GH-98831: Implement basic cache effects (#99313)

This commit is contained in:
Guido van Rossum 2022-11-15 19:59:19 -08:00 committed by GitHub
parent 4636df9feb
commit e37744f289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 201 additions and 127 deletions

View File

@ -76,13 +76,9 @@ do { \
#define NAME_ERROR_MSG \ #define NAME_ERROR_MSG \
"name '%.200s' is not defined" "name '%.200s' is not defined"
typedef struct {
PyObject *kwnames;
} CallShape;
// Dummy variables for stack effects. // Dummy variables for stack effects.
static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub; static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub;
static PyObject *container, *start, *stop, *v; static PyObject *container, *start, *stop, *v, *lhs, *rhs;
static PyObject * static PyObject *
dummy_func( dummy_func(
@ -101,6 +97,8 @@ dummy_func(
binaryfunc binary_ops[] binaryfunc binary_ops[]
) )
{ {
_PyInterpreterFrame entry_frame;
switch (opcode) { switch (opcode) {
// BEGIN BYTECODES // // BEGIN BYTECODES //
@ -193,7 +191,21 @@ dummy_func(
ERROR_IF(res == NULL, error); ERROR_IF(res == NULL, error);
} }
inst(BINARY_OP_MULTIPLY_INT, (left, right -- prod)) { family(binary_op, INLINE_CACHE_ENTRIES_BINARY_OP) = {
BINARY_OP,
BINARY_OP_ADD_FLOAT,
BINARY_OP_ADD_INT,
BINARY_OP_ADD_UNICODE,
BINARY_OP_GENERIC,
// BINARY_OP_INPLACE_ADD_UNICODE, // This is an odd duck.
BINARY_OP_MULTIPLY_FLOAT,
BINARY_OP_MULTIPLY_INT,
BINARY_OP_SUBTRACT_FLOAT,
BINARY_OP_SUBTRACT_INT,
};
inst(BINARY_OP_MULTIPLY_INT, (left, right, unused/1 -- prod)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP); DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP); DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP);
@ -202,10 +214,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(prod == NULL, error); ERROR_IF(prod == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
} }
inst(BINARY_OP_MULTIPLY_FLOAT, (left, right -- prod)) { inst(BINARY_OP_MULTIPLY_FLOAT, (left, right, unused/1 -- prod)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP); DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP); DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP);
@ -216,10 +227,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(prod == NULL, error); ERROR_IF(prod == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
} }
inst(BINARY_OP_SUBTRACT_INT, (left, right -- sub)) { inst(BINARY_OP_SUBTRACT_INT, (left, right, unused/1 -- sub)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP); DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP); DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP);
@ -228,10 +238,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(sub == NULL, error); ERROR_IF(sub == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
} }
inst(BINARY_OP_SUBTRACT_FLOAT, (left, right -- sub)) { inst(BINARY_OP_SUBTRACT_FLOAT, (left, right, unused/1 -- sub)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP); DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP); DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP);
@ -241,10 +250,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(sub == NULL, error); ERROR_IF(sub == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
} }
inst(BINARY_OP_ADD_UNICODE, (left, right -- res)) { inst(BINARY_OP_ADD_UNICODE, (left, right, unused/1 -- res)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
DEOPT_IF(!PyUnicode_CheckExact(left), BINARY_OP); DEOPT_IF(!PyUnicode_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP); DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@ -253,7 +261,6 @@ dummy_func(
_Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc);
_Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
ERROR_IF(res == NULL, error); ERROR_IF(res == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
} }
// This is a subtle one. It's a super-instruction for // This is a subtle one. It's a super-instruction for
@ -292,7 +299,7 @@ dummy_func(
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP + 1); JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP + 1);
} }
inst(BINARY_OP_ADD_FLOAT, (left, right -- sum)) { inst(BINARY_OP_ADD_FLOAT, (left, right, unused/1 -- sum)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP); DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP); DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@ -303,10 +310,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(sum == NULL, error); ERROR_IF(sum == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
} }
inst(BINARY_OP_ADD_INT, (left, right -- sum)) { inst(BINARY_OP_ADD_INT, (left, right, unused/1 -- sum)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP); DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP); DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@ -315,7 +321,6 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(sum == NULL, error); ERROR_IF(sum == NULL, error);
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
} }
inst(BINARY_SUBSCR, (container, sub -- res)) { inst(BINARY_SUBSCR, (container, sub -- res)) {
@ -3691,30 +3696,21 @@ dummy_func(
PUSH(Py_NewRef(peek)); PUSH(Py_NewRef(peek));
} }
// stack effect: (__0 -- ) inst(BINARY_OP_GENERIC, (lhs, rhs, unused/1 -- res)) {
inst(BINARY_OP_GENERIC) {
PyObject *rhs = POP();
PyObject *lhs = TOP();
assert(0 <= oparg); assert(0 <= oparg);
assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops)); assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops));
assert(binary_ops[oparg]); assert(binary_ops[oparg]);
PyObject *res = binary_ops[oparg](lhs, rhs); res = binary_ops[oparg](lhs, rhs);
Py_DECREF(lhs); Py_DECREF(lhs);
Py_DECREF(rhs); Py_DECREF(rhs);
SET_TOP(res); ERROR_IF(res == NULL, error);
if (res == NULL) {
goto error;
}
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
} }
// stack effect: (__0 -- ) // This always dispatches, so the result is unused.
inst(BINARY_OP) { inst(BINARY_OP, (lhs, rhs, unused/1 -- unused)) {
_PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr; _PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr;
if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) { if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
PyObject *lhs = SECOND();
PyObject *rhs = TOP();
next_instr--; next_instr--;
_Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0)); _Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0));
DISPATCH_SAME_OPARG(); DISPATCH_SAME_OPARG();
@ -3761,13 +3757,8 @@ dummy_func(
; ;
} }
// Families go below this point // // Future families go below this point //
family(binary_op) = {
BINARY_OP, BINARY_OP_ADD_FLOAT,
BINARY_OP_ADD_INT, BINARY_OP_ADD_UNICODE, BINARY_OP_GENERIC, BINARY_OP_INPLACE_ADD_UNICODE,
BINARY_OP_MULTIPLY_FLOAT, BINARY_OP_MULTIPLY_INT, BINARY_OP_SUBTRACT_FLOAT,
BINARY_OP_SUBTRACT_INT };
family(binary_subscr) = { family(binary_subscr) = {
BINARY_SUBSCR, BINARY_SUBSCR_DICT, BINARY_SUBSCR, BINARY_SUBSCR_DICT,
BINARY_SUBSCR_GETITEM, BINARY_SUBSCR_LIST_INT, BINARY_SUBSCR_TUPLE_INT }; BINARY_SUBSCR_GETITEM, BINARY_SUBSCR_LIST_INT, BINARY_SUBSCR_TUPLE_INT };

View File

@ -145,9 +145,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (prod == NULL) goto pop_2_error; if (prod == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1); STACK_SHRINK(1);
POKE(1, prod); POKE(1, prod);
next_instr += 1;
DISPATCH(); DISPATCH();
} }
@ -165,9 +165,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (prod == NULL) goto pop_2_error; if (prod == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1); STACK_SHRINK(1);
POKE(1, prod); POKE(1, prod);
next_instr += 1;
DISPATCH(); DISPATCH();
} }
@ -183,9 +183,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (sub == NULL) goto pop_2_error; if (sub == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1); STACK_SHRINK(1);
POKE(1, sub); POKE(1, sub);
next_instr += 1;
DISPATCH(); DISPATCH();
} }
@ -202,9 +202,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (sub == NULL) goto pop_2_error; if (sub == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1); STACK_SHRINK(1);
POKE(1, sub); POKE(1, sub);
next_instr += 1;
DISPATCH(); DISPATCH();
} }
@ -220,9 +220,9 @@
_Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc);
_Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
if (res == NULL) goto pop_2_error; if (res == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1); STACK_SHRINK(1);
POKE(1, res); POKE(1, res);
next_instr += 1;
DISPATCH(); DISPATCH();
} }
@ -274,9 +274,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc); _Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (sum == NULL) goto pop_2_error; if (sum == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1); STACK_SHRINK(1);
POKE(1, sum); POKE(1, sum);
next_instr += 1;
DISPATCH(); DISPATCH();
} }
@ -292,9 +292,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free); _Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (sum == NULL) goto pop_2_error; if (sum == NULL) goto pop_2_error;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1); STACK_SHRINK(1);
POKE(1, sum); POKE(1, sum);
next_instr += 1;
DISPATCH(); DISPATCH();
} }
@ -3703,29 +3703,30 @@
TARGET(BINARY_OP_GENERIC) { TARGET(BINARY_OP_GENERIC) {
PREDICTED(BINARY_OP_GENERIC); PREDICTED(BINARY_OP_GENERIC);
PyObject *rhs = POP(); PyObject *rhs = PEEK(1);
PyObject *lhs = TOP(); PyObject *lhs = PEEK(2);
PyObject *res;
assert(0 <= oparg); assert(0 <= oparg);
assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops)); assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops));
assert(binary_ops[oparg]); assert(binary_ops[oparg]);
PyObject *res = binary_ops[oparg](lhs, rhs); res = binary_ops[oparg](lhs, rhs);
Py_DECREF(lhs); Py_DECREF(lhs);
Py_DECREF(rhs); Py_DECREF(rhs);
SET_TOP(res); if (res == NULL) goto pop_2_error;
if (res == NULL) { STACK_SHRINK(1);
goto error; POKE(1, res);
} next_instr += 1;
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
DISPATCH(); DISPATCH();
} }
TARGET(BINARY_OP) { TARGET(BINARY_OP) {
PREDICTED(BINARY_OP); PREDICTED(BINARY_OP);
assert(INLINE_CACHE_ENTRIES_BINARY_OP == 1);
PyObject *rhs = PEEK(1);
PyObject *lhs = PEEK(2);
_PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr; _PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr;
if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) { if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
assert(cframe.use_tracing == 0); assert(cframe.use_tracing == 0);
PyObject *lhs = SECOND();
PyObject *rhs = TOP();
next_instr--; next_instr--;
_Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0)); _Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0));
DISPATCH_SAME_OPARG(); DISPATCH_SAME_OPARG();

View File

@ -18,7 +18,6 @@ RE_PREDICTED = r"(?s)(?:PREDICT\(|GO_TO_INSTRUCTION\(|DEOPT_IF\(.*?,\s*)(\w+)\);
arg_parser = argparse.ArgumentParser() arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-i", "--input", type=str, default="Python/bytecodes.c") arg_parser.add_argument("-i", "--input", type=str, default="Python/bytecodes.c")
arg_parser.add_argument("-o", "--output", type=str, default="Python/generated_cases.c.h") arg_parser.add_argument("-o", "--output", type=str, default="Python/generated_cases.c.h")
arg_parser.add_argument("-c", "--compare", action="store_true")
arg_parser.add_argument("-q", "--quiet", action="store_true") arg_parser.add_argument("-q", "--quiet", action="store_true")
@ -40,7 +39,6 @@ def parse_cases(
families: list[parser.Family] = [] families: list[parser.Family] = []
while not psr.eof(): while not psr.eof():
if inst := psr.inst_def(): if inst := psr.inst_def():
assert inst.block
instrs.append(inst) instrs.append(inst)
elif sup := psr.super_def(): elif sup := psr.super_def():
supers.append(sup) supers.append(sup)
@ -69,17 +67,45 @@ def always_exits(block: parser.Block) -> bool:
return line.startswith(("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()")) return line.startswith(("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()"))
def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0): def find_cache_size(instr: InstDef, families: list[parser.Family]) -> str | None:
assert instr.block for family in families:
if instr.name == family.members[0]:
return family.size
def write_instr(
instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0, cache_size: str | None = None
) -> int:
# Returns cache offset
if dedent < 0: if dedent < 0:
indent += " " * -dedent indent += " " * -dedent
# Separate stack inputs from cache inputs
input_names: set[str] = set()
stack: list[parser.StackEffect] = []
cache: list[parser.CacheEffect] = []
for input in instr.inputs:
if isinstance(input, parser.StackEffect):
stack.append(input)
input_names.add(input.name)
else:
assert isinstance(input, parser.CacheEffect), input
cache.append(input)
outputs = instr.outputs
cache_offset = 0
for ceffect in cache:
if ceffect.name != "unused":
bits = ceffect.size * 16
f.write(f"{indent} PyObject *{ceffect.name} = read{bits}(next_instr + {cache_offset});\n")
cache_offset += ceffect.size
if cache_size:
f.write(f"{indent} assert({cache_size} == {cache_offset});\n")
# TODO: Is it better to count forward or backward? # TODO: Is it better to count forward or backward?
for i, input in enumerate(reversed(instr.inputs), 1): for i, effect in enumerate(reversed(stack), 1):
f.write(f"{indent} PyObject *{input} = PEEK({i});\n") if effect.name != "unused":
f.write(f"{indent} PyObject *{effect.name} = PEEK({i});\n")
for output in instr.outputs: for output in instr.outputs:
if output not in instr.inputs: if output.name not in input_names and output.name != "unused":
f.write(f"{indent} PyObject *{output};\n") f.write(f"{indent} PyObject *{output.name};\n")
assert instr.block is not None
blocklines = instr.block.to_text(dedent=dedent).splitlines(True) blocklines = instr.block.to_text(dedent=dedent).splitlines(True)
# Remove blank lines from ends # Remove blank lines from ends
while blocklines and not blocklines[0].strip(): while blocklines and not blocklines[0].strip():
@ -95,7 +121,7 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
while blocklines and not blocklines[-1].strip(): while blocklines and not blocklines[-1].strip():
blocklines.pop() blocklines.pop()
# Write the body # Write the body
ninputs = len(instr.inputs or ()) ninputs = len(stack)
for line in blocklines: for line in blocklines:
if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line): if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line):
space, cond, label = m.groups() space, cond, label = m.groups()
@ -107,46 +133,56 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
f.write(f"{space}if ({cond}) goto {label};\n") f.write(f"{space}if ({cond}) goto {label};\n")
else: else:
f.write(line) f.write(line)
noutputs = len(instr.outputs or ()) if always_exits(instr.block):
# None of the rest matters
return cache_offset
# Stack effect
noutputs = len(outputs)
diff = noutputs - ninputs diff = noutputs - ninputs
if diff > 0: if diff > 0:
f.write(f"{indent} STACK_GROW({diff});\n") f.write(f"{indent} STACK_GROW({diff});\n")
elif diff < 0: elif diff < 0:
f.write(f"{indent} STACK_SHRINK({-diff});\n") f.write(f"{indent} STACK_SHRINK({-diff});\n")
for i, output in enumerate(reversed(instr.outputs or ()), 1): for i, output in enumerate(reversed(outputs), 1):
if output not in (instr.inputs or ()): if output.name not in input_names and output.name != "unused":
f.write(f"{indent} POKE({i}, {output});\n") f.write(f"{indent} POKE({i}, {output.name});\n")
assert instr.block # Cache effect
if cache_offset:
f.write(f"{indent} next_instr += {cache_offset};\n")
return cache_offset
def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):
def write_cases(
f: TextIO, instrs: list[InstDef], supers: list[parser.Super], families: list[parser.Family]
) -> dict[str, tuple[int, int, int]]:
predictions: set[str] = set() predictions: set[str] = set()
for instr in instrs: for instr in instrs:
assert isinstance(instr, InstDef)
assert instr.block is not None
for target in re.findall(RE_PREDICTED, instr.block.text): for target in re.findall(RE_PREDICTED, instr.block.text):
predictions.add(target) predictions.add(target)
indent = " " indent = " "
f.write(f"// This file is generated by {os.path.relpath(__file__)}\n") f.write(f"// This file is generated by {os.path.relpath(__file__)}\n")
f.write(f"// Do not edit!\n") f.write(f"// Do not edit!\n")
instr_index: dict[str, InstDef] = {} instr_index: dict[str, InstDef] = {}
effects_table: dict[str, tuple[int, int, int]] = {} # name -> (ninputs, noutputs, cache_offset)
for instr in instrs: for instr in instrs:
instr_index[instr.name] = instr instr_index[instr.name] = instr
f.write(f"\n{indent}TARGET({instr.name}) {{\n") f.write(f"\n{indent}TARGET({instr.name}) {{\n")
if instr.name in predictions: if instr.name in predictions:
f.write(f"{indent} PREDICTED({instr.name});\n") f.write(f"{indent} PREDICTED({instr.name});\n")
write_instr(instr, predictions, indent, f) cache_offset = write_instr(
assert instr.block instr, predictions, indent, f,
cache_size=find_cache_size(instr, families)
)
effects_table[instr.name] = len(instr.inputs), len(instr.outputs), cache_offset
if not always_exits(instr.block): if not always_exits(instr.block):
f.write(f"{indent} DISPATCH();\n") f.write(f"{indent} DISPATCH();\n")
# Write trailing '}' # Write trailing '}'
f.write(f"{indent}}}\n") f.write(f"{indent}}}\n")
for sup in supers: for sup in supers:
assert isinstance(sup, parser.Super)
components = [instr_index[name] for name in sup.ops] components = [instr_index[name] for name in sup.ops]
f.write(f"\n{indent}TARGET({sup.name}) {{\n") f.write(f"\n{indent}TARGET({sup.name}) {{\n")
for i, instr in enumerate(components): for i, instr in enumerate(components):
assert instr.block
if i > 0: if i > 0:
f.write(f"{indent} NEXTOPARG();\n") f.write(f"{indent} NEXTOPARG();\n")
f.write(f"{indent} next_instr++;\n") f.write(f"{indent} next_instr++;\n")
@ -156,6 +192,8 @@ def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):
f.write(f"{indent} DISPATCH();\n") f.write(f"{indent} DISPATCH();\n")
f.write(f"{indent}}}\n") f.write(f"{indent}}}\n")
return effects_table
def main(): def main():
args = arg_parser.parse_args() args = arg_parser.parse_args()
@ -176,12 +214,28 @@ def main():
file=sys.stderr, file=sys.stderr,
) )
with eopen(args.output, "w") as f: with eopen(args.output, "w") as f:
write_cases(f, instrs, supers) effects_table = write_cases(f, instrs, supers, families)
if not args.quiet: if not args.quiet:
print( print(
f"Wrote {ninstrs + nsupers} instructions to {args.output}", f"Wrote {ninstrs + nsupers} instructions to {args.output}",
file=sys.stderr, file=sys.stderr,
) )
# Check that families have consistent effects
errors = 0
for family in families:
head = effects_table[family.members[0]]
for member in family.members:
if effects_table[member] != head:
errors += 1
print(
f"Family {family.name!r} has inconsistent effects (inputs, outputs, cache units):",
file=sys.stderr,
)
print(
f" {family.members[0]} = {head}; {member} = {effects_table[member]}",
)
if errors:
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -56,11 +56,28 @@ class Block(Node):
tokens: list[lx.Token] tokens: list[lx.Token]
@dataclass
class Effect(Node):
pass
@dataclass
class StackEffect(Effect):
name: str
# TODO: type, condition
@dataclass
class CacheEffect(Effect):
name: str
size: int
@dataclass @dataclass
class InstHeader(Node): class InstHeader(Node):
name: str name: str
inputs: list[str] inputs: list[Effect]
outputs: list[str] outputs: list[Effect]
@dataclass @dataclass
@ -69,16 +86,17 @@ class InstDef(Node):
block: Block block: Block
@property @property
def name(self): def name(self) -> str:
return self.header.name return self.header.name
@property @property
def inputs(self): def inputs(self) -> list[Effect]:
return self.header.inputs return self.header.inputs
@property @property
def outputs(self): def outputs(self) -> list[StackEffect]:
return self.header.outputs # This is always true
return [x for x in self.header.outputs if isinstance(x, StackEffect)]
@dataclass @dataclass
@ -90,6 +108,7 @@ class Super(Node):
@dataclass @dataclass
class Family(Node): class Family(Node):
name: str name: str
size: str # Variable giving the cache size in code units
members: list[str] members: list[str]
@ -123,18 +142,16 @@ class Parser(PLexer):
return InstHeader(name, [], []) return InstHeader(name, [], [])
return None return None
def check_overlaps(self, inp: list[str], outp: list[str]): def check_overlaps(self, inp: list[Effect], outp: list[Effect]):
for i, name in enumerate(inp): for i, name in enumerate(inp):
try: for j, name2 in enumerate(outp):
j = outp.index(name) if name == name2:
except ValueError: if i != j:
continue raise self.make_syntax_error(
else: f"Input {name!r} at pos {i} repeated in output at different pos {j}")
if i != j: break
raise self.make_syntax_error(
f"Input {name!r} at pos {i} repeated in output at different pos {j}")
def stack_effect(self) -> tuple[list[str], list[str]]: def stack_effect(self) -> tuple[list[Effect], list[Effect]]:
# '(' [inputs] '--' [outputs] ')' # '(' [inputs] '--' [outputs] ')'
if self.expect(lx.LPAREN): if self.expect(lx.LPAREN):
inp = self.inputs() or [] inp = self.inputs() or []
@ -144,8 +161,8 @@ class Parser(PLexer):
return inp, outp return inp, outp
raise self.make_syntax_error("Expected stack effect") raise self.make_syntax_error("Expected stack effect")
def inputs(self) -> list[str] | None: def inputs(self) -> list[Effect] | None:
# input (, input)* # input (',' input)*
here = self.getpos() here = self.getpos()
if inp := self.input(): if inp := self.input():
near = self.getpos() near = self.getpos()
@ -157,27 +174,25 @@ class Parser(PLexer):
self.setpos(here) self.setpos(here)
return None return None
def input(self) -> str | None: @contextual
# IDENTIFIER def input(self) -> Effect | None:
# IDENTIFIER '/' INTEGER (CacheEffect)
# IDENTIFIER (StackEffect)
if (tkn := self.expect(lx.IDENTIFIER)): if (tkn := self.expect(lx.IDENTIFIER)):
if self.expect(lx.LBRACKET): if self.expect(lx.DIVIDE):
if arg := self.expect(lx.IDENTIFIER): if num := self.expect(lx.NUMBER):
if self.expect(lx.RBRACKET): try:
return f"{tkn.text}[{arg.text}]" size = int(num.text)
if self.expect(lx.TIMES): except ValueError:
if num := self.expect(lx.NUMBER): raise self.make_syntax_error(
if self.expect(lx.RBRACKET): f"Expected integer, got {num.text!r}")
return f"{tkn.text}[{arg.text}*{num.text}]" else:
raise self.make_syntax_error("Expected argument in brackets", tkn) return CacheEffect(tkn.text, size)
raise self.make_syntax_error("Expected integer")
else:
return StackEffect(tkn.text)
return tkn.text def outputs(self) -> list[Effect] | None:
if self.expect(lx.CONDOP):
while self.expect(lx.CONDOP):
pass
return "??"
return None
def outputs(self) -> list[str] | None:
# output (, output)* # output (, output)*
here = self.getpos() here = self.getpos()
if outp := self.output(): if outp := self.output():
@ -190,8 +205,10 @@ class Parser(PLexer):
self.setpos(here) self.setpos(here)
return None return None
def output(self) -> str | None: @contextual
return self.input() # TODO: They're not quite the same. def output(self) -> Effect | None:
if (tkn := self.expect(lx.IDENTIFIER)):
return StackEffect(tkn.text)
@contextual @contextual
def super_def(self) -> Super | None: def super_def(self) -> Super | None:
@ -216,24 +233,35 @@ class Parser(PLexer):
@contextual @contextual
def family_def(self) -> Family | None: def family_def(self) -> Family | None:
if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family": if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family":
size = None
if self.expect(lx.LPAREN): if self.expect(lx.LPAREN):
if (tkn := self.expect(lx.IDENTIFIER)): if (tkn := self.expect(lx.IDENTIFIER)):
if self.expect(lx.COMMA):
if not (size := self.expect(lx.IDENTIFIER)):
raise self.make_syntax_error(
"Expected identifier")
if self.expect(lx.RPAREN): if self.expect(lx.RPAREN):
if self.expect(lx.EQUALS): if self.expect(lx.EQUALS):
if not self.expect(lx.LBRACE):
raise self.make_syntax_error("Expected {")
if members := self.members(): if members := self.members():
if self.expect(lx.SEMI): if self.expect(lx.RBRACE) and self.expect(lx.SEMI):
return Family(tkn.text, members) return Family(tkn.text, size.text if size else "", members)
return None return None
def members(self) -> list[str] | None: def members(self) -> list[str] | None:
here = self.getpos() here = self.getpos()
if tkn := self.expect(lx.IDENTIFIER): if tkn := self.expect(lx.IDENTIFIER):
near = self.getpos() members = [tkn.text]
if self.expect(lx.COMMA): while self.expect(lx.COMMA):
if rest := self.members(): if tkn := self.expect(lx.IDENTIFIER):
return [tkn.text] + rest members.append(tkn.text)
self.setpos(near) else:
return [tkn.text] break
peek = self.peek()
if not peek or peek.kind != lx.RBRACE:
raise self.make_syntax_error("Expected comma or right paren")
return members
self.setpos(here) self.setpos(here)
return None return None
@ -274,5 +302,5 @@ if __name__ == "__main__":
filename = None filename = None
src = "if (x) { x.foo; // comment\n}" src = "if (x) { x.foo; // comment\n}"
parser = Parser(src, filename) parser = Parser(src, filename)
x = parser.inst_def() x = parser.inst_def() or parser.super_def() or parser.family_def()
print(x) print(x)