aboutsummaryrefslogtreecommitdiff
path: root/src/python/grpcio_tests/tests/_loader.py
blob: b9fc3ccf0f600db2d864024cdf7e87e3effb2d44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import importlib
import logging
import os
import pkgutil
import re
import sys
import unittest

import coverage

logger = logging.getLogger(__name__)

TEST_MODULE_REGEX = r"^.*_test$"


# Determines the path og a given path relative to the first matching
# path on sys.path. Useful for determining what a directory's module
# path will be.
def _relativize_to_sys_path(path):
    for sys_path in sys.path:
        if path.startswith(sys_path):
            relative = path[len(sys_path) :]
            if not relative:
                return ""
            if relative.startswith(os.path.sep):
                relative = relative[len(os.path.sep) :]
            if not relative.endswith(os.path.sep):
                relative += os.path.sep
            return relative
    raise AssertionError("Failed to relativize {} to sys.path.".format(path))


def _relative_path_to_module_prefix(path):
    return path.replace(os.path.sep, ".")


class Loader(object):
    """Test loader for setuptools test suite support.

    Attributes:
      suite (unittest.TestSuite): All tests collected by the loader.
      loader (unittest.TestLoader): Standard Python unittest loader to be ran per
        module discovered.
      module_matcher (re.RegexObject): A regular expression object to match
        against module names and determine whether or not the discovered module
        contributes to the test suite.
    """

    def __init__(self):
        self.suite = unittest.TestSuite()
        self.loader = unittest.TestLoader()
        self.module_matcher = re.compile(TEST_MODULE_REGEX)

    def loadTestsFromNames(self, names, module=None):
        """Function mirroring TestLoader::loadTestsFromNames, as expected by
        setuptools.setup argument `test_loader`."""
        # ensure that we capture decorators and definitions (else our coverage
        # measure unnecessarily suffers)
        coverage_context = coverage.Coverage(data_suffix=True)
        coverage_context.start()
        imported_modules = tuple(
            importlib.import_module(name) for name in names
        )
        for imported_module in imported_modules:
            self.visit_module(imported_module)
        for imported_module in imported_modules:
            try:
                package_paths = imported_module.__path__
            except AttributeError:
                continue
            self.walk_packages(package_paths)
        coverage_context.stop()
        coverage_context.save()
        return self.suite

    def walk_packages(self, package_paths):
        """Walks over the packages, dispatching `visit_module` calls.

        Args:
          package_paths (list): A list of paths over which to walk through modules
            along.
        """
        for path in package_paths:
            self._walk_package(path)

    def _walk_package(self, package_path):
        prefix = _relative_path_to_module_prefix(
            _relativize_to_sys_path(package_path)
        )
        for importer, module_name, is_package in pkgutil.walk_packages(
            [package_path], prefix
        ):
            module = None
            if module_name in sys.modules:
                module = sys.modules[module_name]
                self.visit_module(module)
            else:
                try:
                    spec = importer.find_spec(module_name)
                    module = importlib.util.module_from_spec(spec)
                    spec.loader.exec_module(module)
                    self.visit_module(module)
                except ModuleNotFoundError:
                    logger.debug("Skip loading %s", module_name)

    def visit_module(self, module):
        """Visits the module, adding discovered tests to the test suite.

        Args:
          module (module): Module to match against self.module_matcher; if matched
            it has its tests loaded via self.loader into self.suite.
        """
        if self.module_matcher.match(module.__name__):
            module_suite = self.loader.loadTestsFromModule(module)
            self.suite.addTest(module_suite)


def iterate_suite_cases(suite):
    """Generator over all unittest.TestCases in a unittest.TestSuite.

    Args:
      suite (unittest.TestSuite): Suite to iterate over in the generator.

    Returns:
      generator: A generator over all unittest.TestCases in `suite`.
    """
    for item in suite:
        if isinstance(item, unittest.TestSuite):
            for child_item in iterate_suite_cases(item):
                yield child_item
        elif isinstance(item, unittest.TestCase):
            yield item
        else:
            raise ValueError(
                "unexpected suite item of type {}".format(type(item))
            )