diff options
Diffstat (limited to 'tests/unittest_brain_numpy_ndarray.py')
-rw-r--r-- | tests/unittest_brain_numpy_ndarray.py | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/tests/unittest_brain_numpy_ndarray.py b/tests/unittest_brain_numpy_ndarray.py new file mode 100644 index 00000000..1a417b85 --- /dev/null +++ b/tests/unittest_brain_numpy_ndarray.py @@ -0,0 +1,188 @@ +# Copyright (c) 2017-2021 hippo91 <guillaume.peillex@gmail.com> +# Copyright (c) 2017-2018, 2020 Claudiu Popa <pcmanticore@gmail.com> +# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com> +# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk> +# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com> +# Copyright (c) 2021 Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com> +# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com> + +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +import unittest + +try: + import numpy # pylint: disable=unused-import + + HAS_NUMPY = True +except ImportError: + HAS_NUMPY = False + +from astroid import builder, nodes +from astroid.brain.brain_numpy_utils import ( + NUMPY_VERSION_TYPE_HINTS_SUPPORT, + numpy_supports_type_hints, +) + + +@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.") +class NumpyBrainNdarrayTest(unittest.TestCase): + """ + Test that calls to numpy functions returning arrays are correctly inferred + """ + + ndarray_returning_ndarray_methods = ( + "__abs__", + "__add__", + "__and__", + "__array__", + "__array_wrap__", + "__copy__", + "__deepcopy__", + "__eq__", + "__floordiv__", + "__ge__", + "__gt__", + "__iadd__", + "__iand__", + "__ifloordiv__", + "__ilshift__", + "__imod__", + "__imul__", + "__invert__", + "__ior__", + "__ipow__", + "__irshift__", + "__isub__", + "__itruediv__", + "__ixor__", + "__le__", + "__lshift__", + "__lt__", + "__matmul__", + "__mod__", + "__mul__", + "__ne__", + "__neg__", + "__or__", + "__pos__", + "__pow__", + "__rshift__", + "__sub__", + "__truediv__", + "__xor__", + "all", + "any", + "argmax", + "argmin", + "argpartition", + "argsort", + "astype", + "byteswap", + "choose", + "clip", + "compress", + "conj", + "conjugate", + "copy", + "cumprod", + "cumsum", + "diagonal", + "dot", + "flatten", + "getfield", + "max", + "mean", + "min", + "newbyteorder", + "prod", + "ptp", + "ravel", + "repeat", + "reshape", + "round", + "searchsorted", + "squeeze", + "std", + "sum", + "swapaxes", + "take", + "trace", + "transpose", + "var", + "view", + ) + + def _inferred_ndarray_method_call(self, func_name): + node = builder.extract_node( + f""" + import numpy as np + test_array = np.ndarray((2, 2)) + test_array.{func_name:s}() + """ + ) + return node.infer() + + def _inferred_ndarray_attribute(self, attr_name): + node = builder.extract_node( + f""" + import numpy as np + test_array = np.ndarray((2, 2)) + test_array.{attr_name:s} + """ + ) + return node.infer() + + def test_numpy_function_calls_inferred_as_ndarray(self): + """ + Test that some calls to numpy functions are inferred as numpy.ndarray + """ + licit_array_types = ".ndarray" + for func_ in self.ndarray_returning_ndarray_methods: + with self.subTest(typ=func_): + inferred_values = list(self._inferred_ndarray_method_call(func_)) + self.assertTrue( + len(inferred_values) == 1, + msg=f"Too much inferred value for {func_:s}", + ) + self.assertTrue( + inferred_values[-1].pytype() in licit_array_types, + msg=f"Illicit type for {func_:s} ({inferred_values[-1].pytype()})", + ) + + def test_numpy_ndarray_attribute_inferred_as_ndarray(self): + """ + Test that some numpy ndarray attributes are inferred as numpy.ndarray + """ + licit_array_types = ".ndarray" + for attr_ in ("real", "imag", "shape", "T"): + with self.subTest(typ=attr_): + inferred_values = list(self._inferred_ndarray_attribute(attr_)) + self.assertTrue( + len(inferred_values) == 1, + msg=f"Too much inferred value for {attr_:s}", + ) + self.assertTrue( + inferred_values[-1].pytype() in licit_array_types, + msg=f"Illicit type for {attr_:s} ({inferred_values[-1].pytype()})", + ) + + @unittest.skipUnless( + HAS_NUMPY and numpy_supports_type_hints(), + f"This test requires the numpy library with a version above {NUMPY_VERSION_TYPE_HINTS_SUPPORT}", + ) + def test_numpy_ndarray_class_support_type_indexing(self): + """ + Test that numpy ndarray class can be subscripted (type hints) + """ + src = """ + import numpy as np + np.ndarray[int] + """ + node = builder.extract_node(src) + cls_node = node.inferred()[0] + self.assertIsInstance(cls_node, nodes.ClassDef) + self.assertEqual(cls_node.name, "ndarray") + + +if __name__ == "__main__": + unittest.main() |