import re from analyzer import StackItem, StackEffect, Instruction, Uop, PseudoInstruction from dataclasses import dataclass from cwriter import CWriter from typing import Iterator UNUSED = {"unused"} def maybe_parenthesize(sym: str) -> str: """Add parentheses around a string if it contains an operator and is not already parenthesized. An exception is made for '*' which is common and harmless in the context where the symbolic size is used. """ if sym.startswith("(") and sym.endswith(")"): return sym if re.match(r"^[\s\w*]+$", sym): return sym else: return f"({sym})" def var_size(var: StackItem) -> str: if var.condition: # Special case simplifications if var.condition == "0": return "0" elif var.condition == "1": return var.get_size() elif var.condition == "oparg & 1" and not var.size: return f"({var.condition})" else: return f"(({var.condition}) ? {var.get_size()} : 0)" elif var.size: return var.size else: return "1" @dataclass class Local: item: StackItem cached: bool in_memory: bool defined: bool def __repr__(self) -> str: return f"Local('{self.item.name}', mem={self.in_memory}, defined={self.defined}, array={self.is_array()})" def compact_str(self) -> str: mtag = "M" if self.in_memory else "" dtag = "D" if self.defined else "" atag = "A" if self.is_array() else "" return f"'{self.item.name}'{mtag}{dtag}{atag}" @staticmethod def unused(defn: StackItem) -> "Local": return Local(defn, False, defn.is_array(), False) @staticmethod def undefined(defn: StackItem) -> "Local": array = defn.is_array() return Local(defn, not array, array, False) @staticmethod def redefinition(var: StackItem, prev: "Local") -> "Local": assert var.is_array() == prev.is_array() return Local(var, prev.cached, prev.in_memory, True) @staticmethod def from_memory(defn: StackItem) -> "Local": return Local(defn, True, True, True) def copy(self) -> "Local": return Local( self.item, self.cached, self.in_memory, self.defined ) @property def size(self) -> str: return self.item.size @property def name(self) -> str: return self.item.name @property def condition(self) -> str | None: return self.item.condition def is_array(self) -> bool: return self.item.is_array() def __eq__(self, other: object) -> bool: if not isinstance(other, Local): return NotImplemented return ( self.item is other.item and self.cached is other.cached and self.in_memory is other.in_memory and self.defined is other.defined ) @dataclass class StackOffset: "The stack offset of the virtual base of the stack from the physical stack pointer" popped: list[str] pushed: list[str] @staticmethod def empty() -> "StackOffset": return StackOffset([], []) def copy(self) -> "StackOffset": return StackOffset(self.popped[:], self.pushed[:]) def pop(self, item: StackItem) -> None: self.popped.append(var_size(item)) def push(self, item: StackItem) -> None: self.pushed.append(var_size(item)) def __sub__(self, other: "StackOffset") -> "StackOffset": return StackOffset(self.popped + other.pushed, self.pushed + other.popped) def __neg__(self) -> "StackOffset": return StackOffset(self.pushed, self.popped) def simplify(self) -> None: "Remove matching values from both the popped and pushed list" if not self.popped: self.pushed.sort() return if not self.pushed: self.popped.sort() return # Sort the list so the lexically largest element is last. popped = sorted(self.popped) pushed = sorted(self.pushed) self.popped = [] self.pushed = [] while popped and pushed: pop = popped.pop() push = pushed.pop() if pop == push: pass elif pop > push: # if pop > push, there can be no element in pushed matching pop. self.popped.append(pop) pushed.append(push) else: self.pushed.append(push) popped.append(pop) self.popped.extend(popped) self.pushed.extend(pushed) self.pushed.sort() self.popped.sort() def to_c(self) -> str: self.simplify() int_offset = 0 symbol_offset = "" for item in self.popped: try: int_offset -= int(item) except ValueError: symbol_offset += f" - {maybe_parenthesize(item)}" for item in self.pushed: try: int_offset += int(item) except ValueError: symbol_offset += f" + {maybe_parenthesize(item)}" if symbol_offset and not int_offset: res = symbol_offset else: res = f"{int_offset}{symbol_offset}" if res.startswith(" + "): res = res[3:] if res.startswith(" - "): res = "-" + res[3:] return res def as_int(self) -> int | None: self.simplify() int_offset = 0 for item in self.popped: try: int_offset -= int(item) except ValueError: return None for item in self.pushed: try: int_offset += int(item) except ValueError: return None return int_offset def clear(self) -> None: self.popped = [] self.pushed = [] def __bool__(self) -> bool: self.simplify() return bool(self.popped) or bool(self.pushed) def __eq__(self, other: object) -> bool: if not isinstance(other, StackOffset): return NotImplemented return self.to_c() == other.to_c() class StackError(Exception): pass def array_or_scalar(var: StackItem | Local) -> str: return "array" if var.is_array() else "scalar" class Stack: def __init__(self) -> None: self.top_offset = StackOffset.empty() self.base_offset = StackOffset.empty() self.variables: list[Local] = [] self.defined: set[str] = set() def pop(self, var: StackItem, extract_bits: bool = True) -> tuple[str, Local]: self.top_offset.pop(var) indirect = "&" if var.is_array() else "" if self.variables: popped = self.variables.pop() if var.is_array() ^ popped.is_array(): raise StackError( f"Array mismatch when popping '{popped.name}' from stack to assign to '{var.name}'. " f"Expected {array_or_scalar(var)} got {array_or_scalar(popped)}" ) if popped.size != var.size: raise StackError( f"Size mismatch when popping '{popped.name}' from stack to assign to '{var.name}'. " f"Expected {var_size(var)} got {var_size(popped.item)}" ) if var.name in UNUSED: if popped.name not in UNUSED and popped.name in self.defined: raise StackError( f"Value is declared unused, but is already cached by prior operation as '{popped.name}'" ) return "", popped if not var.used: return "", popped self.defined.add(var.name) if popped.defined: if popped.name == var.name: return "", popped else: defn = f"{var.name} = {popped.name};\n" else: if var.is_array(): defn = f"{var.name} = &stack_pointer[{self.top_offset.to_c()}];\n" else: defn = f"{var.name} = stack_pointer[{self.top_offset.to_c()}];\n" popped.in_memory = True return defn, Local.redefinition(var, popped) self.base_offset.pop(var) if var.name in UNUSED or not var.used: return "", Local.unused(var) self.defined.add(var.name) cast = f"({var.type})" if (not indirect and var.type) else "" bits = ".bits" if cast and extract_bits else "" assign = f"{var.name} = {cast}{indirect}stack_pointer[{self.base_offset.to_c()}]{bits};" if var.condition: if var.condition == "1": assign = f"{assign}\n" elif var.condition == "0": return "", Local.unused(var) else: assign = f"if ({var.condition}) {{ {assign} }}\n" else: assign = f"{assign}\n" return assign, Local.from_memory(var) def push(self, var: Local) -> None: assert(var not in self.variables) self.variables.append(var) self.top_offset.push(var.item) if var.item.used: self.defined.add(var.name) @staticmethod def _do_emit( out: CWriter, var: StackItem, base_offset: StackOffset, cast_type: str = "uintptr_t", extract_bits: bool = True, ) -> None: cast = f"({cast_type})" if var.type else "" bits = ".bits" if cast and extract_bits else "" if var.condition == "0": return if var.condition and var.condition != "1": out.emit(f"if ({var.condition}) ") out.emit(f"stack_pointer[{base_offset.to_c()}]{bits} = {cast}{var.name};\n") def _adjust_stack_pointer(self, out: CWriter, number: str) -> None: if number != "0": out.start_line() out.emit(f"stack_pointer += {number};\n") out.emit("assert(WITHIN_STACK_BOUNDS());\n") def flush( self, out: CWriter, cast_type: str = "uintptr_t", extract_bits: bool = True ) -> None: out.start_line() var_offset = self.base_offset.copy() for var in self.variables: if ( var.defined and not var.in_memory ): Stack._do_emit(out, var.item, var_offset, cast_type, extract_bits) var.in_memory = True var_offset.push(var.item) number = self.top_offset.to_c() self._adjust_stack_pointer(out, number) self.base_offset -= self.top_offset self.top_offset.clear() out.start_line() def is_flushed(self) -> bool: return not self.variables and not self.base_offset and not self.top_offset def peek_offset(self) -> str: return self.top_offset.to_c() def as_comment(self) -> str: variables = ", ".join([v.compact_str() for v in self.variables]) return ( f"/* Variables: {variables}. base: {self.base_offset.to_c()}. top: {self.top_offset.to_c()} */" ) def copy(self) -> "Stack": other = Stack() other.top_offset = self.top_offset.copy() other.base_offset = self.base_offset.copy() other.variables = [var.copy() for var in self.variables] other.defined = set(self.defined) return other def __eq__(self, other: object) -> bool: if not isinstance(other, Stack): return NotImplemented return ( self.top_offset == other.top_offset and self.base_offset == other.base_offset and self.variables == other.variables ) def align(self, other: "Stack", out: CWriter) -> None: if len(self.variables) != len(other.variables): raise StackError("Cannot align stacks: differing variables") if self.top_offset == other.top_offset: return diff = self.top_offset - other.top_offset try: self.top_offset -= diff self.base_offset -= diff self._adjust_stack_pointer(out, diff.to_c()) except ValueError: raise StackError("Cannot align stacks: cannot adjust stack pointer") def merge(self, other: "Stack", out: CWriter) -> None: if len(self.variables) != len(other.variables): raise StackError("Cannot merge stacks: differing variables") for self_var, other_var in zip(self.variables, other.variables): if self_var.name != other_var.name: raise StackError(f"Mismatched variables on stack: {self_var.name} and {other_var.name}") self_var.defined = self_var.defined and other_var.defined self_var.in_memory = self_var.in_memory and other_var.in_memory self.align(other, out) def get_stack_effect(inst: Instruction | PseudoInstruction) -> Stack: stack = Stack() def stacks(inst: Instruction | PseudoInstruction) -> Iterator[StackEffect]: if isinstance(inst, Instruction): for uop in inst.parts: if isinstance(uop, Uop): yield uop.stack else: assert isinstance(inst, PseudoInstruction) yield inst.stack for s in stacks(inst): locals: dict[str, Local] = {} for var in reversed(s.inputs): _, local = stack.pop(var) if var.name != "unused": locals[local.name] = local for var in s.outputs: if var.name in locals: local = locals[var.name] else: local = Local.unused(var) stack.push(local) return stack @dataclass class Storage: stack: Stack inputs: list[Local] outputs: list[Local] peeks: list[Local] spilled: int = 0 @staticmethod def needs_defining(var: Local) -> bool: return ( not var.defined and not var.is_array() and var.name != "unused" ) @staticmethod def is_live(var: Local) -> bool: return ( var.defined and var.name != "unused" ) def first_input_not_cleared(self) -> str: for input in self.inputs: if input.defined: return input.name return "" def clear_inputs(self, reason:str) -> None: while self.inputs: tos = self.inputs.pop() if self.is_live(tos) and not tos.is_array(): raise StackError( f"Input '{tos.name}' is still live {reason}" ) self.stack.pop(tos.item) def clear_dead_inputs(self) -> None: live = "" while self.inputs: tos = self.inputs[-1] if self.is_live(tos): live = tos.name break self.inputs.pop() self.stack.pop(tos.item) for var in self.inputs: if not var.defined and not var.is_array() and var.name != "unused": raise StackError( f"Input '{var.name}' is not live, but '{live}' is" ) def _push_defined_outputs(self) -> None: defined_output = "" for output in self.outputs: if output.defined and not output.in_memory: defined_output = output.name if not defined_output: return self.clear_inputs(f"when output '{defined_output}' is defined") undefined = "" for out in self.outputs: if out.defined: if undefined: f"Locals not defined in stack order. " f"Expected '{undefined}' to be defined before '{out.name}'" else: undefined = out.name while self.outputs and not self.needs_defining(self.outputs[0]): out = self.outputs.pop(0) self.stack.push(out) def locals_cached(self) -> bool: for out in self.outputs: if out.defined: return True return False def flush(self, out: CWriter, cast_type: str = "uintptr_t", extract_bits: bool = True) -> None: self.clear_dead_inputs() self._push_defined_outputs() self.stack.flush(out, cast_type, extract_bits) def save(self, out: CWriter) -> None: assert self.spilled >= 0 if self.spilled == 0: self.flush(out) out.start_line() out.emit("_PyFrame_SetStackPointer(frame, stack_pointer);\n") self.spilled += 1 def reload(self, out: CWriter) -> None: if self.spilled == 0: raise StackError("Cannot reload stack as it hasn't been saved") assert self.spilled > 0 self.spilled -= 1 if self.spilled == 0: out.start_line() out.emit("stack_pointer = _PyFrame_GetStackPointer(frame);\n") @staticmethod def for_uop(stack: Stack, uop: Uop, extract_bits: bool = True) -> tuple[list[str], "Storage"]: code_list: list[str] = [] inputs: list[Local] = [] peeks: list[Local] = [] for input in reversed(uop.stack.inputs): code, local = stack.pop(input, extract_bits) code_list.append(code) if input.peek: peeks.append(local) else: inputs.append(local) inputs.reverse() peeks.reverse() for peek in peeks: stack.push(peek) top_offset = stack.top_offset.copy() for ouput in uop.stack.outputs: if ouput.is_array() and ouput.used and not ouput.peek: c_offset = top_offset.to_c() top_offset.push(ouput) code_list.append(f"{ouput.name} = &stack_pointer[{c_offset}];\n") else: top_offset.push(ouput) for var in inputs: stack.push(var) outputs = [ Local.undefined(var) for var in uop.stack.outputs if not var.peek ] return code_list, Storage(stack, inputs, outputs, peeks) @staticmethod def copy_list(arg: list[Local]) -> list[Local]: return [ l.copy() for l in arg ] def copy(self) -> "Storage": new_stack = self.stack.copy() variables = { var.name: var for var in new_stack.variables } inputs = [ variables[var.name] for var in self.inputs] assert [v.name for v in inputs] == [v.name for v in self.inputs], (inputs, self.inputs) return Storage( new_stack, inputs, self.copy_list(self.outputs), self.copy_list(self.peeks) ) def sanity_check(self) -> None: names: set[str] = set() for var in self.inputs: if var.name in names: raise StackError(f"Duplicate name {var.name}") names.add(var.name) names = set() for var in self.outputs: if var.name in names: raise StackError(f"Duplicate name {var.name}") names.add(var.name) names = set() for var in self.stack.variables: if var.name in names: raise StackError(f"Duplicate name {var.name}") names.add(var.name) def is_flushed(self) -> bool: for var in self.outputs: if var.defined and not var.in_memory: return False return self.stack.is_flushed() def merge(self, other: "Storage", out: CWriter) -> None: self.sanity_check() if len(self.inputs) != len(other.inputs): self.clear_dead_inputs() other.clear_dead_inputs() if len(self.inputs) != len(other.inputs): diff = self.inputs[-1] if len(self.inputs) > len(other.inputs) else other.inputs[-1] raise StackError(f"Unmergeable inputs. Differing state of '{diff.name}'") for var, other_var in zip(self.inputs, other.inputs): if var.defined != other_var.defined: raise StackError(f"'{var.name}' is cleared on some paths, but not all") if len(self.outputs) != len(other.outputs): self._push_defined_outputs() other._push_defined_outputs() if len(self.outputs) != len(other.outputs): var = self.outputs[0] if len(self.outputs) > len(other.outputs) else other.outputs[0] raise StackError(f"'{var.name}' is set on some paths, but not all") self.stack.merge(other.stack, out) self.sanity_check() def push_outputs(self) -> None: if self.spilled: raise StackError(f"Unbalanced stack spills") self.clear_inputs("at the end of the micro-op") if self.inputs: raise StackError(f"Input variable '{self.inputs[-1].name}' is still live") self._push_defined_outputs() if self.outputs: for out in self.outputs: if self.needs_defining(out): raise StackError(f"Output variable '{self.outputs[0].name}' is not defined") self.stack.push(out) self.outputs = [] def as_comment(self) -> str: stack_comment = self.stack.as_comment() next_line = "\n " inputs = ", ".join([var.compact_str() for var in self.inputs]) outputs = ", ".join([var.compact_str() for var in self.outputs]) peeks = ", ".join([var.name for var in self.peeks]) return f"{stack_comment[:-2]}{next_line}inputs: {inputs}{next_line}outputs: {outputs}{next_line}peeks: {peeks} */"