aboutsummaryrefslogtreecommitdiff
path: root/tests/unittest_brain_numpy_core_umath.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest_brain_numpy_core_umath.py')
-rw-r--r--tests/unittest_brain_numpy_core_umath.py269
1 files changed, 269 insertions, 0 deletions
diff --git a/tests/unittest_brain_numpy_core_umath.py b/tests/unittest_brain_numpy_core_umath.py
new file mode 100644
index 00000000..c80c391a
--- /dev/null
+++ b/tests/unittest_brain_numpy_core_umath.py
@@ -0,0 +1,269 @@
+# Copyright (c) 2019-2021 hippo91 <guillaume.peillex@gmail.com>
+# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
+# Copyright (c) 2020 Claudiu Popa <pcmanticore@gmail.com>
+# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
+# Copyright (c) 2021 Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>
+# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
+# Copyright (c) 2021 Andrew Haigh <hello@nelf.in>
+
+# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
+# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
+import unittest
+
+try:
+ import numpy # pylint: disable=unused-import
+
+ HAS_NUMPY = True
+except ImportError:
+ HAS_NUMPY = False
+
+from astroid import bases, builder, nodes
+
+
+@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
+class NumpyBrainCoreUmathTest(unittest.TestCase):
+ """
+ Test of all members of numpy.core.umath module
+ """
+
+ one_arg_ufunc = (
+ "arccos",
+ "arccosh",
+ "arcsin",
+ "arcsinh",
+ "arctan",
+ "arctanh",
+ "cbrt",
+ "conj",
+ "conjugate",
+ "cosh",
+ "deg2rad",
+ "degrees",
+ "exp2",
+ "expm1",
+ "fabs",
+ "frexp",
+ "isfinite",
+ "isinf",
+ "log",
+ "log1p",
+ "log2",
+ "logical_not",
+ "modf",
+ "negative",
+ "positive",
+ "rad2deg",
+ "radians",
+ "reciprocal",
+ "rint",
+ "sign",
+ "signbit",
+ "spacing",
+ "square",
+ "tan",
+ "tanh",
+ "trunc",
+ )
+
+ two_args_ufunc = (
+ "add",
+ "bitwise_and",
+ "bitwise_or",
+ "bitwise_xor",
+ "copysign",
+ "divide",
+ "divmod",
+ "equal",
+ "float_power",
+ "floor_divide",
+ "fmax",
+ "fmin",
+ "fmod",
+ "gcd",
+ "greater",
+ "heaviside",
+ "hypot",
+ "lcm",
+ "ldexp",
+ "left_shift",
+ "less",
+ "logaddexp",
+ "logaddexp2",
+ "logical_and",
+ "logical_or",
+ "logical_xor",
+ "maximum",
+ "minimum",
+ "multiply",
+ "nextafter",
+ "not_equal",
+ "power",
+ "remainder",
+ "right_shift",
+ "subtract",
+ "true_divide",
+ )
+
+ all_ufunc = one_arg_ufunc + two_args_ufunc
+
+ constants = ("e", "euler_gamma")
+
+ def _inferred_numpy_attribute(self, func_name):
+ node = builder.extract_node(
+ f"""
+ import numpy.core.umath as tested_module
+ func = tested_module.{func_name:s}
+ func"""
+ )
+ return next(node.infer())
+
+ def test_numpy_core_umath_constants(self):
+ """
+ Test that constants have Const type.
+ """
+ for const in self.constants:
+ with self.subTest(const=const):
+ inferred = self._inferred_numpy_attribute(const)
+ self.assertIsInstance(inferred, nodes.Const)
+
+ def test_numpy_core_umath_constants_values(self):
+ """
+ Test the values of the constants.
+ """
+ exact_values = {"e": 2.718281828459045, "euler_gamma": 0.5772156649015329}
+ for const in self.constants:
+ with self.subTest(const=const):
+ inferred = self._inferred_numpy_attribute(const)
+ self.assertEqual(inferred.value, exact_values[const])
+
+ def test_numpy_core_umath_functions(self):
+ """
+ Test that functions have FunctionDef type.
+ """
+ for func in self.all_ufunc:
+ with self.subTest(func=func):
+ inferred = self._inferred_numpy_attribute(func)
+ self.assertIsInstance(inferred, bases.Instance)
+
+ def test_numpy_core_umath_functions_one_arg(self):
+ """
+ Test the arguments names of functions.
+ """
+ exact_arg_names = [
+ "self",
+ "x",
+ "out",
+ "where",
+ "casting",
+ "order",
+ "dtype",
+ "subok",
+ ]
+ for func in self.one_arg_ufunc:
+ with self.subTest(func=func):
+ inferred = self._inferred_numpy_attribute(func)
+ self.assertEqual(
+ inferred.getattr("__call__")[0].argnames(), exact_arg_names
+ )
+
+ def test_numpy_core_umath_functions_two_args(self):
+ """
+ Test the arguments names of functions.
+ """
+ exact_arg_names = [
+ "self",
+ "x1",
+ "x2",
+ "out",
+ "where",
+ "casting",
+ "order",
+ "dtype",
+ "subok",
+ ]
+ for func in self.two_args_ufunc:
+ with self.subTest(func=func):
+ inferred = self._inferred_numpy_attribute(func)
+ self.assertEqual(
+ inferred.getattr("__call__")[0].argnames(), exact_arg_names
+ )
+
+ def test_numpy_core_umath_functions_kwargs_default_values(self):
+ """
+ Test the default values for keyword arguments.
+ """
+ exact_kwargs_default_values = [None, True, "same_kind", "K", None, True]
+ for func in self.one_arg_ufunc + self.two_args_ufunc:
+ with self.subTest(func=func):
+ inferred = self._inferred_numpy_attribute(func)
+ default_args_values = [
+ default.value
+ for default in inferred.getattr("__call__")[0].args.defaults
+ ]
+ self.assertEqual(default_args_values, exact_kwargs_default_values)
+
+ def _inferred_numpy_func_call(self, func_name, *func_args):
+ node = builder.extract_node(
+ f"""
+ import numpy as np
+ func = np.{func_name:s}
+ func()
+ """
+ )
+ return node.infer()
+
+ def test_numpy_core_umath_functions_return_type(self):
+ """
+ Test that functions which should return a ndarray do return it
+ """
+ ndarray_returning_func = [
+ f for f in self.all_ufunc if f not in ("frexp", "modf")
+ ]
+ for func_ in ndarray_returning_func:
+ with self.subTest(typ=func_):
+ inferred_values = list(self._inferred_numpy_func_call(func_))
+ self.assertTrue(
+ len(inferred_values) == 1,
+ msg="Too much inferred values ({}) for {:s}".format(
+ inferred_values[-1].pytype(), func_
+ ),
+ )
+ self.assertTrue(
+ inferred_values[0].pytype() == ".ndarray",
+ msg=f"Illicit type for {func_:s} ({inferred_values[-1].pytype()})",
+ )
+
+ def test_numpy_core_umath_functions_return_type_tuple(self):
+ """
+ Test that functions which should return a pair of ndarray do return it
+ """
+ ndarray_returning_func = ("frexp", "modf")
+
+ for func_ in ndarray_returning_func:
+ with self.subTest(typ=func_):
+ inferred_values = list(self._inferred_numpy_func_call(func_))
+ self.assertTrue(
+ len(inferred_values) == 1,
+ msg=f"Too much inferred values ({inferred_values}) for {func_:s}",
+ )
+ self.assertTrue(
+ inferred_values[-1].pytype() == "builtins.tuple",
+ msg=f"Illicit type for {func_:s} ({inferred_values[-1].pytype()})",
+ )
+ self.assertTrue(
+ len(inferred_values[0].elts) == 2,
+ msg=f"{func_} should return a pair of values. That's not the case.",
+ )
+ for array in inferred_values[-1].elts:
+ effective_infer = [m.pytype() for m in array.inferred()]
+ self.assertTrue(
+ ".ndarray" in effective_infer,
+ msg=(
+ f"Each item in the return of {func_} should be inferred"
+ f" as a ndarray and not as {effective_infer}"
+ ),
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()