diff options
author | David Euresti <david@euresti.com> | 2021-02-18 21:50:41 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-19 06:50:41 +0100 |
commit | 1499c77e79315ef8c5c98c088dc3ef99bdba72b4 (patch) | |
tree | 5d93313283435d13fbb3ac14a78b4d22f01786b4 | |
parent | 78335e9b49eff2eaf2dc31d0bcc3caa0169cfc60 (diff) | |
download | attrs-1499c77e79315ef8c5c98c088dc3ef99bdba72b4.tar.gz |
Fix issue with get_type_hints(cls.__init__) and refactor (#760)
* Fix issue with get_type_hints(cls.__init__)
* Refactor
* Improve coverage
-rw-r--r-- | src/attr/_make.py | 84 | ||||
-rw-r--r-- | tests/test_annotations.py | 29 |
2 files changed, 70 insertions, 43 deletions
diff --git a/src/attr/_make.py b/src/attr/_make.py index 8bc8634..76b1c62 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -286,6 +286,36 @@ def attrib( ) +def _compile_and_eval(script, globs, locs=None, filename=""): + """ + "Exec" the script with the given global (globs) and local (locs) variables. + """ + bytecode = compile(script, filename, "exec") + eval(bytecode, globs, locs) + + +def _make_method(name, script, filename, globs=None): + """ + Create the method with the script given and return the method object. + """ + locs = {} + if globs is None: + globs = {} + + _compile_and_eval(script, globs, locs, filename) + + # In order of debuggers like PDB being able to step through the code, + # we add a fake linecache entry. + linecache.cache[filename] = ( + len(script), + None, + script.splitlines(True), + filename, + ) + + return locs[name] + + def _make_attr_tuple_class(cls_name, attr_names): """ Create a tuple subclass to hold `Attribute`s for an `attrs` class. @@ -309,8 +339,7 @@ def _make_attr_tuple_class(cls_name, attr_names): else: attr_class_template.append(" pass") globs = {"_attrs_itemgetter": itemgetter, "_attrs_property": property} - eval(compile("\n".join(attr_class_template), "", "exec"), globs) - + _compile_and_eval("\n".join(attr_class_template), globs) return globs[attr_class_name] @@ -1591,21 +1620,7 @@ def _make_hash(cls, attrs, frozen, cache_hash): append_hash_computation_lines("return ", tab) script = "\n".join(method_lines) - globs = {} - locs = {} - bytecode = compile(script, unique_filename, "exec") - eval(bytecode, globs, locs) - - # In order of debuggers like PDB being able to step through the code, - # we add a fake linecache entry. - linecache.cache[unique_filename] = ( - len(script), - None, - script.splitlines(True), - unique_filename, - ) - - return locs["__hash__"] + return _make_method("__hash__", script, unique_filename) def _add_hash(cls, attrs): @@ -1661,20 +1676,7 @@ def _make_eq(cls, attrs): lines.append(" return True") script = "\n".join(lines) - globs = {} - locs = {} - bytecode = compile(script, unique_filename, "exec") - eval(bytecode, globs, locs) - - # In order of debuggers like PDB being able to step through the code, - # we add a fake linecache entry. - linecache.cache[unique_filename] = ( - len(script), - None, - script.splitlines(True), - unique_filename, - ) - return locs["__eq__"] + return _make_method("__eq__", script, unique_filename) def _make_order(cls, attrs): @@ -1949,8 +1951,10 @@ def _make_init( has_global_on_setattr, attrs_init, ) - locs = {} - bytecode = compile(script, unique_filename, "exec") + if cls.__module__ in sys.modules: + # This makes typing.get_type_hints(CLS.__init__) resolve string types. + globs.update(sys.modules[cls.__module__].__dict__) + globs.update({"NOTHING": NOTHING, "attr_dict": attr_dict}) if needs_cached_setattr: @@ -1958,18 +1962,12 @@ def _make_init( # setattr hooks. globs["_cached_setattr"] = _obj_setattr - eval(bytecode, globs, locs) - - # In order of debuggers like PDB being able to step through the code, - # we add a fake linecache entry. - linecache.cache[unique_filename] = ( - len(script), - None, - script.splitlines(True), + init = _make_method( + "__attrs_init__" if attrs_init else "__init__", + script, unique_filename, + globs, ) - - init = locs["__attrs_init__"] if attrs_init else locs["__init__"] init.__annotations__ = annotations return init diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 3e19aa1..0b27099 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -578,3 +578,32 @@ class TestAnnotations: assert typing.List[B] == attr.fields(A).a.type assert A == attr.fields(B).a.type + + def test_init_type_hints(self): + """ + Forward references in __init__ can be automatically resolved. + """ + + @attr.s + class C: + x = attr.ib(type="typing.List[int]") + + assert typing.get_type_hints(C.__init__) == { + "return": type(None), + "x": typing.List[int], + } + + def test_init_type_hints_fake_module(self): + """ + If you somehow set the __module__ to something that doesn't exist + you'll lose __init__ resolution. + """ + + class C: + x = attr.ib(type="typing.List[int]") + + C.__module__ = "totally fake" + C = attr.s(C) + + with pytest.raises(NameError): + typing.get_type_hints(C.__init__) |