aboutsummaryrefslogtreecommitdiff
path: root/tests/unittest_brain_numpy_ndarray.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest_brain_numpy_ndarray.py')
-rw-r--r--tests/unittest_brain_numpy_ndarray.py188
1 files changed, 188 insertions, 0 deletions
diff --git a/tests/unittest_brain_numpy_ndarray.py b/tests/unittest_brain_numpy_ndarray.py
new file mode 100644
index 00000000..1a417b85
--- /dev/null
+++ b/tests/unittest_brain_numpy_ndarray.py
@@ -0,0 +1,188 @@
+# Copyright (c) 2017-2021 hippo91 <guillaume.peillex@gmail.com>
+# Copyright (c) 2017-2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
+# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
+# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
+# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
+# Copyright (c) 2021 Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>
+# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
+
+# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
+# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
+import unittest
+
+try:
+ import numpy # pylint: disable=unused-import
+
+ HAS_NUMPY = True
+except ImportError:
+ HAS_NUMPY = False
+
+from astroid import builder, nodes
+from astroid.brain.brain_numpy_utils import (
+ NUMPY_VERSION_TYPE_HINTS_SUPPORT,
+ numpy_supports_type_hints,
+)
+
+
+@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
+class NumpyBrainNdarrayTest(unittest.TestCase):
+ """
+ Test that calls to numpy functions returning arrays are correctly inferred
+ """
+
+ ndarray_returning_ndarray_methods = (
+ "__abs__",
+ "__add__",
+ "__and__",
+ "__array__",
+ "__array_wrap__",
+ "__copy__",
+ "__deepcopy__",
+ "__eq__",
+ "__floordiv__",
+ "__ge__",
+ "__gt__",
+ "__iadd__",
+ "__iand__",
+ "__ifloordiv__",
+ "__ilshift__",
+ "__imod__",
+ "__imul__",
+ "__invert__",
+ "__ior__",
+ "__ipow__",
+ "__irshift__",
+ "__isub__",
+ "__itruediv__",
+ "__ixor__",
+ "__le__",
+ "__lshift__",
+ "__lt__",
+ "__matmul__",
+ "__mod__",
+ "__mul__",
+ "__ne__",
+ "__neg__",
+ "__or__",
+ "__pos__",
+ "__pow__",
+ "__rshift__",
+ "__sub__",
+ "__truediv__",
+ "__xor__",
+ "all",
+ "any",
+ "argmax",
+ "argmin",
+ "argpartition",
+ "argsort",
+ "astype",
+ "byteswap",
+ "choose",
+ "clip",
+ "compress",
+ "conj",
+ "conjugate",
+ "copy",
+ "cumprod",
+ "cumsum",
+ "diagonal",
+ "dot",
+ "flatten",
+ "getfield",
+ "max",
+ "mean",
+ "min",
+ "newbyteorder",
+ "prod",
+ "ptp",
+ "ravel",
+ "repeat",
+ "reshape",
+ "round",
+ "searchsorted",
+ "squeeze",
+ "std",
+ "sum",
+ "swapaxes",
+ "take",
+ "trace",
+ "transpose",
+ "var",
+ "view",
+ )
+
+ def _inferred_ndarray_method_call(self, func_name):
+ node = builder.extract_node(
+ f"""
+ import numpy as np
+ test_array = np.ndarray((2, 2))
+ test_array.{func_name:s}()
+ """
+ )
+ return node.infer()
+
+ def _inferred_ndarray_attribute(self, attr_name):
+ node = builder.extract_node(
+ f"""
+ import numpy as np
+ test_array = np.ndarray((2, 2))
+ test_array.{attr_name:s}
+ """
+ )
+ return node.infer()
+
+ def test_numpy_function_calls_inferred_as_ndarray(self):
+ """
+ Test that some calls to numpy functions are inferred as numpy.ndarray
+ """
+ licit_array_types = ".ndarray"
+ for func_ in self.ndarray_returning_ndarray_methods:
+ with self.subTest(typ=func_):
+ inferred_values = list(self._inferred_ndarray_method_call(func_))
+ self.assertTrue(
+ len(inferred_values) == 1,
+ msg=f"Too much inferred value for {func_:s}",
+ )
+ self.assertTrue(
+ inferred_values[-1].pytype() in licit_array_types,
+ msg=f"Illicit type for {func_:s} ({inferred_values[-1].pytype()})",
+ )
+
+ def test_numpy_ndarray_attribute_inferred_as_ndarray(self):
+ """
+ Test that some numpy ndarray attributes are inferred as numpy.ndarray
+ """
+ licit_array_types = ".ndarray"
+ for attr_ in ("real", "imag", "shape", "T"):
+ with self.subTest(typ=attr_):
+ inferred_values = list(self._inferred_ndarray_attribute(attr_))
+ self.assertTrue(
+ len(inferred_values) == 1,
+ msg=f"Too much inferred value for {attr_:s}",
+ )
+ self.assertTrue(
+ inferred_values[-1].pytype() in licit_array_types,
+ msg=f"Illicit type for {attr_:s} ({inferred_values[-1].pytype()})",
+ )
+
+ @unittest.skipUnless(
+ HAS_NUMPY and numpy_supports_type_hints(),
+ f"This test requires the numpy library with a version above {NUMPY_VERSION_TYPE_HINTS_SUPPORT}",
+ )
+ def test_numpy_ndarray_class_support_type_indexing(self):
+ """
+ Test that numpy ndarray class can be subscripted (type hints)
+ """
+ src = """
+ import numpy as np
+ np.ndarray[int]
+ """
+ node = builder.extract_node(src)
+ cls_node = node.inferred()[0]
+ self.assertIsInstance(cls_node, nodes.ClassDef)
+ self.assertEqual(cls_node.name, "ndarray")
+
+
+if __name__ == "__main__":
+ unittest.main()