gh-119180: Improvements to ForwardRef.evaluate (#122210)

Noticed some issues while writing documentation for this method.
This commit is contained in:
Jelle Zijlstra 2024-08-11 16:42:57 -07:00 committed by GitHub
parent a6644d4464
commit 016f4b5975
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 11 deletions

View File

@ -74,7 +74,7 @@ class ForwardRef:
def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
"""Evaluate the forward reference and return the value.
If the forward reference is not evaluatable, raise an exception.
If the forward reference cannot be evaluated, raise an exception.
"""
if self.__forward_evaluated__:
return self.__forward_value__
@ -89,12 +89,10 @@ class ForwardRef:
return value
if owner is None:
owner = self.__owner__
if type_params is None and owner is None:
raise TypeError("Either 'type_params' or 'owner' must be provided")
if self.__forward_module__ is not None:
if globals is None and self.__forward_module__ is not None:
globals = getattr(
sys.modules.get(self.__forward_module__, None), "__dict__", globals
sys.modules.get(self.__forward_module__, None), "__dict__", None
)
if globals is None:
globals = self.__globals__
@ -112,14 +110,14 @@ class ForwardRef:
if locals is None:
locals = {}
if isinstance(self.__owner__, type):
locals.update(vars(self.__owner__))
if isinstance(owner, type):
locals.update(vars(owner))
if type_params is None and self.__owner__ is not None:
if type_params is None and owner is not None:
# "Inject" type parameters into the local namespace
# (unless they are shadowed by assignments *in* the local namespace),
# as a way of emulating annotation scopes when calling `eval()`
type_params = getattr(self.__owner__, "__type_params__", None)
type_params = getattr(owner, "__type_params__", None)
# type parameters require some special handling,
# as they exist in their own scope
@ -129,7 +127,14 @@ class ForwardRef:
# but should in turn be overridden by names in the class scope
# (which here are called `globalns`!)
if type_params is not None:
globals, locals = dict(globals), dict(locals)
if globals is None:
globals = {}
else:
globals = dict(globals)
if locals is None:
locals = {}
else:
locals = dict(locals)
for param in type_params:
param_name = param.__name__
if not self.__forward_is_class__ or param_name not in globals:

View File

@ -5,7 +5,7 @@ import functools
import itertools
import pickle
import unittest
from annotationlib import Format, get_annotations, get_annotate_function
from annotationlib import Format, ForwardRef, get_annotations, get_annotate_function
from typing import Unpack
from test.test_inspect import inspect_stock_annotations
@ -250,6 +250,46 @@ class TestForwardRefClass(unittest.TestCase):
with self.assertRaises(TypeError):
pickle.dumps(fr, proto)
def test_evaluate_with_type_params(self):
class Gen[T]:
alias = int
with self.assertRaises(NameError):
ForwardRef("T").evaluate()
with self.assertRaises(NameError):
ForwardRef("T").evaluate(type_params=())
with self.assertRaises(NameError):
ForwardRef("T").evaluate(owner=int)
T, = Gen.__type_params__
self.assertIs(ForwardRef("T").evaluate(type_params=Gen.__type_params__), T)
self.assertIs(ForwardRef("T").evaluate(owner=Gen), T)
with self.assertRaises(NameError):
ForwardRef("alias").evaluate(type_params=Gen.__type_params__)
self.assertIs(ForwardRef("alias").evaluate(owner=Gen), int)
# If you pass custom locals, we don't look at the owner's locals
with self.assertRaises(NameError):
ForwardRef("alias").evaluate(owner=Gen, locals={})
# But if the name exists in the locals, it works
self.assertIs(
ForwardRef("alias").evaluate(owner=Gen, locals={"alias": str}), str
)
def test_fwdref_with_module(self):
self.assertIs(ForwardRef("Format", module=annotationlib).evaluate(), Format)
with self.assertRaises(NameError):
# If globals are passed explicitly, we don't look at the module dict
ForwardRef("Format", module=annotationlib).evaluate(globals={})
def test_fwdref_value_is_cached(self):
fr = ForwardRef("hello")
with self.assertRaises(NameError):
fr.evaluate()
self.assertIs(fr.evaluate(globals={"hello": str}), str)
self.assertIs(fr.evaluate(), str)
class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self):

View File

@ -474,6 +474,10 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
_deprecation_warning_for_no_type_params_passed("typing._eval_type")
type_params = ()
if isinstance(t, ForwardRef):
# If the forward_ref has __forward_module__ set, evaluate() infers the globals
# from the module, and it will probably pick better than the globals we have here.
if t.__forward_module__ is not None:
globalns = None
return evaluate_forward_ref(t, globals=globalns, locals=localns,
type_params=type_params, owner=owner,
_recursive_guard=recursive_guard, format=format)