diff --git a/Python/bytecodes.c b/Python/bytecodes.c index 69ee741d5df..1575b5390fb 100644 --- a/Python/bytecodes.c +++ b/Python/bytecodes.c @@ -76,13 +76,9 @@ do { \ #define NAME_ERROR_MSG \ "name '%.200s' is not defined" -typedef struct { - PyObject *kwnames; -} CallShape; - // Dummy variables for stack effects. static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub; -static PyObject *container, *start, *stop, *v; +static PyObject *container, *start, *stop, *v, *lhs, *rhs; static PyObject * dummy_func( @@ -101,6 +97,8 @@ dummy_func( binaryfunc binary_ops[] ) { + _PyInterpreterFrame entry_frame; + switch (opcode) { // BEGIN BYTECODES // @@ -193,7 +191,21 @@ dummy_func( ERROR_IF(res == NULL, error); } - inst(BINARY_OP_MULTIPLY_INT, (left, right -- prod)) { + family(binary_op, INLINE_CACHE_ENTRIES_BINARY_OP) = { + BINARY_OP, + BINARY_OP_ADD_FLOAT, + BINARY_OP_ADD_INT, + BINARY_OP_ADD_UNICODE, + BINARY_OP_GENERIC, + // BINARY_OP_INPLACE_ADD_UNICODE, // This is an odd duck. + BINARY_OP_MULTIPLY_FLOAT, + BINARY_OP_MULTIPLY_INT, + BINARY_OP_SUBTRACT_FLOAT, + BINARY_OP_SUBTRACT_INT, + }; + + + inst(BINARY_OP_MULTIPLY_INT, (left, right, unused/1 -- prod)) { assert(cframe.use_tracing == 0); DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP); DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP); @@ -202,10 +214,9 @@ dummy_func( _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); ERROR_IF(prod == NULL, error); - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); } - inst(BINARY_OP_MULTIPLY_FLOAT, (left, right -- prod)) { + inst(BINARY_OP_MULTIPLY_FLOAT, (left, right, unused/1 -- prod)) { assert(cframe.use_tracing == 0); DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP); DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP); @@ -216,10 +227,9 @@ dummy_func( _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); ERROR_IF(prod == NULL, error); - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); } - inst(BINARY_OP_SUBTRACT_INT, (left, right -- sub)) { + inst(BINARY_OP_SUBTRACT_INT, (left, right, unused/1 -- sub)) { assert(cframe.use_tracing == 0); DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP); DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP); @@ -228,10 +238,9 @@ dummy_func( _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); ERROR_IF(sub == NULL, error); - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); } - inst(BINARY_OP_SUBTRACT_FLOAT, (left, right -- sub)) { + inst(BINARY_OP_SUBTRACT_FLOAT, (left, right, unused/1 -- sub)) { assert(cframe.use_tracing == 0); DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP); DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP); @@ -241,10 +250,9 @@ dummy_func( _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); ERROR_IF(sub == NULL, error); - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); } - inst(BINARY_OP_ADD_UNICODE, (left, right -- res)) { + inst(BINARY_OP_ADD_UNICODE, (left, right, unused/1 -- res)) { assert(cframe.use_tracing == 0); DEOPT_IF(!PyUnicode_CheckExact(left), BINARY_OP); DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP); @@ -253,7 +261,6 @@ dummy_func( _Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc); ERROR_IF(res == NULL, error); - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); } // This is a subtle one. It's a super-instruction for @@ -292,7 +299,7 @@ dummy_func( JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP + 1); } - inst(BINARY_OP_ADD_FLOAT, (left, right -- sum)) { + inst(BINARY_OP_ADD_FLOAT, (left, right, unused/1 -- sum)) { assert(cframe.use_tracing == 0); DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP); DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP); @@ -303,10 +310,9 @@ dummy_func( _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); ERROR_IF(sum == NULL, error); - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); } - inst(BINARY_OP_ADD_INT, (left, right -- sum)) { + inst(BINARY_OP_ADD_INT, (left, right, unused/1 -- sum)) { assert(cframe.use_tracing == 0); DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP); DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP); @@ -315,7 +321,6 @@ dummy_func( _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); ERROR_IF(sum == NULL, error); - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); } inst(BINARY_SUBSCR, (container, sub -- res)) { @@ -3691,30 +3696,21 @@ dummy_func( PUSH(Py_NewRef(peek)); } - // stack effect: (__0 -- ) - inst(BINARY_OP_GENERIC) { - PyObject *rhs = POP(); - PyObject *lhs = TOP(); + inst(BINARY_OP_GENERIC, (lhs, rhs, unused/1 -- res)) { assert(0 <= oparg); assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops)); assert(binary_ops[oparg]); - PyObject *res = binary_ops[oparg](lhs, rhs); + res = binary_ops[oparg](lhs, rhs); Py_DECREF(lhs); Py_DECREF(rhs); - SET_TOP(res); - if (res == NULL) { - goto error; - } - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); + ERROR_IF(res == NULL, error); } - // stack effect: (__0 -- ) - inst(BINARY_OP) { + // This always dispatches, so the result is unused. + inst(BINARY_OP, (lhs, rhs, unused/1 -- unused)) { _PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr; if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) { assert(cframe.use_tracing == 0); - PyObject *lhs = SECOND(); - PyObject *rhs = TOP(); next_instr--; _Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0)); DISPATCH_SAME_OPARG(); @@ -3761,13 +3757,8 @@ dummy_func( ; } -// Families go below this point // +// Future families go below this point // -family(binary_op) = { - BINARY_OP, BINARY_OP_ADD_FLOAT, - BINARY_OP_ADD_INT, BINARY_OP_ADD_UNICODE, BINARY_OP_GENERIC, BINARY_OP_INPLACE_ADD_UNICODE, - BINARY_OP_MULTIPLY_FLOAT, BINARY_OP_MULTIPLY_INT, BINARY_OP_SUBTRACT_FLOAT, - BINARY_OP_SUBTRACT_INT }; family(binary_subscr) = { BINARY_SUBSCR, BINARY_SUBSCR_DICT, BINARY_SUBSCR_GETITEM, BINARY_SUBSCR_LIST_INT, BINARY_SUBSCR_TUPLE_INT }; diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h index 552d0e6d017..b8bc66b1488 100644 --- a/Python/generated_cases.c.h +++ b/Python/generated_cases.c.h @@ -145,9 +145,9 @@ _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); if (prod == NULL) goto pop_2_error; - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); STACK_SHRINK(1); POKE(1, prod); + next_instr += 1; DISPATCH(); } @@ -165,9 +165,9 @@ _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); if (prod == NULL) goto pop_2_error; - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); STACK_SHRINK(1); POKE(1, prod); + next_instr += 1; DISPATCH(); } @@ -183,9 +183,9 @@ _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); if (sub == NULL) goto pop_2_error; - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); STACK_SHRINK(1); POKE(1, sub); + next_instr += 1; DISPATCH(); } @@ -202,9 +202,9 @@ _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); if (sub == NULL) goto pop_2_error; - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); STACK_SHRINK(1); POKE(1, sub); + next_instr += 1; DISPATCH(); } @@ -220,9 +220,9 @@ _Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc); if (res == NULL) goto pop_2_error; - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); STACK_SHRINK(1); POKE(1, res); + next_instr += 1; DISPATCH(); } @@ -274,9 +274,9 @@ _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); if (sum == NULL) goto pop_2_error; - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); STACK_SHRINK(1); POKE(1, sum); + next_instr += 1; DISPATCH(); } @@ -292,9 +292,9 @@ _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); if (sum == NULL) goto pop_2_error; - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); STACK_SHRINK(1); POKE(1, sum); + next_instr += 1; DISPATCH(); } @@ -3703,29 +3703,30 @@ TARGET(BINARY_OP_GENERIC) { PREDICTED(BINARY_OP_GENERIC); - PyObject *rhs = POP(); - PyObject *lhs = TOP(); + PyObject *rhs = PEEK(1); + PyObject *lhs = PEEK(2); + PyObject *res; assert(0 <= oparg); assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops)); assert(binary_ops[oparg]); - PyObject *res = binary_ops[oparg](lhs, rhs); + res = binary_ops[oparg](lhs, rhs); Py_DECREF(lhs); Py_DECREF(rhs); - SET_TOP(res); - if (res == NULL) { - goto error; - } - JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP); + if (res == NULL) goto pop_2_error; + STACK_SHRINK(1); + POKE(1, res); + next_instr += 1; DISPATCH(); } TARGET(BINARY_OP) { PREDICTED(BINARY_OP); + assert(INLINE_CACHE_ENTRIES_BINARY_OP == 1); + PyObject *rhs = PEEK(1); + PyObject *lhs = PEEK(2); _PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr; if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) { assert(cframe.use_tracing == 0); - PyObject *lhs = SECOND(); - PyObject *rhs = TOP(); next_instr--; _Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0)); DISPATCH_SAME_OPARG(); diff --git a/Tools/cases_generator/generate_cases.py b/Tools/cases_generator/generate_cases.py index b4f5f8f01dc..d0165317509 100644 --- a/Tools/cases_generator/generate_cases.py +++ b/Tools/cases_generator/generate_cases.py @@ -18,7 +18,6 @@ RE_PREDICTED = r"(?s)(?:PREDICT\(|GO_TO_INSTRUCTION\(|DEOPT_IF\(.*?,\s*)(\w+)\); arg_parser = argparse.ArgumentParser() arg_parser.add_argument("-i", "--input", type=str, default="Python/bytecodes.c") arg_parser.add_argument("-o", "--output", type=str, default="Python/generated_cases.c.h") -arg_parser.add_argument("-c", "--compare", action="store_true") arg_parser.add_argument("-q", "--quiet", action="store_true") @@ -40,7 +39,6 @@ def parse_cases( families: list[parser.Family] = [] while not psr.eof(): if inst := psr.inst_def(): - assert inst.block instrs.append(inst) elif sup := psr.super_def(): supers.append(sup) @@ -69,17 +67,45 @@ def always_exits(block: parser.Block) -> bool: return line.startswith(("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()")) -def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0): - assert instr.block +def find_cache_size(instr: InstDef, families: list[parser.Family]) -> str | None: + for family in families: + if instr.name == family.members[0]: + return family.size + + +def write_instr( + instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0, cache_size: str | None = None +) -> int: + # Returns cache offset if dedent < 0: indent += " " * -dedent + # Separate stack inputs from cache inputs + input_names: set[str] = set() + stack: list[parser.StackEffect] = [] + cache: list[parser.CacheEffect] = [] + for input in instr.inputs: + if isinstance(input, parser.StackEffect): + stack.append(input) + input_names.add(input.name) + else: + assert isinstance(input, parser.CacheEffect), input + cache.append(input) + outputs = instr.outputs + cache_offset = 0 + for ceffect in cache: + if ceffect.name != "unused": + bits = ceffect.size * 16 + f.write(f"{indent} PyObject *{ceffect.name} = read{bits}(next_instr + {cache_offset});\n") + cache_offset += ceffect.size + if cache_size: + f.write(f"{indent} assert({cache_size} == {cache_offset});\n") # TODO: Is it better to count forward or backward? - for i, input in enumerate(reversed(instr.inputs), 1): - f.write(f"{indent} PyObject *{input} = PEEK({i});\n") + for i, effect in enumerate(reversed(stack), 1): + if effect.name != "unused": + f.write(f"{indent} PyObject *{effect.name} = PEEK({i});\n") for output in instr.outputs: - if output not in instr.inputs: - f.write(f"{indent} PyObject *{output};\n") - assert instr.block is not None + if output.name not in input_names and output.name != "unused": + f.write(f"{indent} PyObject *{output.name};\n") blocklines = instr.block.to_text(dedent=dedent).splitlines(True) # Remove blank lines from ends while blocklines and not blocklines[0].strip(): @@ -95,7 +121,7 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d while blocklines and not blocklines[-1].strip(): blocklines.pop() # Write the body - ninputs = len(instr.inputs or ()) + ninputs = len(stack) for line in blocklines: if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line): space, cond, label = m.groups() @@ -107,46 +133,56 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d f.write(f"{space}if ({cond}) goto {label};\n") else: f.write(line) - noutputs = len(instr.outputs or ()) + if always_exits(instr.block): + # None of the rest matters + return cache_offset + # Stack effect + noutputs = len(outputs) diff = noutputs - ninputs if diff > 0: f.write(f"{indent} STACK_GROW({diff});\n") elif diff < 0: f.write(f"{indent} STACK_SHRINK({-diff});\n") - for i, output in enumerate(reversed(instr.outputs or ()), 1): - if output not in (instr.inputs or ()): - f.write(f"{indent} POKE({i}, {output});\n") - assert instr.block + for i, output in enumerate(reversed(outputs), 1): + if output.name not in input_names and output.name != "unused": + f.write(f"{indent} POKE({i}, {output.name});\n") + # Cache effect + if cache_offset: + f.write(f"{indent} next_instr += {cache_offset};\n") + return cache_offset -def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]): + +def write_cases( + f: TextIO, instrs: list[InstDef], supers: list[parser.Super], families: list[parser.Family] +) -> dict[str, tuple[int, int, int]]: predictions: set[str] = set() for instr in instrs: - assert isinstance(instr, InstDef) - assert instr.block is not None for target in re.findall(RE_PREDICTED, instr.block.text): predictions.add(target) indent = " " f.write(f"// This file is generated by {os.path.relpath(__file__)}\n") f.write(f"// Do not edit!\n") instr_index: dict[str, InstDef] = {} + effects_table: dict[str, tuple[int, int, int]] = {} # name -> (ninputs, noutputs, cache_offset) for instr in instrs: instr_index[instr.name] = instr f.write(f"\n{indent}TARGET({instr.name}) {{\n") if instr.name in predictions: f.write(f"{indent} PREDICTED({instr.name});\n") - write_instr(instr, predictions, indent, f) - assert instr.block + cache_offset = write_instr( + instr, predictions, indent, f, + cache_size=find_cache_size(instr, families) + ) + effects_table[instr.name] = len(instr.inputs), len(instr.outputs), cache_offset if not always_exits(instr.block): f.write(f"{indent} DISPATCH();\n") # Write trailing '}' f.write(f"{indent}}}\n") for sup in supers: - assert isinstance(sup, parser.Super) components = [instr_index[name] for name in sup.ops] f.write(f"\n{indent}TARGET({sup.name}) {{\n") for i, instr in enumerate(components): - assert instr.block if i > 0: f.write(f"{indent} NEXTOPARG();\n") f.write(f"{indent} next_instr++;\n") @@ -156,6 +192,8 @@ def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]): f.write(f"{indent} DISPATCH();\n") f.write(f"{indent}}}\n") + return effects_table + def main(): args = arg_parser.parse_args() @@ -176,12 +214,28 @@ def main(): file=sys.stderr, ) with eopen(args.output, "w") as f: - write_cases(f, instrs, supers) + effects_table = write_cases(f, instrs, supers, families) if not args.quiet: print( f"Wrote {ninstrs + nsupers} instructions to {args.output}", file=sys.stderr, ) + # Check that families have consistent effects + errors = 0 + for family in families: + head = effects_table[family.members[0]] + for member in family.members: + if effects_table[member] != head: + errors += 1 + print( + f"Family {family.name!r} has inconsistent effects (inputs, outputs, cache units):", + file=sys.stderr, + ) + print( + f" {family.members[0]} = {head}; {member} = {effects_table[member]}", + ) + if errors: + sys.exit(1) if __name__ == "__main__": diff --git a/Tools/cases_generator/parser.py b/Tools/cases_generator/parser.py index 9e95cdb42d4..1f855312aeb 100644 --- a/Tools/cases_generator/parser.py +++ b/Tools/cases_generator/parser.py @@ -56,11 +56,28 @@ class Block(Node): tokens: list[lx.Token] +@dataclass +class Effect(Node): + pass + + +@dataclass +class StackEffect(Effect): + name: str + # TODO: type, condition + + +@dataclass +class CacheEffect(Effect): + name: str + size: int + + @dataclass class InstHeader(Node): name: str - inputs: list[str] - outputs: list[str] + inputs: list[Effect] + outputs: list[Effect] @dataclass @@ -69,16 +86,17 @@ class InstDef(Node): block: Block @property - def name(self): + def name(self) -> str: return self.header.name @property - def inputs(self): + def inputs(self) -> list[Effect]: return self.header.inputs @property - def outputs(self): - return self.header.outputs + def outputs(self) -> list[StackEffect]: + # This is always true + return [x for x in self.header.outputs if isinstance(x, StackEffect)] @dataclass @@ -90,6 +108,7 @@ class Super(Node): @dataclass class Family(Node): name: str + size: str # Variable giving the cache size in code units members: list[str] @@ -123,18 +142,16 @@ class Parser(PLexer): return InstHeader(name, [], []) return None - def check_overlaps(self, inp: list[str], outp: list[str]): + def check_overlaps(self, inp: list[Effect], outp: list[Effect]): for i, name in enumerate(inp): - try: - j = outp.index(name) - except ValueError: - continue - else: - if i != j: - raise self.make_syntax_error( - f"Input {name!r} at pos {i} repeated in output at different pos {j}") + for j, name2 in enumerate(outp): + if name == name2: + if i != j: + raise self.make_syntax_error( + f"Input {name!r} at pos {i} repeated in output at different pos {j}") + break - def stack_effect(self) -> tuple[list[str], list[str]]: + def stack_effect(self) -> tuple[list[Effect], list[Effect]]: # '(' [inputs] '--' [outputs] ')' if self.expect(lx.LPAREN): inp = self.inputs() or [] @@ -144,8 +161,8 @@ class Parser(PLexer): return inp, outp raise self.make_syntax_error("Expected stack effect") - def inputs(self) -> list[str] | None: - # input (, input)* + def inputs(self) -> list[Effect] | None: + # input (',' input)* here = self.getpos() if inp := self.input(): near = self.getpos() @@ -157,27 +174,25 @@ class Parser(PLexer): self.setpos(here) return None - def input(self) -> str | None: - # IDENTIFIER + @contextual + def input(self) -> Effect | None: + # IDENTIFIER '/' INTEGER (CacheEffect) + # IDENTIFIER (StackEffect) if (tkn := self.expect(lx.IDENTIFIER)): - if self.expect(lx.LBRACKET): - if arg := self.expect(lx.IDENTIFIER): - if self.expect(lx.RBRACKET): - return f"{tkn.text}[{arg.text}]" - if self.expect(lx.TIMES): - if num := self.expect(lx.NUMBER): - if self.expect(lx.RBRACKET): - return f"{tkn.text}[{arg.text}*{num.text}]" - raise self.make_syntax_error("Expected argument in brackets", tkn) + if self.expect(lx.DIVIDE): + if num := self.expect(lx.NUMBER): + try: + size = int(num.text) + except ValueError: + raise self.make_syntax_error( + f"Expected integer, got {num.text!r}") + else: + return CacheEffect(tkn.text, size) + raise self.make_syntax_error("Expected integer") + else: + return StackEffect(tkn.text) - return tkn.text - if self.expect(lx.CONDOP): - while self.expect(lx.CONDOP): - pass - return "??" - return None - - def outputs(self) -> list[str] | None: + def outputs(self) -> list[Effect] | None: # output (, output)* here = self.getpos() if outp := self.output(): @@ -190,8 +205,10 @@ class Parser(PLexer): self.setpos(here) return None - def output(self) -> str | None: - return self.input() # TODO: They're not quite the same. + @contextual + def output(self) -> Effect | None: + if (tkn := self.expect(lx.IDENTIFIER)): + return StackEffect(tkn.text) @contextual def super_def(self) -> Super | None: @@ -216,24 +233,35 @@ class Parser(PLexer): @contextual def family_def(self) -> Family | None: if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family": + size = None if self.expect(lx.LPAREN): if (tkn := self.expect(lx.IDENTIFIER)): + if self.expect(lx.COMMA): + if not (size := self.expect(lx.IDENTIFIER)): + raise self.make_syntax_error( + "Expected identifier") if self.expect(lx.RPAREN): if self.expect(lx.EQUALS): + if not self.expect(lx.LBRACE): + raise self.make_syntax_error("Expected {") if members := self.members(): - if self.expect(lx.SEMI): - return Family(tkn.text, members) + if self.expect(lx.RBRACE) and self.expect(lx.SEMI): + return Family(tkn.text, size.text if size else "", members) return None def members(self) -> list[str] | None: here = self.getpos() if tkn := self.expect(lx.IDENTIFIER): - near = self.getpos() - if self.expect(lx.COMMA): - if rest := self.members(): - return [tkn.text] + rest - self.setpos(near) - return [tkn.text] + members = [tkn.text] + while self.expect(lx.COMMA): + if tkn := self.expect(lx.IDENTIFIER): + members.append(tkn.text) + else: + break + peek = self.peek() + if not peek or peek.kind != lx.RBRACE: + raise self.make_syntax_error("Expected comma or right paren") + return members self.setpos(here) return None @@ -274,5 +302,5 @@ if __name__ == "__main__": filename = None src = "if (x) { x.foo; // comment\n}" parser = Parser(src, filename) - x = parser.inst_def() + x = parser.inst_def() or parser.super_def() or parser.family_def() print(x)