aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>2021-10-25 21:28:17 +0200
committerGitHub <noreply@github.com>2021-10-25 21:28:17 +0200
commit0a43490e7906216bac9e5adb16e1885ff4f04fcd (patch)
treebf0a47fc1254caa25ebecf693d546b5364159922
parentd3e285417bbd1f595a40891450ac553bcaf460b9 (diff)
downloadastroid-0a43490e7906216bac9e5adb16e1885ff4f04fcd.tar.gz
Refactor and add typing to ``NodeNG.frame()`` (#1225)
* Refactor and add typing to ``NodeNG.frame()``
-rw-r--r--astroid/nodes/node_ng.py11
-rw-r--r--astroid/nodes/scoped_nodes.py60
-rw-r--r--tests/unittest_scoped_nodes.py51
3 files changed, 96 insertions, 26 deletions
diff --git a/astroid/nodes/node_ng.py b/astroid/nodes/node_ng.py
index c9aa0e0e..e6d0d50b 100644
--- a/astroid/nodes/node_ng.py
+++ b/astroid/nodes/node_ng.py
@@ -25,7 +25,7 @@ from astroid.nodes.as_string import AsStringVisitor
from astroid.nodes.const import OP_PRECEDENCE
if TYPE_CHECKING:
- from astroid.nodes import LocalsDictNodeNG
+ from astroid import nodes
# Types for 'NodeNG.nodes_of_class()'
T_Nodes = TypeVar("T_Nodes", bound="NodeNG")
@@ -258,18 +258,19 @@ class NodeNG:
return self
return self.parent.statement()
- def frame(self):
+ def frame(
+ self,
+ ) -> Union["nodes.FunctionDef", "nodes.Module", "nodes.ClassDef", "nodes.Lambda"]:
"""The first parent frame node.
A frame node is a :class:`Module`, :class:`FunctionDef`,
- or :class:`ClassDef`.
+ :class:`ClassDef` or :class:`Lambda`.
:returns: The first parent frame node.
- :rtype: Module or FunctionDef or ClassDef
"""
return self.parent.frame()
- def scope(self) -> "LocalsDictNodeNG":
+ def scope(self) -> "nodes.LocalsDictNodeNG":
"""The first parent node defining a new scope.
These can be Module, FunctionDef, ClassDef, Lambda, or GeneratorExp nodes.
diff --git a/astroid/nodes/scoped_nodes.py b/astroid/nodes/scoped_nodes.py
index e9ccd4a2..b8da0ac0 100644
--- a/astroid/nodes/scoped_nodes.py
+++ b/astroid/nodes/scoped_nodes.py
@@ -230,17 +230,6 @@ class LocalsDictNodeNG(node_classes.LookupMixIn, node_classes.NodeNG):
return self.name
return f"{self.parent.frame().qname()}.{self.name}"
- def frame(self):
- """The first parent frame node.
-
- A frame node is a :class:`Module`, :class:`FunctionDef`,
- or :class:`ClassDef`.
-
- :returns: The first parent frame node.
- :rtype: Module or FunctionDef or ClassDef
- """
- return self
-
def scope(self: T) -> T:
"""The first parent node defining a new scope.
@@ -826,20 +815,19 @@ class Module(LocalsDictNodeNG):
def get_children(self):
yield from self.body
-
-class ComprehensionScope(LocalsDictNodeNG):
- """Scoping for different types of comprehensions."""
-
- def frame(self):
- """The first parent frame node.
+ def frame(self: T) -> T:
+ """The node's frame node.
A frame node is a :class:`Module`, :class:`FunctionDef`,
- or :class:`ClassDef`.
+ :class:`ClassDef` or :class:`Lambda`.
- :returns: The first parent frame node.
- :rtype: Module or FunctionDef or ClassDef
+ :returns: The node itself.
"""
- return self.parent.frame()
+ return self
+
+
+class ComprehensionScope(LocalsDictNodeNG):
+ """Scoping for different types of comprehensions."""
scope_lookup = LocalsDictNodeNG._scope_lookup
@@ -1344,6 +1332,16 @@ class Lambda(mixins.FilterStmtsMixin, LocalsDictNodeNG):
yield self.args
yield self.body
+ def frame(self: T) -> T:
+ """The node's frame node.
+
+ A frame node is a :class:`Module`, :class:`FunctionDef`,
+ :class:`ClassDef` or :class:`Lambda`.
+
+ :returns: The node itself.
+ """
+ return self
+
class FunctionDef(mixins.MultiLineBlockMixin, node_classes.Statement, Lambda):
"""Class representing an :class:`ast.FunctionDef`.
@@ -1839,6 +1837,16 @@ class FunctionDef(mixins.MultiLineBlockMixin, node_classes.Statement, Lambda):
return self, [frame]
return super().scope_lookup(node, name, offset)
+ def frame(self: T) -> T:
+ """The node's frame node.
+
+ A frame node is a :class:`Module`, :class:`FunctionDef`,
+ :class:`ClassDef` or :class:`Lambda`.
+
+ :returns: The node itself.
+ """
+ return self
+
class AsyncFunctionDef(FunctionDef):
"""Class representing an :class:`ast.FunctionDef` node.
@@ -3054,3 +3062,13 @@ class ClassDef(mixins.FilterStmtsMixin, LocalsDictNodeNG, node_classes.Statement
child_node._get_assign_nodes() for child_node in self.body
)
return list(itertools.chain.from_iterable(children_assign_nodes))
+
+ def frame(self: T) -> T:
+ """The node's frame node.
+
+ A frame node is a :class:`Module`, :class:`FunctionDef`,
+ :class:`ClassDef` or :class:`Lambda`.
+
+ :returns: The node itself.
+ """
+ return self
diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py
index 327ed445..7a44b105 100644
--- a/tests/unittest_scoped_nodes.py
+++ b/tests/unittest_scoped_nodes.py
@@ -44,6 +44,7 @@ import pytest
from astroid import MANAGER, builder, nodes, objects, test_utils, util
from astroid.bases import BoundMethod, Generator, Instance, UnboundMethod
+from astroid.const import PY38_PLUS
from astroid.exceptions import (
AttributeInferenceError,
DuplicateBasesError,
@@ -2289,5 +2290,55 @@ def test_slots_duplicate_bases_issue_1089() -> None:
astroid["First"].slots()
+class TestFrameNodes:
+ @pytest.mark.skipif(not PY38_PLUS, reason="needs assignment expressions")
+ @staticmethod
+ def test_frame_node():
+ """Test if the frame of FunctionDef, ClassDef and Module is correctly set"""
+ module = builder.parse(
+ """
+ def func():
+ var_1 = x
+ return var_1
+
+ class MyClass:
+
+ attribute = 1
+
+ def method():
+ pass
+
+ VAR = lambda y = (named_expr := "walrus"): print(y)
+ """
+ )
+ function = module.body[0]
+ assert function.frame() == function
+ assert function.body[0].frame() == function
+
+ class_node = module.body[1]
+ assert class_node.frame() == class_node
+ assert class_node.body[0].frame() == class_node
+ assert class_node.body[1].frame() == class_node.body[1]
+
+ lambda_assignment = module.body[2].value
+ assert lambda_assignment.args.args[0].frame() == lambda_assignment
+
+ assert module.frame() == module
+
+ @staticmethod
+ def test_non_frame_node():
+ """Test if the frame of non frame nodes is set correctly"""
+ module = builder.parse(
+ """
+ VAR_ONE = 1
+
+ VAR_TWO = [x for x in range(1)]
+ """
+ )
+ assert module.body[0].frame() == module
+
+ assert module.body[1].value.locals["x"][0].frame() == module
+
+
if __name__ == "__main__":
unittest.main()