"""Parser for bytecodes.inst."""
from dataclasses import dataclass, field
from typing import NamedTuple, Callable, TypeVar, Literal, cast
import lexer as lx
from plexer import PLexer
P = TypeVar("P", bound="Parser")
N = TypeVar("N", bound="Node")
def contextual(func: Callable[[P], N | None]) -> Callable[[P], N | None]:
# Decorator to wrap grammar methods.
# Resets position if `func` returns None.
def contextual_wrapper(self: P) -> N | None:
begin = self.getpos()
res = func(self)
if res is None:
return None
end = self.getpos()
res.context = Context(begin, end, self)
return res
return contextual_wrapper
class Context(NamedTuple):
begin: int
end: int
owner: PLexer
def __repr__(self) -> str:
return f"<{self.owner.filename}: {self.begin}-{self.end}>"
class Node:
context: Context | None = field(init=False, compare=False, default=None)
def text(self) -> str:
return self.to_text()
def to_text(self, dedent: int = 0) -> str:
context = self.context
if not context:
return ""
return lx.to_text(self.tokens, dedent)
def tokens(self) -> list[lx.Token]:
context = self.context
if not context:
return []
tokens = context.owner.tokens
begin = context.begin
end = context.end
return tokens[begin:end]
def first_token(self) -> lx.Token:
context = self.context
assert context is not None
return context.owner.tokens[context.begin]
class Block(Node):
# This just holds a context which has the list of tokens.
class StackEffect(Node):
name: str = field(compare=False) # __eq__ only uses type, cond, size
type: str = "" # Optional `:type`
cond: str = "" # Optional `if (cond)`
size: str = "" # Optional `[size]`
# Note: size cannot be combined with type or cond
def __repr__(self) -> str:
items = [self.name, self.type, self.cond, self.size]
while items and items[-1] == "":
del items[-1]
return f"StackEffect({', '.join(repr(item) for item in items)})"
class Expression(Node):
size: str
class CacheEffect(Node):
name: str
size: int
class OpName(Node):
name: str
InputEffect = StackEffect | CacheEffect
OutputEffect = StackEffect
UOp = OpName | CacheEffect
class InstHeader(Node):
annotations: list[str]
kind: Literal["inst", "op"]
name: str
inputs: list[InputEffect]
outputs: list[OutputEffect]
class InstDef(Node):
annotations: list[str]
kind: Literal["inst", "op"]
name: str
inputs: list[InputEffect]
outputs: list[OutputEffect]
block: Block
class Macro(Node):
name: str
uops: list[UOp]
class Family(Node):
name: str
size: str # Variable giving the cache size in code units
members: list[str]
class Pseudo(Node):
name: str
inputs: list[InputEffect]
outputs: list[OutputEffect]
flags: list[str] # instr flags to set on the pseudo instruction
targets: list[str] # opcodes this can be replaced by
as_sequence: bool
AstNode = InstDef | Macro | Pseudo | Family
class Parser(PLexer):
def definition(self) -> AstNode | None:
if macro := self.macro_def():
return macro
if family := self.family_def():
return family
if pseudo := self.pseudo_def():
return pseudo
if inst := self.inst_def():
return inst
return None
def inst_def(self) -> InstDef | None:
if hdr := self.inst_header():
if block := self.block():
return InstDef(
raise self.make_syntax_error("Expected block")
return None
def inst_header(self) -> InstHeader | None:
# annotation* inst(NAME, (inputs -- outputs))
# | annotation* op(NAME, (inputs -- outputs))
annotations = []
while anno := self.expect(lx.ANNOTATION):
if anno.text == "replicate":
times = self.require(lx.NUMBER)
tkn = self.expect(lx.INST)
if not tkn:
tkn = self.expect(lx.OP)
if tkn:
kind = cast(Literal["inst", "op"], tkn.text)
if self.expect(lx.LPAREN) and (tkn := self.expect(lx.IDENTIFIER)):
name = tkn.text
if self.expect(lx.COMMA):
inp, outp = self.io_effect()
if self.expect(lx.RPAREN):
if (tkn := self.peek()) and tkn.kind == lx.LBRACE:
return InstHeader(annotations, kind, name, inp, outp)
return None
def io_effect(self) -> tuple[list[InputEffect], list[OutputEffect]]:
# '(' [inputs] '--' [outputs] ')'
if self.expect(lx.LPAREN):
inputs = self.inputs() or []
if self.expect(lx.MINUSMINUS):
outputs = self.outputs() or []
if self.expect(lx.RPAREN):
return inputs, outputs
raise self.make_syntax_error("Expected stack effect")
def inputs(self) -> list[InputEffect] | None:
# input (',' input)*
here = self.getpos()
if inp := self.input():
inp = cast(InputEffect, inp)
near = self.getpos()
if self.expect(lx.COMMA):
if rest := self.inputs():
return [inp] + rest
return [inp]
return None
def input(self) -> InputEffect | None:
return self.cache_effect() or self.stack_effect()
def outputs(self) -> list[OutputEffect] | None:
# output (, output)*
here = self.getpos()
if outp := self.output():
near = self.getpos()
if self.expect(lx.COMMA):
if rest := self.outputs():
return [outp] + rest
return [outp]
return None
def output(self) -> OutputEffect | None:
return self.stack_effect()
def cache_effect(self) -> CacheEffect | None:
if tkn := self.expect(lx.IDENTIFIER):
if self.expect(lx.DIVIDE):
num = self.require(lx.NUMBER).text
size = int(num)
except ValueError:
raise self.make_syntax_error(f"Expected integer, got {num!r}")
return CacheEffect(tkn.text, size)
return None
def stack_effect(self) -> StackEffect | None:
# IDENTIFIER [':' IDENTIFIER [TIMES]] ['if' '(' expression ')']
# | IDENTIFIER '[' expression ']'
if tkn := self.expect(lx.IDENTIFIER):
type_text = ""
if self.expect(lx.COLON):
type_text = self.require(lx.IDENTIFIER).text.strip()
if self.expect(lx.TIMES):
type_text += " *"
cond_text = ""
if self.expect(lx.IF):
if not (cond := self.expression()):
raise self.make_syntax_error("Expected condition")
cond_text = cond.text.strip()
size_text = ""
if self.expect(lx.LBRACKET):
if type_text or cond_text:
raise self.make_syntax_error("Unexpected [")
if not (size := self.expression()):
raise self.make_syntax_error("Expected expression")
size_text = size.text.strip()
return StackEffect(tkn.text, type_text, cond_text, size_text)
return None
def expression(self) -> Expression | None:
tokens: list[lx.Token] = []
level = 1
while tkn := self.peek():
if tkn.kind in (lx.LBRACKET, lx.LPAREN):
level += 1
elif tkn.kind in (lx.RBRACKET, lx.RPAREN):
level -= 1
if level == 0:
if not tokens:
return None
return Expression(lx.to_text(tokens).strip())
# def ops(self) -> list[OpName] | None:
# if op := self.op():
# ops = [op]
# while self.expect(lx.PLUS):
# if op := self.op():
# ops.append(op)
# return ops
def op(self) -> OpName | None:
if tkn := self.expect(lx.IDENTIFIER):
return OpName(tkn.text)
return None
def macro_def(self) -> Macro | None:
if tkn := self.expect(lx.MACRO):
if self.expect(lx.LPAREN):
if tkn := self.expect(lx.IDENTIFIER):
if self.expect(lx.RPAREN):
if self.expect(lx.EQUALS):
if uops := self.uops():
res = Macro(tkn.text, uops)
return res
return None
def uops(self) -> list[UOp] | None:
if uop := self.uop():
uop = cast(UOp, uop)
uops = [uop]
while self.expect(lx.PLUS):
if uop := self.uop():
uop = cast(UOp, uop)
raise self.make_syntax_error("Expected op name or cache effect")
return uops
return None
def uop(self) -> UOp | None:
if tkn := self.expect(lx.IDENTIFIER):
if self.expect(lx.DIVIDE):
if num := self.expect(lx.NUMBER):
size = int(num.text)
except ValueError:
raise self.make_syntax_error(
f"Expected integer, got {num.text!r}"
return CacheEffect(tkn.text, size)
raise self.make_syntax_error("Expected integer")
return OpName(tkn.text)
return None
def family_def(self) -> Family | None:
if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family":
size = None
if self.expect(lx.LPAREN):
if tkn := self.expect(lx.IDENTIFIER):
if self.expect(lx.COMMA):
if not (size := self.expect(lx.IDENTIFIER)):
if not (size := self.expect(lx.NUMBER)):
raise self.make_syntax_error(
"Expected identifier or number"
if self.expect(lx.RPAREN):
if self.expect(lx.EQUALS):
if not self.expect(lx.LBRACE):
raise self.make_syntax_error("Expected {")
if members := self.members():
if self.expect(lx.RBRACE) and self.expect(lx.SEMI):
return Family(
tkn.text, size.text if size else "", members
return None
def flags(self) -> list[str]:
here = self.getpos()
if self.expect(lx.LPAREN):
if tkn := self.expect(lx.IDENTIFIER):
flags = [tkn.text]
while self.expect(lx.COMMA):
if tkn := self.expect(lx.IDENTIFIER):
if not self.expect(lx.RPAREN):
raise self.make_syntax_error("Expected comma or right paren")
return flags
return []
def pseudo_def(self) -> Pseudo | None:
if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "pseudo":
size = None
if self.expect(lx.LPAREN):
if tkn := self.expect(lx.IDENTIFIER):
if self.expect(lx.COMMA):
inp, outp = self.io_effect()
if self.expect(lx.COMMA):
flags = self.flags()
flags = []
if self.expect(lx.RPAREN):
if self.expect(lx.EQUALS):
if self.expect(lx.LBRACE):
as_sequence = False
closing = lx.RBRACE
elif self.expect(lx.LBRACKET):
as_sequence = True
closing = lx.RBRACKET
raise self.make_syntax_error("Expected { or [")
if members := self.members(allow_sequence=True):
if self.expect(closing) and self.expect(lx.SEMI):
return Pseudo(
tkn.text, inp, outp, flags, members, as_sequence
return None
def members(self, allow_sequence : bool=False) -> list[str] | None:
here = self.getpos()
if tkn := self.expect(lx.IDENTIFIER):
members = [tkn.text]
while self.expect(lx.COMMA):
if tkn := self.expect(lx.IDENTIFIER):
peek = self.peek()
kinds = [lx.RBRACE, lx.RBRACKET] if allow_sequence else [lx.RBRACE]
if not peek or peek.kind not in kinds:
raise self.make_syntax_error(
f"Expected comma or right paren{'/bracket' if allow_sequence else ''}")
return members
return None
def block(self) -> Block | None:
if self.c_blob():
return Block()
return None
def c_blob(self) -> list[lx.Token]:
tokens: list[lx.Token] = []
level = 0
while tkn := self.next(raw=True):
if tkn.kind in (lx.LBRACE, lx.LPAREN, lx.LBRACKET):
level += 1
elif tkn.kind in (lx.RBRACE, lx.RPAREN, lx.RBRACKET):
level -= 1
if level <= 0:
return tokens
if __name__ == "__main__":
import sys
if sys.argv[1:]:
filename = sys.argv[1]
if filename == "-c" and sys.argv[2:]:
src = sys.argv[2]
filename = "<string>"
with open(filename, "r") as f:
src = f.read()
srclines = src.splitlines()
begin = srclines.index("// BEGIN BYTECODES //")
end = srclines.index("// END BYTECODES //")
src = "\n".join(srclines[begin + 1 : end])
filename = "<default>"
src = "if (x) { x.foo; // comment\n}"
parser = Parser(src, filename)
x = parser.definition()