"""Helpers for introspecting and wrapping annotations.""" import ast import enum import functools import sys import types __all__ = ["Format", "ForwardRef", "call_annotate_function", "get_annotations"] class Format(enum.IntEnum): VALUE = 1 FORWARDREF = 2 SOURCE = 3 _Union = None _sentinel = object() # Slots shared by ForwardRef and _Stringifier. The __forward__ names must be # preserved for compatibility with the old typing.ForwardRef class. The remaining # names are private. _SLOTS = ( "__forward_evaluated__", "__forward_value__", "__forward_is_argument__", "__forward_is_class__", "__forward_module__", "__weakref__", "__arg__", "__ast_node__", "__code__", "__globals__", "__owner__", "__cell__", ) class ForwardRef: """Wrapper that holds a forward reference.""" __slots__ = _SLOTS def __init__( self, arg, *, module=None, owner=None, is_argument=True, is_class=False, _globals=None, _cell=None, ): if not isinstance(arg, str): raise TypeError(f"Forward reference must be a string -- got {arg!r}") self.__arg__ = arg self.__forward_evaluated__ = False self.__forward_value__ = None self.__forward_is_argument__ = is_argument self.__forward_is_class__ = is_class self.__forward_module__ = module self.__code__ = None self.__ast_node__ = None self.__globals__ = _globals self.__cell__ = _cell self.__owner__ = owner def __init_subclass__(cls, /, *args, **kwds): raise TypeError("Cannot subclass ForwardRef") def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): """Evaluate the forward reference and return the value. If the forward reference is not evaluatable, raise an exception. """ if self.__forward_evaluated__: return self.__forward_value__ if self.__cell__ is not None: try: value = self.__cell__.cell_contents except ValueError: pass else: self.__forward_evaluated__ = True self.__forward_value__ = value return value if owner is None: owner = self.__owner__ if type_params is None and owner is None: raise TypeError("Either 'type_params' or 'owner' must be provided") if self.__forward_module__ is not None: globals = getattr( sys.modules.get(self.__forward_module__, None), "__dict__", globals ) if globals is None: globals = self.__globals__ if globals is None: if isinstance(owner, type): module_name = getattr(owner, "__module__", None) if module_name: module = sys.modules.get(module_name, None) if module: globals = getattr(module, "__dict__", None) elif isinstance(owner, types.ModuleType): globals = getattr(owner, "__dict__", None) elif callable(owner): globals = getattr(owner, "__globals__", None) if locals is None: locals = {} if isinstance(self.__owner__, type): locals.update(vars(self.__owner__)) if type_params is None and self.__owner__ is not None: # "Inject" type parameters into the local namespace # (unless they are shadowed by assignments *in* the local namespace), # as a way of emulating annotation scopes when calling `eval()` type_params = getattr(self.__owner__, "__type_params__", None) # type parameters require some special handling, # as they exist in their own scope # but `eval()` does not have a dedicated parameter for that scope. # For classes, names in type parameter scopes should override # names in the global scope (which here are called `localns`!), # but should in turn be overridden by names in the class scope # (which here are called `globalns`!) if type_params is not None: globals, locals = dict(globals), dict(locals) for param in type_params: param_name = param.__name__ if not self.__forward_is_class__ or param_name not in globals: globals[param_name] = param locals.pop(param_name, None) code = self.__forward_code__ value = eval(code, globals=globals, locals=locals) self.__forward_evaluated__ = True self.__forward_value__ = value return value def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard): import typing import warnings if type_params is _sentinel: typing._deprecation_warning_for_no_type_params_passed( "typing.ForwardRef._evaluate" ) type_params = () warnings._deprecated( "ForwardRef._evaluate", "{name} is a private API and is retained for compatibility, but will be removed" " in Python 3.16. Use ForwardRef.evaluate() or typing.evaluate_forward_ref() instead.", remove=(3, 16), ) return typing.evaluate_forward_ref( self, globals=globalns, locals=localns, type_params=type_params, _recursive_guard=recursive_guard, ) @property def __forward_arg__(self): if self.__arg__ is not None: return self.__arg__ if self.__ast_node__ is not None: self.__arg__ = ast.unparse(self.__ast_node__) return self.__arg__ raise AssertionError( "Attempted to access '__forward_arg__' on an uninitialized ForwardRef" ) @property def __forward_code__(self): if self.__code__ is not None: return self.__code__ arg = self.__forward_arg__ # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`. # Unfortunately, this isn't a valid expression on its own, so we # do the unpacking manually. if arg.startswith("*"): arg_to_compile = f"({arg},)[0]" # E.g. (*Ts,)[0] or (*tuple[int, int],)[0] else: arg_to_compile = arg try: self.__code__ = compile(arg_to_compile, "", "eval") except SyntaxError: raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}") return self.__code__ def __eq__(self, other): if not isinstance(other, ForwardRef): return NotImplemented if self.__forward_evaluated__ and other.__forward_evaluated__: return ( self.__forward_arg__ == other.__forward_arg__ and self.__forward_value__ == other.__forward_value__ ) return ( self.__forward_arg__ == other.__forward_arg__ and self.__forward_module__ == other.__forward_module__ ) def __hash__(self): return hash((self.__forward_arg__, self.__forward_module__)) def __or__(self, other): global _Union if _Union is None: from typing import Union as _Union return _Union[self, other] def __ror__(self, other): global _Union if _Union is None: from typing import Union as _Union return _Union[other, self] def __repr__(self): if self.__forward_module__ is None: module_repr = "" else: module_repr = f", module={self.__forward_module__!r}" return f"ForwardRef({self.__forward_arg__!r}{module_repr})" class _Stringifier: # Must match the slots on ForwardRef, so we can turn an instance of one into an # instance of the other in place. __slots__ = _SLOTS def __init__(self, node, globals=None, owner=None, is_class=False, cell=None): assert isinstance(node, ast.AST) self.__arg__ = None self.__forward_evaluated__ = False self.__forward_value__ = None self.__forward_is_argument__ = False self.__forward_is_class__ = is_class self.__forward_module__ = None self.__code__ = None self.__ast_node__ = node self.__globals__ = globals self.__cell__ = cell self.__owner__ = owner def __convert(self, other): if isinstance(other, _Stringifier): return other.__ast_node__ elif isinstance(other, slice): return ast.Slice( lower=self.__convert(other.start) if other.start is not None else None, upper=self.__convert(other.stop) if other.stop is not None else None, step=self.__convert(other.step) if other.step is not None else None, ) else: return ast.Constant(value=other) def __make_new(self, node): return _Stringifier( node, self.__globals__, self.__owner__, self.__forward_is_class__ ) # Must implement this since we set __eq__. We hash by identity so that # stringifiers in dict keys are kept separate. def __hash__(self): return id(self) def __getitem__(self, other): # Special case, to avoid stringifying references to class-scoped variables # as '__classdict__["x"]'. if ( isinstance(self.__ast_node__, ast.Name) and self.__ast_node__.id == "__classdict__" ): raise KeyError if isinstance(other, tuple): elts = [self.__convert(elt) for elt in other] other = ast.Tuple(elts) else: other = self.__convert(other) assert isinstance(other, ast.AST), repr(other) return self.__make_new(ast.Subscript(self.__ast_node__, other)) def __getattr__(self, attr): return self.__make_new(ast.Attribute(self.__ast_node__, attr)) def __call__(self, *args, **kwargs): return self.__make_new( ast.Call( self.__ast_node__, [self.__convert(arg) for arg in args], [ ast.keyword(key, self.__convert(value)) for key, value in kwargs.items() ], ) ) def __iter__(self): yield self.__make_new(ast.Starred(self.__ast_node__)) def __repr__(self): return ast.unparse(self.__ast_node__) def __format__(self, format_spec): raise TypeError("Cannot stringify annotation containing string formatting") def _make_binop(op: ast.AST): def binop(self, other): return self.__make_new( ast.BinOp(self.__ast_node__, op, self.__convert(other)) ) return binop __add__ = _make_binop(ast.Add()) __sub__ = _make_binop(ast.Sub()) __mul__ = _make_binop(ast.Mult()) __matmul__ = _make_binop(ast.MatMult()) __truediv__ = _make_binop(ast.Div()) __mod__ = _make_binop(ast.Mod()) __lshift__ = _make_binop(ast.LShift()) __rshift__ = _make_binop(ast.RShift()) __or__ = _make_binop(ast.BitOr()) __xor__ = _make_binop(ast.BitXor()) __and__ = _make_binop(ast.BitAnd()) __floordiv__ = _make_binop(ast.FloorDiv()) __pow__ = _make_binop(ast.Pow()) del _make_binop def _make_rbinop(op: ast.AST): def rbinop(self, other): return self.__make_new( ast.BinOp(self.__convert(other), op, self.__ast_node__) ) return rbinop __radd__ = _make_rbinop(ast.Add()) __rsub__ = _make_rbinop(ast.Sub()) __rmul__ = _make_rbinop(ast.Mult()) __rmatmul__ = _make_rbinop(ast.MatMult()) __rtruediv__ = _make_rbinop(ast.Div()) __rmod__ = _make_rbinop(ast.Mod()) __rlshift__ = _make_rbinop(ast.LShift()) __rrshift__ = _make_rbinop(ast.RShift()) __ror__ = _make_rbinop(ast.BitOr()) __rxor__ = _make_rbinop(ast.BitXor()) __rand__ = _make_rbinop(ast.BitAnd()) __rfloordiv__ = _make_rbinop(ast.FloorDiv()) __rpow__ = _make_rbinop(ast.Pow()) del _make_rbinop def _make_compare(op): def compare(self, other): return self.__make_new( ast.Compare( left=self.__ast_node__, ops=[op], comparators=[self.__convert(other)], ) ) return compare __lt__ = _make_compare(ast.Lt()) __le__ = _make_compare(ast.LtE()) __eq__ = _make_compare(ast.Eq()) __ne__ = _make_compare(ast.NotEq()) __gt__ = _make_compare(ast.Gt()) __ge__ = _make_compare(ast.GtE()) del _make_compare def _make_unary_op(op): def unary_op(self): return self.__make_new(ast.UnaryOp(op, self.__ast_node__)) return unary_op __invert__ = _make_unary_op(ast.Invert()) __pos__ = _make_unary_op(ast.UAdd()) __neg__ = _make_unary_op(ast.USub()) del _make_unary_op class _StringifierDict(dict): def __init__(self, namespace, globals=None, owner=None, is_class=False): super().__init__(namespace) self.namespace = namespace self.globals = globals self.owner = owner self.is_class = is_class self.stringifiers = [] def __missing__(self, key): fwdref = _Stringifier( ast.Name(id=key), globals=self.globals, owner=self.owner, is_class=self.is_class, ) self.stringifiers.append(fwdref) return fwdref def call_evaluate_function(evaluate, format, *, owner=None): """Call an evaluate function. Evaluate functions are normally generated for the value of type aliases and the bounds, constraints, and defaults of type parameter objects. """ return call_annotate_function(evaluate, format, owner=owner, _is_evaluate=True) def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): """Call an __annotate__ function. __annotate__ functions are normally generated by the compiler to defer the evaluation of annotations. They can be called with any of the format arguments in the Format enum, but compiler-generated __annotate__ functions only support the VALUE format. This function provides additional functionality to call __annotate__ functions with the FORWARDREF and SOURCE formats. *annotate* must be an __annotate__ function, which takes a single argument and returns a dict of annotations. *format* must be a member of the Format enum or one of the corresponding integer values. *owner* can be the object that owns the annotations (i.e., the module, class, or function that the __annotate__ function derives from). With the FORWARDREF format, it is used to provide better evaluation capabilities on the generated ForwardRef objects. """ try: return annotate(format) except NotImplementedError: pass if format == Format.SOURCE: # SOURCE is implemented by calling the annotate function in a special # environment where every name lookup results in an instance of _Stringifier. # _Stringifier supports every dunder operation and returns a new _Stringifier. # At the end, we get a dictionary that mostly contains _Stringifier objects (or # possibly constants if the annotate function uses them directly). We then # convert each of those into a string to get an approximation of the # original source. globals = _StringifierDict({}) if annotate.__closure__: freevars = annotate.__code__.co_freevars new_closure = [] for i, cell in enumerate(annotate.__closure__): if i < len(freevars): name = freevars[i] else: name = "__cell__" fwdref = _Stringifier(ast.Name(id=name)) new_closure.append(types.CellType(fwdref)) closure = tuple(new_closure) else: closure = None func = types.FunctionType(annotate.__code__, globals, closure=closure, argdefs=annotate.__defaults__, kwdefaults=annotate.__kwdefaults__) annos = func(Format.VALUE) if _is_evaluate: return annos if isinstance(annos, str) else repr(annos) return { key: val if isinstance(val, str) else repr(val) for key, val in annos.items() } elif format == Format.FORWARDREF: # FORWARDREF is implemented similarly to SOURCE, but there are two changes, # at the beginning and the end of the process. # First, while SOURCE uses an empty dictionary as the namespace, so that all # name lookups result in _Stringifier objects, FORWARDREF uses the globals # and builtins, so that defined names map to their real values. # Second, instead of returning strings, we want to return either real values # or ForwardRef objects. To do this, we keep track of all _Stringifier objects # created while the annotation is being evaluated, and at the end we convert # them all to ForwardRef objects by assigning to __class__. To make this # technique work, we have to ensure that the _Stringifier and ForwardRef # classes share the same attributes. # We use this technique because while the annotations are being evaluated, # we want to support all operations that the language allows, including even # __getattr__ and __eq__, and return new _Stringifier objects so we can accurately # reconstruct the source. But in the dictionary that we eventually return, we # want to return objects with more user-friendly behavior, such as an __eq__ # that returns a bool and an defined set of attributes. namespace = {**annotate.__builtins__, **annotate.__globals__} is_class = isinstance(owner, type) globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class) if annotate.__closure__: freevars = annotate.__code__.co_freevars new_closure = [] for i, cell in enumerate(annotate.__closure__): try: cell.cell_contents except ValueError: if i < len(freevars): name = freevars[i] else: name = "__cell__" fwdref = _Stringifier( ast.Name(id=name), cell=cell, owner=owner, globals=annotate.__globals__, is_class=is_class, ) globals.stringifiers.append(fwdref) new_closure.append(types.CellType(fwdref)) else: new_closure.append(cell) closure = tuple(new_closure) else: closure = None func = types.FunctionType(annotate.__code__, globals, closure=closure, argdefs=annotate.__defaults__, kwdefaults=annotate.__kwdefaults__) result = func(Format.VALUE) for obj in globals.stringifiers: obj.__class__ = ForwardRef return result elif format == Format.VALUE: # Should be impossible because __annotate__ functions must not raise # NotImplementedError for this format. raise RuntimeError("annotate function does not support VALUE format") else: raise ValueError(f"Invalid format: {format!r}") # We use the descriptors from builtins.type instead of accessing # .__annotations__ and .__annotate__ directly on class objects, because # otherwise we could get wrong results in some cases involving metaclasses. # See PEP 749. _BASE_GET_ANNOTATE = type.__dict__["__annotate__"].__get__ _BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__ def get_annotate_function(obj): """Get the __annotate__ function for an object. obj may be a function, class, or module, or a user-defined type with an `__annotate__` attribute. Returns the __annotate__ function or None. """ if isinstance(obj, type): return _BASE_GET_ANNOTATE(obj) return getattr(obj, "__annotate__", None) def get_annotations( obj, *, globals=None, locals=None, eval_str=False, format=Format.VALUE ): """Compute the annotations dict for an object. obj may be a callable, class, or module. Passing in an object of any other type raises TypeError. Returns a dict. get_annotations() returns a new dict every time it's called; calling it twice on the same object will return two different but equivalent dicts. This function handles several details for you: * If eval_str is true, values of type str will be un-stringized using eval(). This is intended for use with stringized annotations ("from __future__ import annotations"). * If obj doesn't have an annotations dict, returns an empty dict. (Functions and methods always have an annotations dict; classes, modules, and other types of callables may not.) * Ignores inherited annotations on classes. If a class doesn't have its own annotations dict, returns an empty dict. * All accesses to object members and dict values are done using getattr() and dict.get() for safety. * Always, always, always returns a freshly-created dict. eval_str controls whether or not values of type str are replaced with the result of calling eval() on those values: * If eval_str is true, eval() is called on values of type str. * If eval_str is false (the default), values of type str are unchanged. globals and locals are passed in to eval(); see the documentation for eval() for more information. If either globals or locals is None, this function may replace that value with a context-specific default, contingent on type(obj): * If obj is a module, globals defaults to obj.__dict__. * If obj is a class, globals defaults to sys.modules[obj.__module__].__dict__ and locals defaults to the obj class namespace. * If obj is a callable, globals defaults to obj.__globals__, although if obj is a wrapped function (using functools.update_wrapper()) it is first unwrapped. """ if eval_str and format != Format.VALUE: raise ValueError("eval_str=True is only supported with format=Format.VALUE") # For VALUE format, we look at __annotations__ directly. if format != Format.VALUE: annotate = get_annotate_function(obj) if annotate is not None: ann = call_annotate_function(annotate, format, owner=obj) if not isinstance(ann, dict): raise ValueError(f"{obj!r}.__annotate__ returned a non-dict") return dict(ann) if isinstance(obj, type): try: ann = _BASE_GET_ANNOTATIONS(obj) except AttributeError: # For static types, the descriptor raises AttributeError. return {} else: ann = getattr(obj, "__annotations__", None) if ann is None: return {} if not isinstance(ann, dict): raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None") if not ann: return {} if not eval_str: return dict(ann) if isinstance(obj, type): # class obj_globals = None module_name = getattr(obj, "__module__", None) if module_name: module = sys.modules.get(module_name, None) if module: obj_globals = getattr(module, "__dict__", None) obj_locals = dict(vars(obj)) unwrap = obj elif isinstance(obj, types.ModuleType): # module obj_globals = getattr(obj, "__dict__") obj_locals = None unwrap = None elif callable(obj): # this includes types.Function, types.BuiltinFunctionType, # types.BuiltinMethodType, functools.partial, functools.singledispatch, # "class funclike" from Lib/test/test_inspect... on and on it goes. obj_globals = getattr(obj, "__globals__", None) obj_locals = None unwrap = obj elif ann is not None: obj_globals = obj_locals = unwrap = None else: raise TypeError(f"{obj!r} is not a module, class, or callable.") if unwrap is not None: while True: if hasattr(unwrap, "__wrapped__"): unwrap = unwrap.__wrapped__ continue if isinstance(unwrap, functools.partial): unwrap = unwrap.func continue break if hasattr(unwrap, "__globals__"): obj_globals = unwrap.__globals__ if globals is None: globals = obj_globals if locals is None: locals = obj_locals # "Inject" type parameters into the local namespace # (unless they are shadowed by assignments *in* the local namespace), # as a way of emulating annotation scopes when calling `eval()` if type_params := getattr(obj, "__type_params__", ()): if locals is None: locals = {} locals = {param.__name__: param for param in type_params} | locals return_value = { key: value if not isinstance(value, str) else eval(value, globals, locals) for key, value in ann.items() } return return_value