summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Euresti <david@euresti.com>2021-02-18 21:50:41 -0800
committerGitHub <noreply@github.com>2021-02-19 06:50:41 +0100
commit1499c77e79315ef8c5c98c088dc3ef99bdba72b4 (patch)
tree5d93313283435d13fbb3ac14a78b4d22f01786b4
parent78335e9b49eff2eaf2dc31d0bcc3caa0169cfc60 (diff)
downloadattrs-1499c77e79315ef8c5c98c088dc3ef99bdba72b4.tar.gz
Fix issue with get_type_hints(cls.__init__) and refactor (#760)
* Fix issue with get_type_hints(cls.__init__) * Refactor * Improve coverage
-rw-r--r--src/attr/_make.py84
-rw-r--r--tests/test_annotations.py29
2 files changed, 70 insertions, 43 deletions
diff --git a/src/attr/_make.py b/src/attr/_make.py
index 8bc8634..76b1c62 100644
--- a/src/attr/_make.py
+++ b/src/attr/_make.py
@@ -286,6 +286,36 @@ def attrib(
)
+def _compile_and_eval(script, globs, locs=None, filename=""):
+ """
+ "Exec" the script with the given global (globs) and local (locs) variables.
+ """
+ bytecode = compile(script, filename, "exec")
+ eval(bytecode, globs, locs)
+
+
+def _make_method(name, script, filename, globs=None):
+ """
+ Create the method with the script given and return the method object.
+ """
+ locs = {}
+ if globs is None:
+ globs = {}
+
+ _compile_and_eval(script, globs, locs, filename)
+
+ # In order of debuggers like PDB being able to step through the code,
+ # we add a fake linecache entry.
+ linecache.cache[filename] = (
+ len(script),
+ None,
+ script.splitlines(True),
+ filename,
+ )
+
+ return locs[name]
+
+
def _make_attr_tuple_class(cls_name, attr_names):
"""
Create a tuple subclass to hold `Attribute`s for an `attrs` class.
@@ -309,8 +339,7 @@ def _make_attr_tuple_class(cls_name, attr_names):
else:
attr_class_template.append(" pass")
globs = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
- eval(compile("\n".join(attr_class_template), "", "exec"), globs)
-
+ _compile_and_eval("\n".join(attr_class_template), globs)
return globs[attr_class_name]
@@ -1591,21 +1620,7 @@ def _make_hash(cls, attrs, frozen, cache_hash):
append_hash_computation_lines("return ", tab)
script = "\n".join(method_lines)
- globs = {}
- locs = {}
- bytecode = compile(script, unique_filename, "exec")
- eval(bytecode, globs, locs)
-
- # In order of debuggers like PDB being able to step through the code,
- # we add a fake linecache entry.
- linecache.cache[unique_filename] = (
- len(script),
- None,
- script.splitlines(True),
- unique_filename,
- )
-
- return locs["__hash__"]
+ return _make_method("__hash__", script, unique_filename)
def _add_hash(cls, attrs):
@@ -1661,20 +1676,7 @@ def _make_eq(cls, attrs):
lines.append(" return True")
script = "\n".join(lines)
- globs = {}
- locs = {}
- bytecode = compile(script, unique_filename, "exec")
- eval(bytecode, globs, locs)
-
- # In order of debuggers like PDB being able to step through the code,
- # we add a fake linecache entry.
- linecache.cache[unique_filename] = (
- len(script),
- None,
- script.splitlines(True),
- unique_filename,
- )
- return locs["__eq__"]
+ return _make_method("__eq__", script, unique_filename)
def _make_order(cls, attrs):
@@ -1949,8 +1951,10 @@ def _make_init(
has_global_on_setattr,
attrs_init,
)
- locs = {}
- bytecode = compile(script, unique_filename, "exec")
+ if cls.__module__ in sys.modules:
+ # This makes typing.get_type_hints(CLS.__init__) resolve string types.
+ globs.update(sys.modules[cls.__module__].__dict__)
+
globs.update({"NOTHING": NOTHING, "attr_dict": attr_dict})
if needs_cached_setattr:
@@ -1958,18 +1962,12 @@ def _make_init(
# setattr hooks.
globs["_cached_setattr"] = _obj_setattr
- eval(bytecode, globs, locs)
-
- # In order of debuggers like PDB being able to step through the code,
- # we add a fake linecache entry.
- linecache.cache[unique_filename] = (
- len(script),
- None,
- script.splitlines(True),
+ init = _make_method(
+ "__attrs_init__" if attrs_init else "__init__",
+ script,
unique_filename,
+ globs,
)
-
- init = locs["__attrs_init__"] if attrs_init else locs["__init__"]
init.__annotations__ = annotations
return init
diff --git a/tests/test_annotations.py b/tests/test_annotations.py
index 3e19aa1..0b27099 100644
--- a/tests/test_annotations.py
+++ b/tests/test_annotations.py
@@ -578,3 +578,32 @@ class TestAnnotations:
assert typing.List[B] == attr.fields(A).a.type
assert A == attr.fields(B).a.type
+
+ def test_init_type_hints(self):
+ """
+ Forward references in __init__ can be automatically resolved.
+ """
+
+ @attr.s
+ class C:
+ x = attr.ib(type="typing.List[int]")
+
+ assert typing.get_type_hints(C.__init__) == {
+ "return": type(None),
+ "x": typing.List[int],
+ }
+
+ def test_init_type_hints_fake_module(self):
+ """
+ If you somehow set the __module__ to something that doesn't exist
+ you'll lose __init__ resolution.
+ """
+
+ class C:
+ x = attr.ib(type="typing.List[int]")
+
+ C.__module__ = "totally fake"
+ C = attr.s(C)
+
+ with pytest.raises(NameError):
+ typing.get_type_hints(C.__init__)