GH-98831: Implement basic cache effects (#99313)

This commit is contained in:
Guido van Rossum 2022-11-15 19:59:19 -08:00 committed by GitHub
parent 4636df9feb
commit e37744f289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 201 additions and 127 deletions

View File

@ -76,13 +76,9 @@ do { \
#define NAME_ERROR_MSG \
"name '%.200s' is not defined"
typedef struct {
PyObject *kwnames;
} CallShape;
// Dummy variables for stack effects.
static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub;
static PyObject *container, *start, *stop, *v;
static PyObject *container, *start, *stop, *v, *lhs, *rhs;
static PyObject *
dummy_func(
@ -101,6 +97,8 @@ dummy_func(
binaryfunc binary_ops[]
)
{
_PyInterpreterFrame entry_frame;
switch (opcode) {
// BEGIN BYTECODES //
@ -193,7 +191,21 @@ dummy_func(
ERROR_IF(res == NULL, error);
}
inst(BINARY_OP_MULTIPLY_INT, (left, right -- prod)) {
family(binary_op, INLINE_CACHE_ENTRIES_BINARY_OP) = {
BINARY_OP,
BINARY_OP_ADD_FLOAT,
BINARY_OP_ADD_INT,
BINARY_OP_ADD_UNICODE,
BINARY_OP_GENERIC,
// BINARY_OP_INPLACE_ADD_UNICODE, // This is an odd duck.
BINARY_OP_MULTIPLY_FLOAT,
BINARY_OP_MULTIPLY_INT,
BINARY_OP_SUBTRACT_FLOAT,
BINARY_OP_SUBTRACT_INT,
};
inst(BINARY_OP_MULTIPLY_INT, (left, right, unused/1 -- prod)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP);
@ -202,10 +214,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(prod == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
inst(BINARY_OP_MULTIPLY_FLOAT, (left, right -- prod)) {
inst(BINARY_OP_MULTIPLY_FLOAT, (left, right, unused/1 -- prod)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP);
@ -216,10 +227,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(prod == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
inst(BINARY_OP_SUBTRACT_INT, (left, right -- sub)) {
inst(BINARY_OP_SUBTRACT_INT, (left, right, unused/1 -- sub)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP);
@ -228,10 +238,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(sub == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
inst(BINARY_OP_SUBTRACT_FLOAT, (left, right -- sub)) {
inst(BINARY_OP_SUBTRACT_FLOAT, (left, right, unused/1 -- sub)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP);
@ -241,10 +250,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(sub == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
inst(BINARY_OP_ADD_UNICODE, (left, right -- res)) {
inst(BINARY_OP_ADD_UNICODE, (left, right, unused/1 -- res)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyUnicode_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@ -253,7 +261,6 @@ dummy_func(
_Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc);
_Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
ERROR_IF(res == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
// This is a subtle one. It's a super-instruction for
@ -292,7 +299,7 @@ dummy_func(
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP + 1);
}
inst(BINARY_OP_ADD_FLOAT, (left, right -- sum)) {
inst(BINARY_OP_ADD_FLOAT, (left, right, unused/1 -- sum)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@ -303,10 +310,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(sum == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
inst(BINARY_OP_ADD_INT, (left, right -- sum)) {
inst(BINARY_OP_ADD_INT, (left, right, unused/1 -- sum)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@ -315,7 +321,6 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(sum == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
inst(BINARY_SUBSCR, (container, sub -- res)) {
@ -3691,30 +3696,21 @@ dummy_func(
PUSH(Py_NewRef(peek));
}
// stack effect: (__0 -- )
inst(BINARY_OP_GENERIC) {
PyObject *rhs = POP();
PyObject *lhs = TOP();
inst(BINARY_OP_GENERIC, (lhs, rhs, unused/1 -- res)) {
assert(0 <= oparg);
assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops));
assert(binary_ops[oparg]);
PyObject *res = binary_ops[oparg](lhs, rhs);
res = binary_ops[oparg](lhs, rhs);
Py_DECREF(lhs);
Py_DECREF(rhs);
SET_TOP(res);
if (res == NULL) {
goto error;
}
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
ERROR_IF(res == NULL, error);
}
// stack effect: (__0 -- )
inst(BINARY_OP) {
// This always dispatches, so the result is unused.
inst(BINARY_OP, (lhs, rhs, unused/1 -- unused)) {
_PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr;
if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
assert(cframe.use_tracing == 0);
PyObject *lhs = SECOND();
PyObject *rhs = TOP();
next_instr--;
_Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0));
DISPATCH_SAME_OPARG();
@ -3761,13 +3757,8 @@ dummy_func(
;
}
// Families go below this point //
// Future families go below this point //
family(binary_op) = {
BINARY_OP, BINARY_OP_ADD_FLOAT,
BINARY_OP_ADD_INT, BINARY_OP_ADD_UNICODE, BINARY_OP_GENERIC, BINARY_OP_INPLACE_ADD_UNICODE,
BINARY_OP_MULTIPLY_FLOAT, BINARY_OP_MULTIPLY_INT, BINARY_OP_SUBTRACT_FLOAT,
BINARY_OP_SUBTRACT_INT };
family(binary_subscr) = {
BINARY_SUBSCR, BINARY_SUBSCR_DICT,
BINARY_SUBSCR_GETITEM, BINARY_SUBSCR_LIST_INT, BINARY_SUBSCR_TUPLE_INT };

View File

@ -145,9 +145,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (prod == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, prod);
next_instr += 1;
DISPATCH();
}
@ -165,9 +165,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (prod == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, prod);
next_instr += 1;
DISPATCH();
}
@ -183,9 +183,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (sub == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, sub);
next_instr += 1;
DISPATCH();
}
@ -202,9 +202,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (sub == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, sub);
next_instr += 1;
DISPATCH();
}
@ -220,9 +220,9 @@
_Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc);
_Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
if (res == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, res);
next_instr += 1;
DISPATCH();
}
@ -274,9 +274,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (sum == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, sum);
next_instr += 1;
DISPATCH();
}
@ -292,9 +292,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (sum == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, sum);
next_instr += 1;
DISPATCH();
}
@ -3703,29 +3703,30 @@
TARGET(BINARY_OP_GENERIC) {
PREDICTED(BINARY_OP_GENERIC);
PyObject *rhs = POP();
PyObject *lhs = TOP();
PyObject *rhs = PEEK(1);
PyObject *lhs = PEEK(2);
PyObject *res;
assert(0 <= oparg);
assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops));
assert(binary_ops[oparg]);
PyObject *res = binary_ops[oparg](lhs, rhs);
res = binary_ops[oparg](lhs, rhs);
Py_DECREF(lhs);
Py_DECREF(rhs);
SET_TOP(res);
if (res == NULL) {
goto error;
}
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
if (res == NULL) goto pop_2_error;
STACK_SHRINK(1);
POKE(1, res);
next_instr += 1;
DISPATCH();
}
TARGET(BINARY_OP) {
PREDICTED(BINARY_OP);
assert(INLINE_CACHE_ENTRIES_BINARY_OP == 1);
PyObject *rhs = PEEK(1);
PyObject *lhs = PEEK(2);
_PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr;
if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
assert(cframe.use_tracing == 0);
PyObject *lhs = SECOND();
PyObject *rhs = TOP();
next_instr--;
_Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0));
DISPATCH_SAME_OPARG();

View File

@ -18,7 +18,6 @@ RE_PREDICTED = r"(?s)(?:PREDICT\(|GO_TO_INSTRUCTION\(|DEOPT_IF\(.*?,\s*)(\w+)\);
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-i", "--input", type=str, default="Python/bytecodes.c")
arg_parser.add_argument("-o", "--output", type=str, default="Python/generated_cases.c.h")
arg_parser.add_argument("-c", "--compare", action="store_true")
arg_parser.add_argument("-q", "--quiet", action="store_true")
@ -40,7 +39,6 @@ def parse_cases(
families: list[parser.Family] = []
while not psr.eof():
if inst := psr.inst_def():
assert inst.block
instrs.append(inst)
elif sup := psr.super_def():
supers.append(sup)
@ -69,17 +67,45 @@ def always_exits(block: parser.Block) -> bool:
return line.startswith(("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()"))
def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0):
assert instr.block
def find_cache_size(instr: InstDef, families: list[parser.Family]) -> str | None:
for family in families:
if instr.name == family.members[0]:
return family.size
def write_instr(
instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0, cache_size: str | None = None
) -> int:
# Returns cache offset
if dedent < 0:
indent += " " * -dedent
# Separate stack inputs from cache inputs
input_names: set[str] = set()
stack: list[parser.StackEffect] = []
cache: list[parser.CacheEffect] = []
for input in instr.inputs:
if isinstance(input, parser.StackEffect):
stack.append(input)
input_names.add(input.name)
else:
assert isinstance(input, parser.CacheEffect), input
cache.append(input)
outputs = instr.outputs
cache_offset = 0
for ceffect in cache:
if ceffect.name != "unused":
bits = ceffect.size * 16
f.write(f"{indent} PyObject *{ceffect.name} = read{bits}(next_instr + {cache_offset});\n")
cache_offset += ceffect.size
if cache_size:
f.write(f"{indent} assert({cache_size} == {cache_offset});\n")
# TODO: Is it better to count forward or backward?
for i, input in enumerate(reversed(instr.inputs), 1):
f.write(f"{indent} PyObject *{input} = PEEK({i});\n")
for i, effect in enumerate(reversed(stack), 1):
if effect.name != "unused":
f.write(f"{indent} PyObject *{effect.name} = PEEK({i});\n")
for output in instr.outputs:
if output not in instr.inputs:
f.write(f"{indent} PyObject *{output};\n")
assert instr.block is not None
if output.name not in input_names and output.name != "unused":
f.write(f"{indent} PyObject *{output.name};\n")
blocklines = instr.block.to_text(dedent=dedent).splitlines(True)
# Remove blank lines from ends
while blocklines and not blocklines[0].strip():
@ -95,7 +121,7 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
while blocklines and not blocklines[-1].strip():
blocklines.pop()
# Write the body
ninputs = len(instr.inputs or ())
ninputs = len(stack)
for line in blocklines:
if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line):
space, cond, label = m.groups()
@ -107,46 +133,56 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
f.write(f"{space}if ({cond}) goto {label};\n")
else:
f.write(line)
noutputs = len(instr.outputs or ())
if always_exits(instr.block):
# None of the rest matters
return cache_offset
# Stack effect
noutputs = len(outputs)
diff = noutputs - ninputs
if diff > 0:
f.write(f"{indent} STACK_GROW({diff});\n")
elif diff < 0:
f.write(f"{indent} STACK_SHRINK({-diff});\n")
for i, output in enumerate(reversed(instr.outputs or ()), 1):
if output not in (instr.inputs or ()):
f.write(f"{indent} POKE({i}, {output});\n")
assert instr.block
for i, output in enumerate(reversed(outputs), 1):
if output.name not in input_names and output.name != "unused":
f.write(f"{indent} POKE({i}, {output.name});\n")
# Cache effect
if cache_offset:
f.write(f"{indent} next_instr += {cache_offset};\n")
return cache_offset
def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):
def write_cases(
f: TextIO, instrs: list[InstDef], supers: list[parser.Super], families: list[parser.Family]
) -> dict[str, tuple[int, int, int]]:
predictions: set[str] = set()
for instr in instrs:
assert isinstance(instr, InstDef)
assert instr.block is not None
for target in re.findall(RE_PREDICTED, instr.block.text):
predictions.add(target)
indent = " "
f.write(f"// This file is generated by {os.path.relpath(__file__)}\n")
f.write(f"// Do not edit!\n")
instr_index: dict[str, InstDef] = {}
effects_table: dict[str, tuple[int, int, int]] = {} # name -> (ninputs, noutputs, cache_offset)
for instr in instrs:
instr_index[instr.name] = instr
f.write(f"\n{indent}TARGET({instr.name}) {{\n")
if instr.name in predictions:
f.write(f"{indent} PREDICTED({instr.name});\n")
write_instr(instr, predictions, indent, f)
assert instr.block
cache_offset = write_instr(
instr, predictions, indent, f,
cache_size=find_cache_size(instr, families)
)
effects_table[instr.name] = len(instr.inputs), len(instr.outputs), cache_offset
if not always_exits(instr.block):
f.write(f"{indent} DISPATCH();\n")
# Write trailing '}'
f.write(f"{indent}}}\n")
for sup in supers:
assert isinstance(sup, parser.Super)
components = [instr_index[name] for name in sup.ops]
f.write(f"\n{indent}TARGET({sup.name}) {{\n")
for i, instr in enumerate(components):
assert instr.block
if i > 0:
f.write(f"{indent} NEXTOPARG();\n")
f.write(f"{indent} next_instr++;\n")
@ -156,6 +192,8 @@ def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):
f.write(f"{indent} DISPATCH();\n")
f.write(f"{indent}}}\n")
return effects_table
def main():
args = arg_parser.parse_args()
@ -176,12 +214,28 @@ def main():
file=sys.stderr,
)
with eopen(args.output, "w") as f:
write_cases(f, instrs, supers)
effects_table = write_cases(f, instrs, supers, families)
if not args.quiet:
print(
f"Wrote {ninstrs + nsupers} instructions to {args.output}",
file=sys.stderr,
)
# Check that families have consistent effects
errors = 0
for family in families:
head = effects_table[family.members[0]]
for member in family.members:
if effects_table[member] != head:
errors += 1
print(
f"Family {family.name!r} has inconsistent effects (inputs, outputs, cache units):",
file=sys.stderr,
)
print(
f" {family.members[0]} = {head}; {member} = {effects_table[member]}",
)
if errors:
sys.exit(1)
if __name__ == "__main__":

View File

@ -56,11 +56,28 @@ class Block(Node):
tokens: list[lx.Token]
@dataclass
class Effect(Node):
pass
@dataclass
class StackEffect(Effect):
name: str
# TODO: type, condition
@dataclass
class CacheEffect(Effect):
name: str
size: int
@dataclass
class InstHeader(Node):
name: str
inputs: list[str]
outputs: list[str]
inputs: list[Effect]
outputs: list[Effect]
@dataclass
@ -69,16 +86,17 @@ class InstDef(Node):
block: Block
@property
def name(self):
def name(self) -> str:
return self.header.name
@property
def inputs(self):
def inputs(self) -> list[Effect]:
return self.header.inputs
@property
def outputs(self):
return self.header.outputs
def outputs(self) -> list[StackEffect]:
# This is always true
return [x for x in self.header.outputs if isinstance(x, StackEffect)]
@dataclass
@ -90,6 +108,7 @@ class Super(Node):
@dataclass
class Family(Node):
name: str
size: str # Variable giving the cache size in code units
members: list[str]
@ -123,18 +142,16 @@ class Parser(PLexer):
return InstHeader(name, [], [])
return None
def check_overlaps(self, inp: list[str], outp: list[str]):
def check_overlaps(self, inp: list[Effect], outp: list[Effect]):
for i, name in enumerate(inp):
try:
j = outp.index(name)
except ValueError:
continue
else:
if i != j:
raise self.make_syntax_error(
f"Input {name!r} at pos {i} repeated in output at different pos {j}")
for j, name2 in enumerate(outp):
if name == name2:
if i != j:
raise self.make_syntax_error(
f"Input {name!r} at pos {i} repeated in output at different pos {j}")
break
def stack_effect(self) -> tuple[list[str], list[str]]:
def stack_effect(self) -> tuple[list[Effect], list[Effect]]:
# '(' [inputs] '--' [outputs] ')'
if self.expect(lx.LPAREN):
inp = self.inputs() or []
@ -144,8 +161,8 @@ class Parser(PLexer):
return inp, outp
raise self.make_syntax_error("Expected stack effect")
def inputs(self) -> list[str] | None:
# input (, input)*
def inputs(self) -> list[Effect] | None:
# input (',' input)*
here = self.getpos()
if inp := self.input():
near = self.getpos()
@ -157,27 +174,25 @@ class Parser(PLexer):
self.setpos(here)
return None
def input(self) -> str | None:
# IDENTIFIER
@contextual
def input(self) -> Effect | None:
# IDENTIFIER '/' INTEGER (CacheEffect)
# IDENTIFIER (StackEffect)
if (tkn := self.expect(lx.IDENTIFIER)):
if self.expect(lx.LBRACKET):
if arg := self.expect(lx.IDENTIFIER):
if self.expect(lx.RBRACKET):
return f"{tkn.text}[{arg.text}]"
if self.expect(lx.TIMES):
if num := self.expect(lx.NUMBER):
if self.expect(lx.RBRACKET):
return f"{tkn.text}[{arg.text}*{num.text}]"
raise self.make_syntax_error("Expected argument in brackets", tkn)
if self.expect(lx.DIVIDE):
if num := self.expect(lx.NUMBER):
try:
size = int(num.text)
except ValueError:
raise self.make_syntax_error(
f"Expected integer, got {num.text!r}")
else:
return CacheEffect(tkn.text, size)
raise self.make_syntax_error("Expected integer")
else:
return StackEffect(tkn.text)
return tkn.text
if self.expect(lx.CONDOP):
while self.expect(lx.CONDOP):
pass
return "??"
return None
def outputs(self) -> list[str] | None:
def outputs(self) -> list[Effect] | None:
# output (, output)*
here = self.getpos()
if outp := self.output():
@ -190,8 +205,10 @@ class Parser(PLexer):
self.setpos(here)
return None
def output(self) -> str | None:
return self.input() # TODO: They're not quite the same.
@contextual
def output(self) -> Effect | None:
if (tkn := self.expect(lx.IDENTIFIER)):
return StackEffect(tkn.text)
@contextual
def super_def(self) -> Super | None:
@ -216,24 +233,35 @@ class Parser(PLexer):
@contextual
def family_def(self) -> Family | None:
if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family":
size = None
if self.expect(lx.LPAREN):
if (tkn := self.expect(lx.IDENTIFIER)):
if self.expect(lx.COMMA):
if not (size := self.expect(lx.IDENTIFIER)):
raise self.make_syntax_error(
"Expected identifier")
if self.expect(lx.RPAREN):
if self.expect(lx.EQUALS):
if not self.expect(lx.LBRACE):
raise self.make_syntax_error("Expected {")
if members := self.members():
if self.expect(lx.SEMI):
return Family(tkn.text, members)
if self.expect(lx.RBRACE) and self.expect(lx.SEMI):
return Family(tkn.text, size.text if size else "", members)
return None
def members(self) -> list[str] | None:
here = self.getpos()
if tkn := self.expect(lx.IDENTIFIER):
near = self.getpos()
if self.expect(lx.COMMA):
if rest := self.members():
return [tkn.text] + rest
self.setpos(near)
return [tkn.text]
members = [tkn.text]
while self.expect(lx.COMMA):
if tkn := self.expect(lx.IDENTIFIER):
members.append(tkn.text)
else:
break
peek = self.peek()
if not peek or peek.kind != lx.RBRACE:
raise self.make_syntax_error("Expected comma or right paren")
return members
self.setpos(here)
return None
@ -274,5 +302,5 @@ if __name__ == "__main__":
filename = None
src = "if (x) { x.foo; // comment\n}"
parser = Parser(src, filename)
x = parser.inst_def()
x = parser.inst_def() or parser.super_def() or parser.family_def()
print(x)