diff options
author | Yilei "Dolee" Yang <yileiyang@google.com> | 2023-05-24 08:03:47 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-24 08:03:47 -0700 |
commit | 9adb4e7f2ee6a52f1722129461a7dfde01c63fbb (patch) | |
tree | 61b593dfeb5ec89fa16fc4b617c6fe1f18220819 | |
parent | c5c609cf04ea3f46eb620eb1b948ee2294645c4a (diff) | |
parent | 78e3ae4767ab981ce0e13ca555cd95cf6cd9db6d (diff) | |
download | absl-py-9adb4e7f2ee6a52f1722129461a7dfde01c63fbb.tar.gz |
Merge pull request #225 from yilei/push_up_to_532188463
Push up to 532188463
31 files changed, 1794 insertions, 1815 deletions
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..08e5794 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,49 @@ +name: Test + +on: [push, pull_request] + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.run_id }} + cancel-in-progress: true + +jobs: + test: + if: + github.event_name == 'push' || github.event.pull_request.head.repo.full_name != + github.repository + + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + os: [ubuntu-latest, macOS-latest, windows-latest] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + id: setup_python + with: + python-version: ${{ matrix.python-version }} + + - name: Install virtualenv + run: | + python -m pip install --upgrade pip + python -m pip install --upgrade virtualenv + - name: Run tests + env: + ABSL_EXPECTED_PYTHON_VERSION: ${{ matrix.python-version }} + ABSL_COPY_TESTLOGS_TO: ci-artifacts + shell: bash + run: ci/run_tests.sh + + - name: Upload bazel test logs + uses: actions/upload-artifact@v3 + with: + name: bazel-testlogs-${{ matrix.os }}-${{ matrix.python-version }} + path: ci-artifacts diff --git a/CHANGELOG.md b/CHANGELOG.md index 58fdf6a..ff009c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,16 +9,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). ### Changed * `absl-py` no longer supports Python 3.6. It has reached end-of-life for more - than a year now. + than a year now. * (testing) For Python 3.11+, the calls to `absltest.TestCase.enter_context` are forwarded to `unittest.TestCase.enterContext` (when called via instance) or `unittest.TestCase.enterClassContext` (when called via class) now. As a result, on Python 3.11+, the private `_cls_exit_stack` attribute is not defined on `absltest.TestCase` and `_exit_stack` attribute is not defined on its instances. -* `AbslTest.assertSameStructure()` now uses the test case's equality - functions (registered with `TestCase.addTypeEqualityFunc()`) for comparing - leaves of the structure. +* `AbslTest.assertSameStructure()` now uses the test case's equality functions + (registered with `TestCase.addTypeEqualityFunc()`) for comparing leaves of + the structure. +* `DEFINE_enum`, `DEFINE_multi_enum`, and `EnumParser` now raise errors when + `enum_values` is provided as a single string value. Additionally, + `EnumParser.enum_values` is now stored as a list copy of the provided + `enum_values` parameter. ## 1.4.0 (2023-01-11) diff --git a/absl/app.py b/absl/app.py index 79fc1f0..d12397b 100644 --- a/absl/app.py +++ b/absl/app.py @@ -238,6 +238,7 @@ def _run_main(main, argv): elif FLAGS.run_with_profiling or FLAGS.profile_file: # Avoid import overhead since most apps (including performance-sensitive # ones) won't be run with profiling. + # pylint: disable=g-import-not-at-top import atexit if FLAGS.use_cprofile_for_profiling: import cProfile as profile diff --git a/absl/command_name.py b/absl/command_name.py index 1996493..9260fee 100644 --- a/absl/command_name.py +++ b/absl/command_name.py @@ -47,7 +47,7 @@ def set_kernel_process_name(name): proc_comm.write(name[:15]) except EnvironmentError: try: - import ctypes + import ctypes # pylint: disable=g-import-not-at-top except ImportError: return # No ctypes. try: diff --git a/absl/flags/__init__.pyi b/absl/flags/__init__.pyi deleted file mode 100644 index 7bf6842..0000000 --- a/absl/flags/__init__.pyi +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2017 The Abseil Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.flags import _argument_parser -from absl.flags import _defines -from absl.flags import _exceptions -from absl.flags import _flag -from absl.flags import _flagvalues -from absl.flags import _helpers -from absl.flags import _validators - -# DEFINE functions. They are explained in more details in the module doc string. -# pylint: disable=invalid-name -DEFINE = _defines.DEFINE -DEFINE_flag = _defines.DEFINE_flag -DEFINE_string = _defines.DEFINE_string -DEFINE_boolean = _defines.DEFINE_boolean -DEFINE_bool = DEFINE_boolean # Match C++ API. -DEFINE_float = _defines.DEFINE_float -DEFINE_integer = _defines.DEFINE_integer -DEFINE_enum = _defines.DEFINE_enum -DEFINE_enum_class = _defines.DEFINE_enum_class -DEFINE_list = _defines.DEFINE_list -DEFINE_spaceseplist = _defines.DEFINE_spaceseplist -DEFINE_multi = _defines.DEFINE_multi -DEFINE_multi_string = _defines.DEFINE_multi_string -DEFINE_multi_integer = _defines.DEFINE_multi_integer -DEFINE_multi_float = _defines.DEFINE_multi_float -DEFINE_multi_enum = _defines.DEFINE_multi_enum -DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class -DEFINE_alias = _defines.DEFINE_alias -# pylint: enable=invalid-name - -# Flag validators. -register_validator = _validators.register_validator -validator = _validators.validator -register_multi_flags_validator = _validators.register_multi_flags_validator -multi_flags_validator = _validators.multi_flags_validator -mark_flag_as_required = _validators.mark_flag_as_required -mark_flags_as_required = _validators.mark_flags_as_required -mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive -mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive - -# Flag modifiers. -set_default = _defines.set_default - -# Key flag related functions. -declare_key_flag = _defines.declare_key_flag -adopt_module_key_flags = _defines.adopt_module_key_flags -disclaim_key_flags = _defines.disclaim_key_flags - -# Module exceptions. -# pylint: disable=invalid-name -Error = _exceptions.Error -CantOpenFlagFileError = _exceptions.CantOpenFlagFileError -DuplicateFlagError = _exceptions.DuplicateFlagError -IllegalFlagValueError = _exceptions.IllegalFlagValueError -UnrecognizedFlagError = _exceptions.UnrecognizedFlagError -UnparsedFlagAccessError = _exceptions.UnparsedFlagAccessError -ValidationError = _exceptions.ValidationError -FlagNameConflictsWithMethodError = _exceptions.FlagNameConflictsWithMethodError - -# Public classes. -Flag = _flag.Flag -BooleanFlag = _flag.BooleanFlag -EnumFlag = _flag.EnumFlag -EnumClassFlag = _flag.EnumClassFlag -MultiFlag = _flag.MultiFlag -MultiEnumClassFlag = _flag.MultiEnumClassFlag -FlagHolder = _flagvalues.FlagHolder -FlagValues = _flagvalues.FlagValues -ArgumentParser = _argument_parser.ArgumentParser -BooleanParser = _argument_parser.BooleanParser -EnumParser = _argument_parser.EnumParser -EnumClassParser = _argument_parser.EnumClassParser -ArgumentSerializer = _argument_parser.ArgumentSerializer -FloatParser = _argument_parser.FloatParser -IntegerParser = _argument_parser.IntegerParser -BaseListParser = _argument_parser.BaseListParser -ListParser = _argument_parser.ListParser -ListSerializer = _argument_parser.ListSerializer -CsvListSerializer = _argument_parser.CsvListSerializer -WhitespaceSeparatedListParser = _argument_parser.WhitespaceSeparatedListParser -EnumClassSerializer = _argument_parser.EnumClassSerializer -# pylint: enable=invalid-name - -# Helper functions. -get_help_width = _helpers.get_help_width -text_wrap = _helpers.text_wrap -flag_dict_to_args = _helpers.flag_dict_to_args -doc_to_help = _helpers.doc_to_help - -# The global FlagValues instance. -FLAGS = _flagvalues.FLAGS - diff --git a/absl/flags/_argument_parser.py b/absl/flags/_argument_parser.py index 2c4de9b..ee40d6e 100644 --- a/absl/flags/_argument_parser.py +++ b/absl/flags/_argument_parser.py @@ -20,11 +20,18 @@ aliases defined at the package level instead. import collections import csv +import enum import io import string +from typing import Generic, List, Iterable, Optional, Sequence, Text, Type, TypeVar, Union +from xml.dom import minidom from absl.flags import _helpers +_T = TypeVar('_T') +_ET = TypeVar('_ET', bound=enum.Enum) +_N = TypeVar('_N', int, float) + def _is_integer_type(instance): """Returns True if instance is an integer, and not a bool.""" @@ -72,25 +79,7 @@ class _ArgumentParserCache(type): return type.__call__(cls, *args) -# NOTE about Genericity and Metaclass of ArgumentParser. -# (1) In the .py source (this file) -# - is not declared as Generic -# - has _ArgumentParserCache as a metaclass -# (2) In the .pyi source (type stub) -# - is declared as Generic -# - doesn't have a metaclass -# The reason we need this is due to Generic having a different metaclass -# (for python versions <= 3.7) and a class can have only one metaclass. -# -# * Lack of metaclass in .pyi is not a deal breaker, since the metaclass -# doesn't affect any type information. Also type checkers can check the type -# parameters. -# * However, not declaring ArgumentParser as Generic in the source affects -# runtime annotation processing. In particular this means, subclasses should -# inherit from `ArgumentParser` and not `ArgumentParser[SomeType]`. -# The corresponding DEFINE_someType method (the public API) can be annotated -# to return FlagHolder[SomeType]. -class ArgumentParser(metaclass=_ArgumentParserCache): +class ArgumentParser(Generic[_T], metaclass=_ArgumentParserCache): """Base class used to parse and convert arguments. The :meth:`parse` method checks to make sure that the string argument is a @@ -106,9 +95,9 @@ class ArgumentParser(metaclass=_ArgumentParserCache): member variables must be derived from initializer arguments only. """ - syntactic_help = '' + syntactic_help: Text = '' - def parse(self, argument): + def parse(self, argument: Text) -> Optional[_T]: """Parses the string argument and returns the native value. By default it returns its argument unmodified. @@ -128,11 +117,13 @@ class ArgumentParser(metaclass=_ArgumentParserCache): type(argument))) return argument - def flag_type(self): + def flag_type(self) -> Text: """Returns a string representing the type of the flag.""" return 'string' - def _custom_xml_dom_elements(self, doc): + def _custom_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: """Returns a list of minidom.Element to add additional flag information. Args: @@ -142,33 +133,38 @@ class ArgumentParser(metaclass=_ArgumentParserCache): return [] -class ArgumentSerializer(object): +class ArgumentSerializer(Generic[_T]): """Base class for generating string representations of a flag value.""" - def serialize(self, value): + def serialize(self, value: _T) -> Text: """Returns a serialized string of the value.""" return str(value) -class NumericParser(ArgumentParser): +class NumericParser(ArgumentParser[_N]): """Parser of numeric values. Parsed value may be bounded to a given upper and lower bound. """ - def is_outside_bounds(self, val): + lower_bound: Optional[_N] + upper_bound: Optional[_N] + + def is_outside_bounds(self, val: _N) -> bool: """Returns whether the value is outside the bounds or not.""" return ((self.lower_bound is not None and val < self.lower_bound) or (self.upper_bound is not None and val > self.upper_bound)) - def parse(self, argument): + def parse(self, argument: Text) -> _N: """See base class.""" val = self.convert(argument) if self.is_outside_bounds(val): raise ValueError('%s is not %s' % (val, self.syntactic_help)) return val - def _custom_xml_dom_elements(self, doc): + def _custom_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: elements = [] if self.lower_bound is not None: elements.append(_helpers.create_xml_dom_element( @@ -178,7 +174,7 @@ class NumericParser(ArgumentParser): doc, 'upper_bound', self.upper_bound)) return elements - def convert(self, argument): + def convert(self, argument: Text) -> _N: """Returns the correct numeric value of argument. Subclass must implement this method, and raise TypeError if argument is not @@ -194,7 +190,7 @@ class NumericParser(ArgumentParser): raise NotImplementedError -class FloatParser(NumericParser): +class FloatParser(NumericParser[float]): """Parser of floating point values. Parsed value may be bounded to a given upper and lower bound. @@ -203,7 +199,11 @@ class FloatParser(NumericParser): number_name = 'number' syntactic_help = ' '.join((number_article, number_name)) - def __init__(self, lower_bound=None, upper_bound=None): + def __init__( + self, + lower_bound: Optional[float] = None, + upper_bound: Optional[float] = None, + ) -> None: super(FloatParser, self).__init__() self.lower_bound = lower_bound self.upper_bound = upper_bound @@ -220,7 +220,7 @@ class FloatParser(NumericParser): sh = '%s >= %s' % (self.number_name, lower_bound) self.syntactic_help = sh - def convert(self, argument): + def convert(self, argument: Union[int, float, str]) -> float: """Returns the float value of argument.""" if (_is_integer_type(argument) or isinstance(argument, float) or isinstance(argument, str)): @@ -230,12 +230,12 @@ class FloatParser(NumericParser): 'Expect argument to be a string, int, or float, found {}'.format( type(argument))) - def flag_type(self): + def flag_type(self) -> Text: """See base class.""" return 'float' -class IntegerParser(NumericParser): +class IntegerParser(NumericParser[int]): """Parser of an integer value. Parsed value may be bounded to a given upper and lower bound. @@ -244,7 +244,9 @@ class IntegerParser(NumericParser): number_name = 'integer' syntactic_help = ' '.join((number_article, number_name)) - def __init__(self, lower_bound=None, upper_bound=None): + def __init__( + self, lower_bound: Optional[int] = None, upper_bound: Optional[int] = None + ) -> None: super(IntegerParser, self).__init__() self.lower_bound = lower_bound self.upper_bound = upper_bound @@ -265,7 +267,7 @@ class IntegerParser(NumericParser): sh = '%s >= %s' % (self.number_name, lower_bound) self.syntactic_help = sh - def convert(self, argument): + def convert(self, argument: Union[int, Text]) -> int: """Returns the int value of argument.""" if _is_integer_type(argument): return argument @@ -281,15 +283,15 @@ class IntegerParser(NumericParser): raise TypeError('Expect argument to be a string or int, found {}'.format( type(argument))) - def flag_type(self): + def flag_type(self) -> Text: """See base class.""" return 'int' -class BooleanParser(ArgumentParser): +class BooleanParser(ArgumentParser[bool]): """Parser of boolean values.""" - def parse(self, argument): + def parse(self, argument: Union[Text, int]) -> bool: """See base class.""" if isinstance(argument, str): if argument.lower() in ('true', 't', '1'): @@ -309,15 +311,17 @@ class BooleanParser(ArgumentParser): raise TypeError('Non-boolean argument to boolean flag', argument) - def flag_type(self): + def flag_type(self) -> Text: """See base class.""" return 'bool' -class EnumParser(ArgumentParser): +class EnumParser(ArgumentParser[Text]): """Parser of a string enum value (a string value from a given set).""" - def __init__(self, enum_values, case_sensitive=True): + def __init__( + self, enum_values: Iterable[Text], case_sensitive: bool = True + ) -> None: """Initializes EnumParser. Args: @@ -330,11 +334,15 @@ class EnumParser(ArgumentParser): if not enum_values: raise ValueError( 'enum_values cannot be empty, found "{}"'.format(enum_values)) + if isinstance(enum_values, str): + raise ValueError( + 'enum_values cannot be a str, found "{}"'.format(enum_values) + ) super(EnumParser, self).__init__() - self.enum_values = enum_values + self.enum_values = list(enum_values) self.case_sensitive = case_sensitive - def parse(self, argument): + def parse(self, argument: Text) -> Text: """Determines validity of argument and returns the correct element of enum. Args: @@ -360,15 +368,17 @@ class EnumParser(ArgumentParser): return [value for value in self.enum_values if value.upper() == argument.upper()][0] - def flag_type(self): + def flag_type(self) -> Text: """See base class.""" return 'string enum' -class EnumClassParser(ArgumentParser): +class EnumClassParser(ArgumentParser[_ET]): """Parser of an Enum class member.""" - def __init__(self, enum_class, case_sensitive=True): + def __init__( + self, enum_class: Type[_ET], case_sensitive: bool = True + ) -> None: """Initializes EnumParser. Args: @@ -380,10 +390,6 @@ class EnumClassParser(ArgumentParser): TypeError: When enum_class is not a subclass of Enum. ValueError: When enum_class is empty. """ - # Users must have an Enum class defined before using EnumClass flag. - # Therefore this dependency is guaranteed. - import enum - if not issubclass(enum_class, enum.Enum): raise TypeError('{} is not a subclass of Enum.'.format(enum_class)) if not enum_class.__members__: @@ -410,11 +416,11 @@ class EnumClassParser(ArgumentParser): name.lower() for name in enum_class.__members__) @property - def member_names(self): + def member_names(self) -> Sequence[Text]: """The accepted enum names, in lowercase if not case sensitive.""" return self._member_names - def parse(self, argument): + def parse(self, argument: Union[_ET, Text]) -> _ET: """Determines validity of argument and returns the correct element of enum. Args: @@ -427,7 +433,7 @@ class EnumClassParser(ArgumentParser): ValueError: Raised when argument didn't match anything in enum. """ if isinstance(argument, self.enum_class): - return argument + return argument # pytype: disable=bad-return-type elif not isinstance(argument, str): raise ValueError( '{} is not an enum member or a name of a member in {}'.format( @@ -442,29 +448,29 @@ class EnumClassParser(ArgumentParser): return next(value for name, value in self.enum_class.__members__.items() if name.lower() == key.lower()) - def flag_type(self): + def flag_type(self) -> Text: """See base class.""" return 'enum class' -class ListSerializer(ArgumentSerializer): +class ListSerializer(Generic[_T], ArgumentSerializer[List[_T]]): - def __init__(self, list_sep): + def __init__(self, list_sep: Text) -> None: self.list_sep = list_sep - def serialize(self, value): + def serialize(self, value: List[_T]) -> Text: """See base class.""" return self.list_sep.join([str(x) for x in value]) -class EnumClassListSerializer(ListSerializer): +class EnumClassListSerializer(ListSerializer[_ET]): """A serializer for :class:`MultiEnumClass` flags. This serializer simply joins the output of `EnumClassSerializer` using a provided separator. """ - def __init__(self, list_sep, **kwargs): + def __init__(self, list_sep: Text, **kwargs) -> None: """Initializes EnumClassListSerializer. Args: @@ -475,7 +481,7 @@ class EnumClassListSerializer(ListSerializer): super(EnumClassListSerializer, self).__init__(list_sep) self._element_serializer = EnumClassSerializer(**kwargs) - def serialize(self, value): + def serialize(self, value: Union[_ET, List[_ET]]) -> Text: """See base class.""" if isinstance(value, list): return self.list_sep.join( @@ -484,12 +490,9 @@ class EnumClassListSerializer(ListSerializer): return self._element_serializer.serialize(value) -class CsvListSerializer(ArgumentSerializer): - - def __init__(self, list_sep): - self.list_sep = list_sep +class CsvListSerializer(ListSerializer[Text]): - def serialize(self, value): + def serialize(self, value: List[Text]) -> Text: """Serializes a list as a CSV string or unicode.""" output = io.StringIO() writer = csv.writer(output, delimiter=self.list_sep) @@ -504,7 +507,7 @@ class CsvListSerializer(ArgumentSerializer): class EnumClassSerializer(ArgumentSerializer): """Class for generating string representations of an enum class flag value.""" - def __init__(self, lowercase): + def __init__(self, lowercase: bool) -> None: """Initializes EnumClassSerializer. Args: @@ -512,7 +515,7 @@ class EnumClassSerializer(ArgumentSerializer): """ self._lowercase = lowercase - def serialize(self, value): + def serialize(self, value: _ET) -> Text: """Returns a serialized string of the Enum class value.""" as_string = str(value.name) return as_string.lower() if self._lowercase else as_string @@ -529,14 +532,16 @@ class BaseListParser(ArgumentParser): of the separator. """ - def __init__(self, token=None, name=None): + def __init__( + self, token: Optional[Text] = None, name: Optional[Text] = None + ) -> None: assert name super(BaseListParser, self).__init__() self._token = token self._name = name self.syntactic_help = 'a %s separated list' % self._name - def parse(self, argument): + def parse(self, argument: Text) -> List[Text]: """See base class.""" if isinstance(argument, list): return argument @@ -545,7 +550,7 @@ class BaseListParser(ArgumentParser): else: return [s.strip() for s in argument.split(self._token)] - def flag_type(self): + def flag_type(self) -> Text: """See base class.""" return '%s separated list of strings' % self._name @@ -553,10 +558,10 @@ class BaseListParser(ArgumentParser): class ListParser(BaseListParser): """Parser for a comma-separated list of strings.""" - def __init__(self): + def __init__(self) -> None: super(ListParser, self).__init__(',', 'comma') - def parse(self, argument): + def parse(self, argument: Union[Text, List[Text]]) -> List[Text]: """Parses argument as comma-separated list of strings.""" if isinstance(argument, list): return argument @@ -574,7 +579,9 @@ class ListParser(BaseListParser): raise ValueError('Unable to parse the value %r as a %s: %s' % (argument, self.flag_type(), e)) - def _custom_xml_dom_elements(self, doc): + def _custom_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: elements = super(ListParser, self)._custom_xml_dom_elements(doc) elements.append(_helpers.create_xml_dom_element( doc, 'list_separator', repr(','))) @@ -584,7 +591,7 @@ class ListParser(BaseListParser): class WhitespaceSeparatedListParser(BaseListParser): """Parser for a whitespace-separated list of strings.""" - def __init__(self, comma_compat=False): + def __init__(self, comma_compat: bool = False) -> None: """Initializer. Args: @@ -596,7 +603,7 @@ class WhitespaceSeparatedListParser(BaseListParser): name = 'whitespace or comma' if self._comma_compat else 'whitespace' super(WhitespaceSeparatedListParser, self).__init__(None, name) - def parse(self, argument): + def parse(self, argument: Union[Text, List[Text]]) -> List[Text]: """Parses argument as whitespace-separated list of strings. It also parses argument as comma-separated list of strings if requested. @@ -616,7 +623,9 @@ class WhitespaceSeparatedListParser(BaseListParser): argument = argument.replace(',', ' ') return argument.split() - def _custom_xml_dom_elements(self, doc): + def _custom_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: elements = super(WhitespaceSeparatedListParser, self )._custom_xml_dom_elements(doc) separators = list(string.whitespace) diff --git a/absl/flags/_argument_parser.pyi b/absl/flags/_argument_parser.pyi deleted file mode 100644 index 7e78d7d..0000000 --- a/absl/flags/_argument_parser.pyi +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2020 The Abseil Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Contains type annotations for _argument_parser.py.""" - - -from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Sequence, Any - -import enum - -_T = TypeVar('_T') -_ET = TypeVar('_ET', bound=enum.Enum) - - -class ArgumentSerializer(Generic[_T]): - def serialize(self, value: _T) -> Text: ... - - -# The metaclass of ArgumentParser is not reflected here, because it does not -# affect the provided API. -class ArgumentParser(Generic[_T]): - - syntactic_help: Text - - def parse(self, argument: Text) -> Optional[_T]: ... - - def flag_type(self) -> Text: ... - - -# Using bound=numbers.Number results in an error: b/153268436 -_N = TypeVar('_N', int, float) - - -class NumericParser(ArgumentParser[_N]): - - def is_outside_bounds(self, val: _N) -> bool: ... - - def parse(self, argument: Text) -> _N: ... - - def convert(self, argument: Text) -> _N: ... - - -class FloatParser(NumericParser[float]): - - def __init__(self, lower_bound:Optional[float]=None, - upper_bound:Optional[float]=None) -> None: - ... - - -class IntegerParser(NumericParser[int]): - - def __init__(self, lower_bound:Optional[int]=None, - upper_bound:Optional[int]=None) -> None: - ... - - -class BooleanParser(ArgumentParser[bool]): - ... - - -class EnumParser(ArgumentParser[Text]): - def __init__(self, enum_values: Sequence[Text], case_sensitive: bool=...) -> None: - ... - - - -class EnumClassParser(ArgumentParser[_ET]): - - def __init__(self, enum_class: Type[_ET], case_sensitive: bool=...) -> None: - ... - - @property - def member_names(self) -> Sequence[Text]: ... - - -class BaseListParser(ArgumentParser[List[Text]]): - def __init__(self, token: Text, name:Text) -> None: ... - - # Unlike baseclass BaseListParser never returns None. - def parse(self, argument: Text) -> List[Text]: ... - - - -class ListParser(BaseListParser): - def __init__(self) -> None: - ... - - - -class WhitespaceSeparatedListParser(BaseListParser): - def __init__(self, comma_compat: bool=False) -> None: - ... - - - -class ListSerializer(ArgumentSerializer[List[Text]]): - list_sep = ... # type: Text - - def __init__(self, list_sep: Text) -> None: - ... - - -class EnumClassListSerializer(ArgumentSerializer[List[Text]]): - def __init__(self, list_sep: Text, **kwargs: Any) -> None: - ... - - -class CsvListSerializer(ArgumentSerializer[List[Any]]): - - def __init__(self, list_sep: Text) -> None: - ... - - -class EnumClassSerializer(ArgumentSerializer[_ET]): - def __init__(self, lowercase: bool) -> None: - ... diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index 61354e9..6dd87d6 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -17,8 +17,11 @@ Do NOT import this module directly. Import the flags package and use the aliases defined at the package level instead. """ +import enum import sys import types +import typing +from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload from absl.flags import _argument_parser from absl.flags import _exceptions @@ -27,20 +30,11 @@ from absl.flags import _flagvalues from absl.flags import _helpers from absl.flags import _validators -# pylint: disable=unused-import -try: - from typing import Text, List, Any -except ImportError: - pass - -try: - import enum -except ImportError: - pass -# pylint: enable=unused-import - _helpers.disclaim_module_ids.add(id(sys.modules[__name__])) +_T = TypeVar('_T') +_ET = TypeVar('_ET', bound=enum.Enum) + def _register_bounds_validator_if_needed(parser, name, flag_values): """Enforces lower and upper bounds for numeric flags. @@ -62,6 +56,36 @@ def _register_bounds_validator_if_needed(parser, name, flag_values): _validators.register_validator(name, checker, flag_values=flag_values) +@overload +def DEFINE( # pylint: disable=invalid-name + parser: _argument_parser.ArgumentParser[_T], + name: Text, + default: Any, + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ..., + module_name: Optional[Text] = ..., + required: 'typing.Literal[True]' = ..., + **args: Any +) -> _flagvalues.FlagHolder[_T]: + ... + + +@overload +def DEFINE( # pylint: disable=invalid-name + parser: _argument_parser.ArgumentParser[_T], + name: Text, + default: Optional[Any], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[_T]]: + ... + + def DEFINE( # pylint: disable=invalid-name parser, name, @@ -98,8 +122,31 @@ def DEFINE( # pylint: disable=invalid-name a handle to defined flag. """ return DEFINE_flag( - _flag.Flag(parser, serializer, name, default, help, **args), flag_values, - module_name, required) + _flag.Flag(parser, serializer, name, default, help, **args), + flag_values, + module_name, + required=True if required else False, + ) + + +@overload +def DEFINE_flag( # pylint: disable=invalid-name + flag: _flag.Flag[_T], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: 'typing.Literal[True]' = ..., +) -> _flagvalues.FlagHolder[_T]: + ... + + +@overload +def DEFINE_flag( # pylint: disable=invalid-name + flag: _flag.Flag[_T], + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., +) -> _flagvalues.FlagHolder[Optional[_T]]: + ... def DEFINE_flag( # pylint: disable=invalid-name @@ -148,7 +195,7 @@ def DEFINE_flag( # pylint: disable=invalid-name fv, flag, ensure_non_none_value=ensure_non_none_value) -def set_default(flag_holder, value): +def set_default(flag_holder: _flagvalues.FlagHolder[_T], value: _T) -> None: """Changes the default value of the provided flag object. The flag's current value is also updated if the flag is currently using @@ -165,9 +212,11 @@ def set_default(flag_holder, value): flag_holder._flagvalues.set_default(flag_holder.name, value) # pylint: disable=protected-access -def _internal_declare_key_flags(flag_names, - flag_values=_flagvalues.FLAGS, - key_flag_values=None): +def _internal_declare_key_flags( + flag_names: List[str], + flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS, + key_flag_values: Optional[_flagvalues.FlagValues] = None, +) -> None: """Declares a flag as key for the calling module. Internal function. User code should call declare_key_flag or @@ -195,7 +244,10 @@ def _internal_declare_key_flags(flag_names, key_flag_values.register_key_flag_for_module(module, flag_values[flag_name]) -def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): +def declare_key_flag( + flag_name: Union[Text, _flagvalues.FlagHolder], + flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS, +) -> None: """Declares one flag as key to the current module. Key flags are flags that are deemed really important for a module. @@ -237,7 +289,9 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): 'first define it in Python.' % flag_name) -def adopt_module_key_flags(module, flag_values=_flagvalues.FLAGS): +def adopt_module_key_flags( + module: Any, flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS +) -> None: """Declares that all flags key to a module are key to the current module. Args: @@ -269,7 +323,7 @@ def adopt_module_key_flags(module, flag_values=_flagvalues.FLAGS): key_flag_values=flag_values) -def disclaim_key_flags(): +def disclaim_key_flags() -> None: """Declares that the current module will not define any more key flags. Normally, the module that calls the DEFINE_xxx functions claims the @@ -288,6 +342,43 @@ def disclaim_key_flags(): _helpers.disclaim_module_ids.add(id(module)) +@overload +def DEFINE_string( # pylint: disable=invalid-name + name: Text, + default: Optional[Text], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[Text]: + ... + + +@overload +def DEFINE_string( # pylint: disable=invalid-name + name: Text, + default: None, + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[Text]]: + ... + + +@overload +def DEFINE_string( # pylint: disable=invalid-name + name: Text, + default: Text, + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Text]: + ... + + def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin name, default, @@ -296,8 +387,8 @@ def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin required=False, **args): """Registers a flag whose value can be any string.""" - parser = _argument_parser.ArgumentParser() - serializer = _argument_parser.ArgumentSerializer() + parser = _argument_parser.ArgumentParser[str]() + serializer = _argument_parser.ArgumentSerializer[str]() return DEFINE( parser, name, @@ -305,8 +396,49 @@ def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin help, flag_values, serializer, - required=required, - **args) + required=True if required else False, + **args, + ) + + +@overload +def DEFINE_boolean( # pylint: disable=invalid-name + name: Text, + default: Union[None, Text, bool, int], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[bool]: + ... + + +@overload +def DEFINE_boolean( # pylint: disable=invalid-name + name: Text, + default: None, + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[bool]]: + ... + + +@overload +def DEFINE_boolean( # pylint: disable=invalid-name + name: Text, + default: Union[Text, bool, int], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[bool]: + ... def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin @@ -343,8 +475,54 @@ def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin a handle to defined flag. """ return DEFINE_flag( - _flag.BooleanFlag(name, default, help, **args), flag_values, module_name, - required) + _flag.BooleanFlag(name, default, help, **args), + flag_values, + module_name, + required=True if required else False, + ) + + +@overload +def DEFINE_float( # pylint: disable=invalid-name + name: Text, + default: Union[None, float, Text], + help: Optional[Text], # pylint: disable=redefined-builtin + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[float]: + ... + + +@overload +def DEFINE_float( # pylint: disable=invalid-name + name: Text, + default: None, + help: Optional[Text], # pylint: disable=redefined-builtin + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[float]]: + ... + + +@overload +def DEFINE_float( # pylint: disable=invalid-name + name: Text, + default: Union[float, Text], + help: Optional[Text], # pylint: disable=redefined-builtin + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[float]: + ... def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin @@ -385,12 +563,56 @@ def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin help, flag_values, serializer, - required=required, - **args) + required=True if required else False, + **args, + ) _register_bounds_validator_if_needed(parser, name, flag_values=flag_values) return result +@overload +def DEFINE_integer( # pylint: disable=invalid-name + name: Text, + default: Union[None, int, Text], + help: Optional[Text], # pylint: disable=redefined-builtin + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[int]: + ... + + +@overload +def DEFINE_integer( # pylint: disable=invalid-name + name: Text, + default: None, + help: Optional[Text], # pylint: disable=redefined-builtin + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[int]]: + ... + + +@overload +def DEFINE_integer( # pylint: disable=invalid-name + name: Text, + default: Union[int, Text], + help: Optional[Text], # pylint: disable=redefined-builtin + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[int]: + ... + + def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin name, default, @@ -429,12 +651,56 @@ def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin help, flag_values, serializer, - required=required, - **args) + required=True if required else False, + **args, + ) _register_bounds_validator_if_needed(parser, name, flag_values=flag_values) return result +@overload +def DEFINE_enum( # pylint: disable=invalid-name + name: Text, + default: Optional[Text], + enum_values: Iterable[Text], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[Text]: + ... + + +@overload +def DEFINE_enum( # pylint: disable=invalid-name + name: Text, + default: None, + enum_values: Iterable[Text], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[Text]]: + ... + + +@overload +def DEFINE_enum( # pylint: disable=invalid-name + name: Text, + default: Text, + enum_values: Iterable[Text], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Text]: + ... + + def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin name, default, @@ -466,9 +732,59 @@ def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin Returns: a handle to defined flag. """ - return DEFINE_flag( - _flag.EnumFlag(name, default, help, enum_values, **args), flag_values, - module_name, required) + result = DEFINE_flag( + _flag.EnumFlag(name, default, help, enum_values, **args), + flag_values, + module_name, + required=True if required else False, + ) + return result + + +@overload +def DEFINE_enum_class( # pylint: disable=invalid-name + name: Text, + default: Union[None, _ET, Text], + enum_class: Type[_ET], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + case_sensitive: bool = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[_ET]: + ... + + +@overload +def DEFINE_enum_class( # pylint: disable=invalid-name + name: Text, + default: None, + enum_class: Type[_ET], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + case_sensitive: bool = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[_ET]]: + ... + + +@overload +def DEFINE_enum_class( # pylint: disable=invalid-name + name: Text, + default: Union[_ET, Text], + enum_class: Type[_ET], + help: Optional[Text], # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + case_sensitive: bool = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[_ET]: + ... def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin @@ -501,14 +817,53 @@ def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin Returns: a handle to defined flag. """ - return DEFINE_flag( + # NOTE: pytype fails if this is a direct return. + result = DEFINE_flag( _flag.EnumClassFlag( - name, - default, - help, - enum_class, - case_sensitive=case_sensitive, - **args), flag_values, module_name, required) + name, default, help, enum_class, case_sensitive=case_sensitive, **args + ), + flag_values, + module_name, + required=True if required else False, + ) + return result + + +@overload +def DEFINE_list( # pylint: disable=invalid-name + name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[Text]]: + ... + + +@overload +def DEFINE_list( # pylint: disable=invalid-name + name: Text, + default: None, + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[List[Text]]]: + ... + + +@overload +def DEFINE_list( # pylint: disable=invalid-name + name: Text, + default: Union[Iterable[Text], Text], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[Text]]: + ... def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin @@ -545,8 +900,49 @@ def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin help, flag_values, serializer, - required=required, - **args) + required=True if required else False, + **args, + ) + + +@overload +def DEFINE_spaceseplist( # pylint: disable=invalid-name + name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, # pylint: disable=redefined-builtin + comma_compat: bool = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[Text]]: + ... + + +@overload +def DEFINE_spaceseplist( # pylint: disable=invalid-name + name: Text, + default: None, + help: Text, # pylint: disable=redefined-builtin + comma_compat: bool = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[List[Text]]]: + ... + + +@overload +def DEFINE_spaceseplist( # pylint: disable=invalid-name + name: Text, + default: Union[Iterable[Text], Text], + help: Text, # pylint: disable=redefined-builtin + comma_compat: bool = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[Text]]: + ... def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin @@ -588,8 +984,86 @@ def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin help, flag_values, serializer, - required=required, - **args) + required=True if required else False, + **args, + ) + + +@overload +def DEFINE_multi( # pylint: disable=invalid-name + parser: _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, + default: Iterable[_T], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[_T]]: + ... + + +@overload +def DEFINE_multi( # pylint: disable=invalid-name + parser: _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, + default: Union[None, _T], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[_T]]: + ... + + +@overload +def DEFINE_multi( # pylint: disable=invalid-name + parser: _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, + default: None, + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[List[_T]]]: + ... + + +@overload +def DEFINE_multi( # pylint: disable=invalid-name + parser: _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, + default: Iterable[_T], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[_T]]: + ... + + +@overload +def DEFINE_multi( # pylint: disable=invalid-name + parser: _argument_parser.ArgumentParser[_T], + serializer: _argument_parser.ArgumentSerializer[_T], + name: Text, + default: _T, + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[_T]]: + ... def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin @@ -632,9 +1106,50 @@ def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin Returns: a handle to defined flag. """ - return DEFINE_flag( + result = DEFINE_flag( _flag.MultiFlag(parser, serializer, name, default, help, **args), - flag_values, module_name, required) + flag_values, + module_name, + required=True if required else False, + ) + return result + + +@overload +def DEFINE_multi_string( # pylint: disable=invalid-name + name: Text, + default: Union[None, Iterable[Text], Text], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[Text]]: + ... + + +@overload +def DEFINE_multi_string( # pylint: disable=invalid-name + name: Text, + default: None, + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[List[Text]]]: + ... + + +@overload +def DEFINE_multi_string( # pylint: disable=invalid-name + name: Text, + default: Union[Iterable[Text], Text], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[Text]]: + ... def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin @@ -676,8 +1191,52 @@ def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin default, help, flag_values, - required=required, - **args) + required=True if required else False, + **args, + ) + + +@overload +def DEFINE_multi_integer( # pylint: disable=invalid-name + name: Text, + default: Union[None, Iterable[int], int, Text], + help: Text, # pylint: disable=redefined-builtin + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[int]]: + ... + + +@overload +def DEFINE_multi_integer( # pylint: disable=invalid-name + name: Text, + default: None, + help: Text, # pylint: disable=redefined-builtin + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[List[int]]]: + ... + + +@overload +def DEFINE_multi_integer( # pylint: disable=invalid-name + name: Text, + default: Union[Iterable[int], int, Text], + help: Text, # pylint: disable=redefined-builtin + lower_bound: Optional[int] = ..., + upper_bound: Optional[int] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[int]]: + ... def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin @@ -722,8 +1281,52 @@ def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin default, help, flag_values, - required=required, - **args) + required=True if required else False, + **args, + ) + + +@overload +def DEFINE_multi_float( # pylint: disable=invalid-name + name: Text, + default: Union[None, Iterable[float], float, Text], + help: Text, # pylint: disable=redefined-builtin + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[float]]: + ... + + +@overload +def DEFINE_multi_float( # pylint: disable=invalid-name + name: Text, + default: None, + help: Text, # pylint: disable=redefined-builtin + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[List[float]]]: + ... + + +@overload +def DEFINE_multi_float( # pylint: disable=invalid-name + name: Text, + default: Union[Iterable[float], float, Text], + help: Text, # pylint: disable=redefined-builtin + lower_bound: Optional[float] = ..., + upper_bound: Optional[float] = ..., + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[float]]: + ... def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin @@ -768,8 +1371,49 @@ def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin default, help, flag_values, - required=required, - **args) + required=True if required else False, + **args, + ) + + +@overload +def DEFINE_multi_enum( # pylint: disable=invalid-name + name: Text, + default: Union[None, Iterable[Text], Text], + enum_values: Iterable[Text], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[Text]]: + ... + + +@overload +def DEFINE_multi_enum( # pylint: disable=invalid-name + name: Text, + default: None, + enum_values: Iterable[Text], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[List[Text]]]: + ... + + +@overload +def DEFINE_multi_enum( # pylint: disable=invalid-name + name: Text, + default: Union[Iterable[Text], Text], + enum_values: Iterable[Text], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[Text]]: + ... def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin @@ -815,8 +1459,89 @@ def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin default, '<%s>: %s' % ('|'.join(enum_values), help), flag_values, - required=required, - **args) + required=True if required else False, + **args, + ) + + +@overload +def DEFINE_multi_enum_class( # pylint: disable=invalid-name + name: Text, + # This is separate from `Union[None, _ET, Iterable[Text], Text]` to avoid a + # Pytype issue inferring the return value to + # FlagHolder[List[Union[_ET, enum.Enum]]] when an iterable of concrete enum + # subclasses are used. + default: Iterable[_ET], + enum_class: Type[_ET], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[_ET]]: + ... + + +@overload +def DEFINE_multi_enum_class( # pylint: disable=invalid-name + name: Text, + default: Union[None, _ET, Iterable[Text], Text], + enum_class: Type[_ET], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + *, + required: 'typing.Literal[True]', + **args: Any +) -> _flagvalues.FlagHolder[List[_ET]]: + ... + + +@overload +def DEFINE_multi_enum_class( # pylint: disable=invalid-name + name: Text, + default: None, + enum_class: Type[_ET], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[Optional[List[_ET]]]: + ... + + +@overload +def DEFINE_multi_enum_class( # pylint: disable=invalid-name + name: Text, + # This is separate from `Union[None, _ET, Iterable[Text], Text]` to avoid a + # Pytype issue inferring the return value to + # FlagHolder[List[Union[_ET, enum.Enum]]] when an iterable of concrete enum + # subclasses are used. + default: Iterable[_ET], + enum_class: Type[_ET], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[_ET]]: + ... + + +@overload +def DEFINE_multi_enum_class( # pylint: disable=invalid-name + name: Text, + default: Union[_ET, Iterable[Text], Text], + enum_class: Type[_ET], + help: Text, # pylint: disable=redefined-builtin + flag_values: _flagvalues.FlagValues = ..., + module_name: Optional[Text] = ..., + required: bool = ..., + **args: Any +) -> _flagvalues.FlagHolder[List[_ET]]: + ... def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin @@ -857,7 +1582,8 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin Returns: a handle to defined flag. """ - return DEFINE_flag( + # NOTE: pytype fails if this is a direct return. + result = DEFINE_flag( _flag.MultiEnumClassFlag( name, default, @@ -868,15 +1594,17 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin ), flag_values, module_name, - required=required, + required=True if required else False, ) + return result def DEFINE_alias( # pylint: disable=invalid-name - name, - original_name, - flag_values=_flagvalues.FLAGS, - module_name=None): + name: Text, + original_name: Text, + flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS, + module_name: Optional[Text] = None, +) -> _flagvalues.FlagHolder[Any]: """Defines an alias flag for an existing one. Args: diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi deleted file mode 100644 index 9bc8067..0000000 --- a/absl/flags/_defines.pyi +++ /dev/null @@ -1,670 +0,0 @@ -# Copyright 2020 The Abseil Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""This modules contains type annotated stubs for DEFINE functions.""" - - -from absl.flags import _argument_parser -from absl.flags import _flag -from absl.flags import _flagvalues - -import enum - -from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload, Literal - -_T = TypeVar('_T') -_ET = TypeVar('_ET', bound=enum.Enum) - - -@overload -def DEFINE( - parser: _argument_parser.ArgumentParser[_T], - name: Text, - default: Any, - help: Optional[Text], - flag_values : _flagvalues.FlagValues = ..., - serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ..., - module_name: Optional[Text] = ..., - required: Literal[True] = ..., - **args: Any) -> _flagvalues.FlagHolder[_T]: - ... - - -@overload -def DEFINE( - parser: _argument_parser.ArgumentParser[_T], - name: Text, - default: Any, - help: Optional[Text], - flag_values : _flagvalues.FlagValues = ..., - serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[_T]]: - ... - - -@overload -def DEFINE_flag( - flag: _flag.Flag[_T], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: Literal[True] = ... -) -> _flagvalues.FlagHolder[_T]: - ... - -@overload -def DEFINE_flag( - flag: _flag.Flag[_T], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ...) -> _flagvalues.FlagHolder[Optional[_T]]: - ... - -# typing overloads for DEFINE_* methods... -# -# - DEFINE_* method return FlagHolder[Optional[T]] or FlagHolder[T] depending -# on the arguments. -# - If the flag value is guaranteed to be not None, the return type is -# FlagHolder[T]. -# - If the flag is required OR has a non-None default, the flag value i -# guaranteed to be not None after flag parsing has finished. -# The information above is captured with three overloads as follows. -# -# (if required=True and passed in as a keyword argument, -# return type is FlagHolder[Y]) -# @overload -# def DEFINE_xxx( -# ... arguments... -# default: Union[None, X] = ..., -# *, -# required: Literal[True]) -> _flagvalues.FlagHolder[Y]: -# ... -# -# (if default=None, return type is FlagHolder[Optional[Y]]) -# @overload -# def DEFINE_xxx( -# ... arguments... -# default: None, -# required: bool = ...) -> _flagvalues.FlagHolder[Optional[Y]]: -# ... -# -# (if default!=None, return type is FlagHolder[Y]): -# @overload -# def DEFINE_xxx( -# ... arguments... -# default: X, -# required: bool = ...) -> _flagvalues.FlagHolder[Y]: -# ... -# -# where X = type of non-None default values for the flag -# and Y = non-None type for flag value - -@overload -def DEFINE_string( - name: Text, - default: Optional[Text], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[Text]: - ... - -@overload -def DEFINE_string( - name: Text, - default: None, - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[Text]]: - ... - -@overload -def DEFINE_string( - name: Text, - default: Text, - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Text]: - ... - -@overload -def DEFINE_boolean( - name : Text, - default: Union[None, Text, bool, int], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[bool]: - ... - -@overload -def DEFINE_boolean( - name : Text, - default: None, - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[bool]]: - ... - -@overload -def DEFINE_boolean( - name : Text, - default: Union[Text, bool, int], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[bool]: - ... - -@overload -def DEFINE_float( - name: Text, - default: Union[None, float, Text], - help: Optional[Text], - lower_bound: Optional[float] = ..., - upper_bound: Optional[float] = ..., - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[float]: - ... - -@overload -def DEFINE_float( - name: Text, - default: None, - help: Optional[Text], - lower_bound: Optional[float] = ..., - upper_bound: Optional[float] = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[float]]: - ... - -@overload -def DEFINE_float( - name: Text, - default: Union[float, Text], - help: Optional[Text], - lower_bound: Optional[float] = ..., - upper_bound: Optional[float] = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[float]: - ... - - -@overload -def DEFINE_integer( - name: Text, - default: Union[None, int, Text], - help: Optional[Text], - lower_bound: Optional[int] = ..., - upper_bound: Optional[int] = ..., - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[int]: - ... - -@overload -def DEFINE_integer( - name: Text, - default: None, - help: Optional[Text], - lower_bound: Optional[int] = ..., - upper_bound: Optional[int] = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[int]]: - ... - -@overload -def DEFINE_integer( - name: Text, - default: Union[int, Text], - help: Optional[Text], - lower_bound: Optional[int] = ..., - upper_bound: Optional[int] = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[int]: - ... - -@overload -def DEFINE_enum( - name : Text, - default: Optional[Text], - enum_values: Iterable[Text], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[Text]: - ... - -@overload -def DEFINE_enum( - name : Text, - default: None, - enum_values: Iterable[Text], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[Text]]: - ... - -@overload -def DEFINE_enum( - name : Text, - default: Text, - enum_values: Iterable[Text], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Text]: - ... - -@overload -def DEFINE_enum_class( - name: Text, - default: Union[None, _ET, Text], - enum_class: Type[_ET], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - case_sensitive: bool = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[_ET]: - ... - -@overload -def DEFINE_enum_class( - name: Text, - default: None, - enum_class: Type[_ET], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - case_sensitive: bool = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[_ET]]: - ... - -@overload -def DEFINE_enum_class( - name: Text, - default: Union[_ET, Text], - enum_class: Type[_ET], - help: Optional[Text], - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - case_sensitive: bool = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[_ET]: - ... - - -@overload -def DEFINE_list( - name: Text, - default: Union[None, Iterable[Text], Text], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[Text]]: - ... - -@overload -def DEFINE_list( - name: Text, - default: None, - help: Text, - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: - ... - -@overload -def DEFINE_list( - name: Text, - default: Union[Iterable[Text], Text], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[Text]]: - ... - -@overload -def DEFINE_spaceseplist( - name: Text, - default: Union[None, Iterable[Text], Text], - help: Text, - comma_compat: bool = ..., - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[Text]]: - ... - -@overload -def DEFINE_spaceseplist( - name: Text, - default: None, - help: Text, - comma_compat: bool = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: - ... - -@overload -def DEFINE_spaceseplist( - name: Text, - default: Union[Iterable[Text], Text], - help: Text, - comma_compat: bool = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[Text]]: - ... - -@overload -def DEFINE_multi( - parser : _argument_parser.ArgumentParser[_T], - serializer: _argument_parser.ArgumentSerializer[_T], - name: Text, - default: Union[None, Iterable[_T], _T, Text], - help: Text, - flag_values:_flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[_T]]: - ... - -@overload -def DEFINE_multi( - parser : _argument_parser.ArgumentParser[_T], - serializer: _argument_parser.ArgumentSerializer[_T], - name: Text, - default: None, - help: Text, - flag_values:_flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[List[_T]]]: - ... - -@overload -def DEFINE_multi( - parser : _argument_parser.ArgumentParser[_T], - serializer: _argument_parser.ArgumentSerializer[_T], - name: Text, - default: Union[Iterable[_T], _T, Text], - help: Text, - flag_values:_flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[_T]]: - ... - -@overload -def DEFINE_multi_string( - name: Text, - default: Union[None, Iterable[Text], Text], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[Text]]: - ... - -@overload -def DEFINE_multi_string( - name: Text, - default: None, - help: Text, - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: - ... - -@overload -def DEFINE_multi_string( - name: Text, - default: Union[Iterable[Text], Text], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[Text]]: - ... - -@overload -def DEFINE_multi_integer( - name: Text, - default: Union[None, Iterable[int], int, Text], - help: Text, - lower_bound: Optional[int] = ..., - upper_bound: Optional[int] = ..., - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[int]]: - ... - -@overload -def DEFINE_multi_integer( - name: Text, - default: None, - help: Text, - lower_bound: Optional[int] = ..., - upper_bound: Optional[int] = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[List[int]]]: - ... - -@overload -def DEFINE_multi_integer( - name: Text, - default: Union[Iterable[int], int, Text], - help: Text, - lower_bound: Optional[int] = ..., - upper_bound: Optional[int] = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[int]]: - ... - -@overload -def DEFINE_multi_float( - name: Text, - default: Union[None, Iterable[float], float, Text], - help: Text, - lower_bound: Optional[float] = ..., - upper_bound: Optional[float] = ..., - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[float]]: - ... - -@overload -def DEFINE_multi_float( - name: Text, - default: None, - help: Text, - lower_bound: Optional[float] = ..., - upper_bound: Optional[float] = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[List[float]]]: - ... - -@overload -def DEFINE_multi_float( - name: Text, - default: Union[Iterable[float], float, Text], - help: Text, - lower_bound: Optional[float] = ..., - upper_bound: Optional[float] = ..., - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[float]]: - ... - - -@overload -def DEFINE_multi_enum( - name: Text, - default: Union[None, Iterable[Text], Text], - enum_values: Iterable[Text], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[Text]]: - ... - -@overload -def DEFINE_multi_enum( - name: Text, - default: None, - enum_values: Iterable[Text], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[List[Text]]]: - ... - -@overload -def DEFINE_multi_enum( - name: Text, - default: Union[Iterable[Text], Text], - enum_values: Iterable[Text], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[Text]]: - ... - -@overload -def DEFINE_multi_enum_class( - name: Text, - # This is separate from `Union[None, _ET, Text]` to avoid a Pytype issue - # inferring the return value to FlagHolder[List[Union[_ET, enum.Enum]]] - # when an iterable of concrete enum subclasses are used. - default: Iterable[_ET], - enum_class: Type[_ET], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[_ET]]: - ... - -@overload -def DEFINE_multi_enum_class( - name: Text, - default: Union[None, _ET, Text], - enum_class: Type[_ET], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - *, - required: Literal[True], - **args: Any) -> _flagvalues.FlagHolder[List[_ET]]: - ... - -@overload -def DEFINE_multi_enum_class( - name: Text, - default: None, - enum_class: Type[_ET], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[Optional[List[_ET]]]: - ... - -@overload -def DEFINE_multi_enum_class( - name: Text, - # This is separate from `Union[None, _ET, Text]` to avoid a Pytype issue - # inferring the return value to FlagHolder[List[Union[_ET, enum.Enum]]] - # when an iterable of concrete enum subclasses are used. - default: Iterable[_ET], - enum_class: Type[_ET], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[_ET]]: - ... - -@overload -def DEFINE_multi_enum_class( - name: Text, - default: Union[_ET, Text], - enum_class: Type[_ET], - help: Text, - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ..., - required: bool = ..., - **args: Any) -> _flagvalues.FlagHolder[List[_ET]]: - ... - - -def DEFINE_alias( - name: Text, - original_name: Text, - flag_values: _flagvalues.FlagValues = ..., - module_name: Optional[Text] = ...) -> _flagvalues.FlagHolder[Any]: - ... - - -def set_default(flag_holder: _flagvalues.FlagHolder[_T], value: _T) -> None: - ... - - -def declare_key_flag(flag_name: Union[Text, _flagvalues.FlagHolder], - flag_values: _flagvalues.FlagValues = ...) -> None: - ... - - - -def adopt_module_key_flags(module: Any, - flag_values: _flagvalues.FlagValues = ...) -> None: - ... - - - -def disclaim_key_flags() -> None: - ... diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py index 124f137..6711788 100644 --- a/absl/flags/_flag.py +++ b/absl/flags/_flag.py @@ -20,15 +20,21 @@ aliases defined at the package level instead. from collections import abc import copy +import enum import functools +from typing import Any, Dict, Generic, Iterable, List, Optional, Text, Type, TypeVar, Union +from xml.dom import minidom from absl.flags import _argument_parser from absl.flags import _exceptions from absl.flags import _helpers +_T = TypeVar('_T') +_ET = TypeVar('_ET', bound=enum.Enum) + @functools.total_ordering -class Flag(object): +class Flag(Generic[_T]): """Information about a command-line flag. Attributes: @@ -76,10 +82,26 @@ class Flag(object): string, so it is important that it be a legal value for this flag. """ - def __init__(self, parser, serializer, name, default, help_string, - short_name=None, boolean=False, allow_override=False, - allow_override_cpp=False, allow_hide_cpp=False, - allow_overwrite=True, allow_using_method_names=False): + # NOTE: pytype doesn't find defaults without this. + default: Optional[_T] + default_as_str: Optional[Text] + default_unparsed: Union[Optional[_T], Text] + + def __init__( + self, + parser: _argument_parser.ArgumentParser[_T], + serializer: Optional[_argument_parser.ArgumentSerializer[_T]], + name: Text, + default: Union[Optional[_T], Text], + help_string: Optional[Text], + short_name: Optional[Text] = None, + boolean: bool = False, + allow_override: bool = False, + allow_override_cpp: bool = False, + allow_hide_cpp: bool = False, + allow_overwrite: bool = True, + allow_using_method_names: bool = False, + ) -> None: self.name = name if not help_string: @@ -108,11 +130,11 @@ class Flag(object): self._set_default(default) @property - def value(self): + def value(self) -> Optional[_T]: return self._value @value.setter - def value(self, value): + def value(self, value: Optional[_T]): self._value = value def __hash__(self): @@ -137,12 +159,12 @@ class Flag(object): raise TypeError('%s does not support shallow copies. ' 'Use copy.deepcopy instead.' % type(self).__name__) - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Dict[int, Any]) -> 'Flag[_T]': result = object.__new__(type(self)) result.__dict__ = copy.deepcopy(self.__dict__, memo) return result - def _get_parsed_value_as_string(self, value): + def _get_parsed_value_as_string(self, value: Optional[_T]) -> Optional[Text]: """Returns parsed flag value as string.""" if value is None: return None @@ -155,7 +177,7 @@ class Flag(object): return repr('false') return repr(str(value)) - def parse(self, argument): + def parse(self, argument: Union[Text, Optional[_T]]) -> None: """Parses string and sets flag value. Args: @@ -168,7 +190,7 @@ class Flag(object): self.value = self._parse(argument) self.present += 1 - def _parse(self, argument): + def _parse(self, argument: Union[Text, _T]) -> Optional[_T]: """Internal parse function. It returns the parsed value, and does not modify class states. @@ -185,16 +207,16 @@ class Flag(object): raise _exceptions.IllegalFlagValueError( 'flag --%s=%s: %s' % (self.name, argument, e)) - def unparse(self): + def unparse(self) -> None: self.value = self.default self.using_default_value = True self.present = 0 - def serialize(self): + def serialize(self) -> Text: """Serializes the flag.""" return self._serialize(self.value) - def _serialize(self, value): + def _serialize(self, value: Optional[_T]) -> Text: """Internal serialize function.""" if value is None: return '' @@ -209,7 +231,7 @@ class Flag(object): 'Serializer not present for flag %s' % self.name) return '--%s=%s' % (self.name, self.serializer.serialize(value)) - def _set_default(self, value): + def _set_default(self, value: Union[Optional[_T], Text]) -> None: """Changes the default value (and current value too) for this Flag.""" self.default_unparsed = value if value is None: @@ -222,10 +244,10 @@ class Flag(object): # This is split out so that aliases can skip regular parsing of the default # value. - def _parse_from_default(self, value): + def _parse_from_default(self, value: Union[Text, _T]) -> Optional[_T]: return self._parse(value) - def flag_type(self): + def flag_type(self) -> Text: """Returns a str that describes the type of the flag. NOTE: we use strings, and not the types.*Type constants because @@ -234,7 +256,9 @@ class Flag(object): """ return self.parser.flag_type() - def _create_xml_dom_element(self, doc, module_name, is_key=False): + def _create_xml_dom_element( + self, doc: minidom.Document, module_name: str, is_key: bool = False + ) -> minidom.Element: """Returns an XML element that contains this flag's information. This is information that is relevant to all flags (e.g., name, @@ -286,11 +310,13 @@ class Flag(object): element.appendChild(e) return element - def _serialize_value_for_xml(self, value): + def _serialize_value_for_xml(self, value: Optional[_T]) -> Any: """Returns the serialized value, for use in an XML help text.""" return value - def _extra_xml_dom_elements(self, doc): + def _extra_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: """Returns extra info about this flag in XML. "Extra" means "not already included by _create_xml_dom_element above." @@ -306,7 +332,7 @@ class Flag(object): return self.parser._custom_xml_dom_elements(doc) # pylint: disable=protected-access -class BooleanFlag(Flag): +class BooleanFlag(Flag[bool]): """Basic boolean flag. Boolean flags do not take any arguments, and their value is either @@ -319,24 +345,45 @@ class BooleanFlag(Flag): explicitly unset through either ``--noupdate`` or ``--nox``. """ - def __init__(self, name, default, help, short_name=None, **args): # pylint: disable=redefined-builtin + def __init__( + self, + name: Text, + default: Union[Optional[bool], Text], + help: Optional[Text], # pylint: disable=redefined-builtin + short_name: Optional[Text] = None, + **args + ) -> None: p = _argument_parser.BooleanParser() super(BooleanFlag, self).__init__( - p, None, name, default, help, short_name, 1, **args) + p, None, name, default, help, short_name, True, **args + ) -class EnumFlag(Flag): +class EnumFlag(Flag[Text]): """Basic enum flag; its value can be any string from list of enum_values.""" - def __init__(self, name, default, help, enum_values, # pylint: disable=redefined-builtin - short_name=None, case_sensitive=True, **args): + def __init__( + self, + name: Text, + default: Optional[Text], + help: Optional[Text], # pylint: disable=redefined-builtin + enum_values: Iterable[Text], + short_name: Optional[Text] = None, + case_sensitive: bool = True, + **args + ): p = _argument_parser.EnumParser(enum_values, case_sensitive) g = _argument_parser.ArgumentSerializer() super(EnumFlag, self).__init__( p, g, name, default, help, short_name, **args) - self.help = '<%s>: %s' % ('|'.join(enum_values), self.help) - - def _extra_xml_dom_elements(self, doc): + # NOTE: parser should be typed EnumParser but the constructor + # restricts the available interface to ArgumentParser[str]. + self.parser = p + self.help = '<%s>: %s' % ('|'.join(p.enum_values), self.help) + + def _extra_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: elements = [] for enum_value in self.parser.enum_values: elements.append(_helpers.create_xml_dom_element( @@ -344,26 +391,32 @@ class EnumFlag(Flag): return elements -class EnumClassFlag(Flag): +class EnumClassFlag(Flag[_ET]): """Basic enum flag; its value is an enum class's member.""" def __init__( self, - name, - default, - help, # pylint: disable=redefined-builtin - enum_class, - short_name=None, - case_sensitive=False, - **args): + name: Text, + default: Union[Optional[_ET], Text], + help: Optional[Text], # pylint: disable=redefined-builtin + enum_class: Type[_ET], + short_name: Optional[Text] = None, + case_sensitive: bool = False, + **args + ): p = _argument_parser.EnumClassParser( enum_class, case_sensitive=case_sensitive) g = _argument_parser.EnumClassSerializer(lowercase=not case_sensitive) super(EnumClassFlag, self).__init__( p, g, name, default, help, short_name, **args) + # NOTE: parser should be typed EnumClassParser[_ET] but the constructor + # restricts the available interface to ArgumentParser[_ET]. + self.parser = p self.help = '<%s>: %s' % ('|'.join(p.member_names), self.help) - def _extra_xml_dom_elements(self, doc): + def _extra_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: elements = [] for enum_value in self.parser.enum_class.__members__.keys(): elements.append(_helpers.create_xml_dom_element( @@ -371,7 +424,7 @@ class EnumClassFlag(Flag): return elements -class MultiFlag(Flag): +class MultiFlag(Generic[_T], Flag[List[_T]]): """A flag that can appear multiple time on the command-line. The value of such a flag is a list that contains the individual values @@ -392,7 +445,7 @@ class MultiFlag(Flag): super(MultiFlag, self).__init__(*args, **kwargs) self.help += ';\n repeat this option to specify a list of values' - def parse(self, arguments): + def parse(self, arguments: Union[Text, _T, Iterable[_T]]): # pylint: disable=arguments-renamed """Parses one or more arguments with the installed parser. Args: @@ -407,7 +460,7 @@ class MultiFlag(Flag): self.value = new_values self.present += len(new_values) - def _parse(self, arguments): + def _parse(self, arguments: Union[Text, Optional[Iterable[_T]]]) -> List[_T]: # pylint: disable=arguments-renamed if (isinstance(arguments, abc.Iterable) and not isinstance(arguments, str)): arguments = list(arguments) @@ -420,7 +473,7 @@ class MultiFlag(Flag): return [super(MultiFlag, self)._parse(item) for item in arguments] - def _serialize(self, value): + def _serialize(self, value: Optional[List[_T]]) -> Text: """See base class.""" if not self.serializer: raise _exceptions.Error( @@ -438,16 +491,18 @@ class MultiFlag(Flag): """See base class.""" return 'multi ' + self.parser.flag_type() - def _extra_xml_dom_elements(self, doc): + def _extra_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: elements = [] if hasattr(self.parser, 'enum_values'): - for enum_value in self.parser.enum_values: + for enum_value in self.parser.enum_values: # pytype: disable=attribute-error elements.append(_helpers.create_xml_dom_element( doc, 'enum_value', enum_value)) return elements -class MultiEnumClassFlag(MultiFlag): +class MultiEnumClassFlag(MultiFlag[_ET]): # pytype: disable=not-indexable """A multi_enum_class flag. See the __doc__ for MultiFlag for most behaviors of this class. In addition, @@ -455,26 +510,35 @@ class MultiEnumClassFlag(MultiFlag): type. """ - def __init__(self, - name, - default, - help_string, - enum_class, - case_sensitive=False, - **args): + def __init__( + self, + name: str, + default: Union[None, Iterable[_ET], _ET, Iterable[Text], Text], + help_string: str, + enum_class: Type[_ET], + case_sensitive: bool = False, + **args + ): p = _argument_parser.EnumClassParser( enum_class, case_sensitive=case_sensitive) g = _argument_parser.EnumClassListSerializer( list_sep=',', lowercase=not case_sensitive) super(MultiEnumClassFlag, self).__init__( p, g, name, default, help_string, **args) + # NOTE: parser should be typed EnumClassParser[_ET] but the constructor + # restricts the available interface to ArgumentParser[str]. + self.parser = p + # NOTE: serializer should be non-Optional but this isn't inferred. + self.serializer = g self.help = ( '<%s>: %s;\n repeat this option to specify a list of values' % ('|'.join(p.member_names), help_string or '(no help available)')) - def _extra_xml_dom_elements(self, doc): + def _extra_xml_dom_elements( + self, doc: minidom.Document + ) -> List[minidom.Element]: elements = [] - for enum_value in self.parser.enum_class.__members__.keys(): + for enum_value in self.parser.enum_class.__members__.keys(): # pytype: disable=attribute-error elements.append(_helpers.create_xml_dom_element( doc, 'enum_value', enum_value)) return elements @@ -482,6 +546,10 @@ class MultiEnumClassFlag(MultiFlag): def _serialize_value_for_xml(self, value): """See base class.""" if value is not None: + if not self.serializer: + raise _exceptions.Error( + 'Serializer not present for flag %s' % self.name + ) value_serialized = self.serializer.serialize(value) else: value_serialized = '' diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi deleted file mode 100644 index 8f840be..0000000 --- a/absl/flags/_flag.pyi +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2020 The Abseil Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Contains type annotations for Flag class.""" - -import copy -import functools - -from absl.flags import _argument_parser -from absl.flags import _validators_classes -import enum - -from typing import Callable, Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence - -_T = TypeVar('_T') -_ET = TypeVar('_ET', bound=enum.Enum) - - -class Flag(Generic[_T]): - - name: Text - default: Optional[_T] - default_unparsed: Union[Optional[_T], Text] - default_as_str: Optional[Text] - help: Text - short_name: Text - boolean: bool - present: bool - parser: _argument_parser.ArgumentParser[_T] - serializer: _argument_parser.ArgumentSerializer[_T] - allow_override: bool - allow_override_cpp: bool - allow_hide_cpp: bool - using_default_value: bool - allow_overwrite: bool - allow_using_method_names: bool - validators: List[_validators_classes.Validator] - - def __init__(self, - parser: _argument_parser.ArgumentParser[_T], - serializer: Optional[_argument_parser.ArgumentSerializer[_T]], - name: Text, - default: Union[Optional[_T], Text], - help_string: Optional[Text], - short_name: Optional[Text] = ..., - boolean: bool = ..., - allow_override: bool = ..., - allow_override_cpp: bool = ..., - allow_hide_cpp: bool = ..., - allow_overwrite: bool = ..., - allow_using_method_names: bool = ...) -> None: - ... - - - @property - def value(self) -> Optional[_T]: ... - - def parse(self, argument: Union[_T, Text, None]) -> None: ... - - def unparse(self) -> None: ... - - def _parse(self, argument: Any) -> Any: ... - - def _serialize(self, value: Any) -> Text: ... - - def __deepcopy__(self, memo: dict) -> Flag: ... - - def _get_parsed_value_as_string(self, value: Optional[_T]) -> Optional[Text]: - ... - - def serialize(self) -> Text: ... - - def flag_type(self) -> Text: ... - - -class BooleanFlag(Flag[bool]): - def __init__(self, - name: Text, - default: Union[Optional[bool], Text], - help: Optional[Text], - short_name: Optional[Text]=None, - **args: Any) -> None: - ... - - - -class EnumFlag(Flag[Text]): - def __init__(self, - name: Text, - default: Union[Optional[Text], Text], - help: Optional[Text], - enum_values: Sequence[Text], - short_name: Optional[Text] = ..., - case_sensitive: bool = ..., - **args: Any): - ... - - - -class EnumClassFlag(Flag[_ET]): - - def __init__(self, - name: Text, - default: Union[Optional[_ET], Text], - help: Optional[Text], - enum_class: Type[_ET], - short_name: Optional[Text]=None, - **args: Any): - ... - - - -class MultiFlag(Flag[List[_T]]): - ... - - -class MultiEnumClassFlag(MultiFlag[_ET]): - def __init__(self, - name: Text, - default: Union[Optional[List[_ET]], Text], - help_string: Optional[Text], - enum_class: Type[_ET], - **args: Any): - ... - - diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index fd0e631..56921ce 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -22,13 +22,14 @@ import itertools import logging import os import sys -from typing import Generic, TypeVar +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Text, TextIO, Generic, TypeVar, Union, Tuple from xml.dom import minidom from absl.flags import _exceptions from absl.flags import _flag from absl.flags import _helpers from absl.flags import _validators_classes +from absl.flags._flag import Flag # Add flagvalues module to disclaimed module ids. _helpers.disclaim_module_ids.add(id(sys.modules[__name__])) @@ -74,12 +75,16 @@ class FlagValues: help for all of the registered :class:`~absl.flags.Flag` objects. """ + _HAS_DYNAMIC_ATTRIBUTES = True + # A note on collections.abc.Mapping: # FlagValues defines __getitem__, __iter__, and __len__. It makes perfect # sense to let it be a collections.abc.Mapping class. However, we are not # able to do so. The mixin methods, e.g. keys, values, are not uncommon flag # names. Those flag values would not be accessible via the FLAGS.xxx form. + __dict__: Dict[str, Any] + def __init__(self): # Since everything in this class is so heavily overloaded, the only # way of defining and using fields is to access __dict__ directly. @@ -126,7 +131,7 @@ class FlagValues: # (is_retired, type_is_bool). self.__dict__['__is_retired_flag_func'] = None - def set_gnu_getopt(self, gnu_getopt=True): + def set_gnu_getopt(self, gnu_getopt: bool = True) -> None: """Sets whether or not to use GNU style scanning. GNU style allows mixing of flag and non-flag arguments. See @@ -138,13 +143,13 @@ class FlagValues: self.__dict__['__use_gnu_getopt'] = gnu_getopt self.__dict__['__use_gnu_getopt_explicitly_set'] = True - def is_gnu_getopt(self): + def is_gnu_getopt(self) -> bool: return self.__dict__['__use_gnu_getopt'] - def _flags(self): + def _flags(self) -> Dict[Text, Flag]: return self.__dict__['__flags'] - def flags_by_module_dict(self): + def flags_by_module_dict(self) -> Dict[Text, List[Flag]]: """Returns the dictionary of module_name -> list of defined flags. Returns: @@ -153,7 +158,7 @@ class FlagValues: """ return self.__dict__['__flags_by_module'] - def flags_by_module_id_dict(self): + def flags_by_module_id_dict(self) -> Dict[int, List[Flag]]: """Returns the dictionary of module_id -> list of defined flags. Returns: @@ -162,7 +167,7 @@ class FlagValues: """ return self.__dict__['__flags_by_module_id'] - def key_flags_by_module_dict(self): + def key_flags_by_module_dict(self) -> Dict[Text, List[Flag]]: """Returns the dictionary of module_name -> list of key flags. Returns: @@ -171,7 +176,7 @@ class FlagValues: """ return self.__dict__['__key_flags_by_module'] - def register_flag_by_module(self, module_name, flag): + def register_flag_by_module(self, module_name: Text, flag: Flag) -> None: """Records the module that defines a specific flag. We keep track of which flag is defined by which module so that we @@ -184,7 +189,7 @@ class FlagValues: flags_by_module = self.flags_by_module_dict() flags_by_module.setdefault(module_name, []).append(flag) - def register_flag_by_module_id(self, module_id, flag): + def register_flag_by_module_id(self, module_id: int, flag: Flag) -> None: """Records the module that defines a specific flag. Args: @@ -194,7 +199,7 @@ class FlagValues: flags_by_module_id = self.flags_by_module_id_dict() flags_by_module_id.setdefault(module_id, []).append(flag) - def register_key_flag_for_module(self, module_name, flag): + def register_key_flag_for_module(self, module_name: Text, flag: Flag) -> None: """Specifies that a flag is a key flag for a module. Args: @@ -208,7 +213,7 @@ class FlagValues: if flag not in key_flags: key_flags.append(flag) - def _flag_is_registered(self, flag_obj): + def _flag_is_registered(self, flag_obj: Flag) -> bool: """Checks whether a Flag object is registered under long name or short name. Args: @@ -228,7 +233,9 @@ class FlagValues: return True return False - def _cleanup_unregistered_flag_from_module_dicts(self, flag_obj): + def _cleanup_unregistered_flag_from_module_dicts( + self, flag_obj: Flag + ) -> None: """Cleans up unregistered flags from all module -> [flags] dictionaries. If flag_obj is registered under either its long name or short name, it @@ -248,7 +255,7 @@ class FlagValues: while flag_obj in flags_in_module: flags_in_module.remove(flag_obj) - def get_flags_for_module(self, module): + def get_flags_for_module(self, module: Union[Text, Any]) -> List[Flag]: """Returns the list of flags defined by a module. Args: @@ -266,7 +273,7 @@ class FlagValues: return list(self.flags_by_module_dict().get(module, [])) - def get_key_flags_for_module(self, module): + def get_key_flags_for_module(self, module: Union[Text, Any]) -> List[Flag]: """Returns the list of key flags for a module. Args: @@ -293,7 +300,10 @@ class FlagValues: key_flags.append(flag) return key_flags - def find_module_defining_flag(self, flagname, default=None): + # TODO(yileiyang): Restrict default to Optional[Text]. + def find_module_defining_flag( + self, flagname: Text, default: Optional[_T] = None + ) -> Union[str, Optional[_T]]: """Return the name of the module defining this flag, or default. Args: @@ -318,7 +328,10 @@ class FlagValues: return module return default - def find_module_id_defining_flag(self, flagname, default=None): + # TODO(yileiyang): Restrict default to Optional[Text]. + def find_module_id_defining_flag( + self, flagname: Text, default: Optional[_T] = None + ) -> Union[int, Optional[_T]]: """Return the ID of the module defining this flag, or default. Args: @@ -343,7 +356,9 @@ class FlagValues: return module_id return default - def _register_unknown_flag_setter(self, setter): + def _register_unknown_flag_setter( + self, setter: Callable[[str, Any], None] + ) -> None: """Allow set default values for undefined flags. Args: @@ -352,7 +367,7 @@ class FlagValues: """ self.__dict__['__set_unknown'] = setter - def _set_unknown_flag(self, name, value): + def _set_unknown_flag(self, name: str, value: _T) -> _T: """Returns value if setting flag |name| to |value| returned True. Args: @@ -378,7 +393,7 @@ class FlagValues: pass raise _exceptions.UnrecognizedFlagError(name, value) - def append_flag_values(self, flag_values): + def append_flag_values(self, flag_values: 'FlagValues') -> None: """Appends flags registered in another FlagValues instance. Args: @@ -397,7 +412,9 @@ class FlagValues: raise _exceptions.DuplicateFlagError.from_flag( flag_name, self, other_flag_values=flag_values) - def remove_flag_values(self, flag_values): + def remove_flag_values( + self, flag_values: 'Union[FlagValues, Iterable[Text]]' + ) -> None: """Remove flags that were previously appended from another FlagValues. Args: @@ -407,7 +424,7 @@ class FlagValues: for flag_name in flag_values: self.__delattr__(flag_name) - def __setitem__(self, name, flag): + def __setitem__(self, name: Text, flag: Flag) -> None: """Registers a new flag variable.""" fl = self._flags() if not isinstance(flag, _flag.Flag): @@ -430,10 +447,10 @@ class FlagValues: # module is simply being imported a subsequent time. return raise _exceptions.DuplicateFlagError.from_flag(name, self) - short_name = flag.short_name # If a new flag overrides an old one, we need to cleanup the old flag's # modules if it's not registered. flags_to_cleanup = set() + short_name: str = flag.short_name # pytype: disable=annotation-type-mismatch if short_name is not None: if (short_name in fl and not flag.allow_override and not fl[short_name].allow_override): @@ -449,7 +466,7 @@ class FlagValues: for f in flags_to_cleanup: self._cleanup_unregistered_flag_from_module_dicts(f) - def __dir__(self): + def __dir__(self) -> List[Text]: """Returns list of names of all defined flags. Useful for TAB-completion in ipython. @@ -459,7 +476,7 @@ class FlagValues: """ return sorted(self.__dict__['__flags']) - def __getitem__(self, name): + def __getitem__(self, name: Text) -> Flag: """Returns the Flag object for the flag --name.""" return self._flags()[name] @@ -467,7 +484,7 @@ class FlagValues: """Marks the flag --name as hidden.""" self.__dict__['__hiddenflags'].add(name) - def __getattr__(self, name): + def __getattr__(self, name: Text) -> Any: """Retrieves the 'value' attribute of the flag --name.""" fl = self._flags() if name not in fl: @@ -481,12 +498,12 @@ class FlagValues: raise _exceptions.UnparsedFlagAccessError( 'Trying to access flag --%s before flags were parsed.' % name) - def __setattr__(self, name, value): + def __setattr__(self, name: Text, value: _T) -> _T: """Sets the 'value' attribute of the flag --name.""" self._set_attributes(**{name: value}) return value - def _set_attributes(self, **attributes): + def _set_attributes(self, **attributes: Any) -> None: """Sets multiple flag values together, triggers validators afterwards.""" fl = self._flags() known_flags = set() @@ -502,7 +519,7 @@ class FlagValues: self._assert_validators(fl[name].validators) fl[name].using_default_value = False - def validate_all_flags(self): + def validate_all_flags(self) -> None: """Verifies whether all flags pass validation. Raises: @@ -515,7 +532,9 @@ class FlagValues: all_validators.update(flag.validators) self._assert_validators(all_validators) - def _assert_validators(self, validators): + def _assert_validators( + self, validators: Iterable[_validators_classes.Validator] + ) -> None: """Asserts if all validators in the list are satisfied. It asserts validators in the order they were created. @@ -550,7 +569,7 @@ class FlagValues: if messages: raise _exceptions.IllegalFlagValueError('\n'.join(messages)) - def __delattr__(self, flag_name): + def __delattr__(self, flag_name: Text) -> None: """Deletes a previously-defined flag from a flag object. This method makes sure we can delete a flag by using @@ -580,7 +599,7 @@ class FlagValues: self._cleanup_unregistered_flag_from_module_dicts(flag_obj) - def set_default(self, name, value): + def set_default(self, name: Text, value: Any) -> None: """Changes the default value of the named flag object. The flag's current value is also updated if the flag is currently using @@ -602,17 +621,19 @@ class FlagValues: fl[name]._set_default(value) # pylint: disable=protected-access self._assert_validators(fl[name].validators) - def __contains__(self, name): + def __contains__(self, name: Text) -> bool: """Returns True if name is a value (flag) in the dict.""" return name in self._flags() - def __len__(self): + def __len__(self) -> int: return len(self.__dict__['__flags']) - def __iter__(self): + def __iter__(self) -> Iterator[Text]: return iter(self._flags()) - def __call__(self, argv, known_only=False): + def __call__( + self, argv: Sequence[Text], known_only: bool = False + ) -> List[Text]: """Parses flags from argv; stores parsed flags into this FlagValues object. All unparsed arguments are returned. @@ -656,14 +677,14 @@ class FlagValues: self.validate_all_flags() return [program_name] + unparsed_args - def __getstate__(self): + def __getstate__(self) -> Any: raise TypeError("can't pickle FlagValues") - def __copy__(self): + def __copy__(self) -> Any: raise TypeError('FlagValues does not support shallow copies. ' 'Use absl.testing.flagsaver or copy.deepcopy instead.') - def __deepcopy__(self, memo): + def __deepcopy__(self, memo) -> Any: result = object.__new__(type(self)) result.__dict__.update(copy.deepcopy(self.__dict__, memo)) return result @@ -680,7 +701,9 @@ class FlagValues: """ self.__dict__['__is_retired_flag_func'] = is_retired_flag_func - def _parse_args(self, args, known_only): + def _parse_args( + self, args: List[str], known_only: bool + ) -> Tuple[List[Tuple[Optional[str], Any]], List[str]]: """Helper function to do the main argument parsing. This function goes through args and does the bulk of the flag parsing. @@ -818,11 +841,11 @@ class FlagValues: unparsed_args.extend(list(args)) return unknown_flags, unparsed_args - def is_parsed(self): + def is_parsed(self) -> bool: """Returns whether flags were parsed.""" return self.__dict__['__flags_parsed'] - def mark_as_parsed(self): + def mark_as_parsed(self) -> None: """Explicitly marks flags as parsed. Use this when the caller knows that this FlagValues has been parsed as if @@ -831,7 +854,7 @@ class FlagValues: """ self.__dict__['__flags_parsed'] = True - def unparse_flags(self): + def unparse_flags(self) -> None: """Unparses all flags to the point before any FLAGS(argv) was called.""" for f in self._flags().values(): f.unparse() @@ -841,7 +864,7 @@ class FlagValues: self.__dict__['__flags_parsed'] = False self.__dict__['__unparse_flags_called'] = True - def flag_values_dict(self): + def flag_values_dict(self) -> Dict[Text, Any]: """Returns a dictionary that maps flag names to flag values.""" return {name: flag.value for name, flag in self._flags().items()} @@ -849,7 +872,9 @@ class FlagValues: """Returns a help string for all known flags.""" return self.get_help() - def get_help(self, prefix='', include_special_flags=True): + def get_help( + self, prefix: Text = '', include_special_flags: bool = True + ) -> Text: """Returns a help string for all known flags. Args: @@ -875,7 +900,8 @@ class FlagValues: values = self._flags().values() if include_special_flags: values = itertools.chain( - values, _helpers.SPECIAL_FLAGS._flags().values()) # pylint: disable=protected-access + values, _helpers.SPECIAL_FLAGS._flags().values() # pylint: disable=protected-access # pytype: disable=attribute-error + ) self._render_flag_list(values, output_lines, prefix) return '\n'.join(output_lines) @@ -896,9 +922,10 @@ class FlagValues: if include_special_flags: self._render_module_flags( 'absl.flags', - _helpers.SPECIAL_FLAGS._flags().values(), # pylint: disable=protected-access + _helpers.SPECIAL_FLAGS._flags().values(), # pylint: disable=protected-access # pytype: disable=attribute-error output_lines, - prefix) + prefix, + ) return '\n'.join(output_lines) def _render_module_flags(self, module, flags, output_lines, prefix=''): @@ -927,7 +954,7 @@ class FlagValues: if key_flags: self._render_module_flags(module, key_flags, output_lines, prefix) - def module_help(self, module): + def module_help(self, module: Any) -> Text: """Describes the key flags of a module. Args: @@ -940,7 +967,7 @@ class FlagValues: self._render_our_module_key_flags(module, helplist) return '\n'.join(helplist) - def main_module_help(self): + def main_module_help(self) -> Text: """Describes the key flags of the main module. Returns: @@ -950,7 +977,7 @@ class FlagValues: def _render_flag_list(self, flaglist, output_lines, prefix=' '): fl = self._flags() - special_fl = _helpers.SPECIAL_FLAGS._flags() # pylint: disable=protected-access + special_fl = _helpers.SPECIAL_FLAGS._flags() # pylint: disable=protected-access # pytype: disable=attribute-error flaglist = [(flag.name, flag) for flag in flaglist] flaglist.sort() flagset = {} @@ -987,7 +1014,7 @@ class FlagValues: '(%s)' % flag.parser.syntactic_help, indent=prefix + ' ') output_lines.append(flaghelp) - def get_flag_value(self, name, default): # pylint: disable=invalid-name + def get_flag_value(self, name: Text, default: Any) -> Any: # pylint: disable=invalid-name """Returns the value of a flag (if not None) or a default value. Args: @@ -1109,7 +1136,9 @@ class FlagValues: parsed_file_stack.pop() return flag_line_list - def read_flags_from_files(self, argv, force_gnu=True): + def read_flags_from_files( + self, argv: Sequence[Text], force_gnu: bool = True + ) -> List[Text]: """Processes command line args, but also allow args to be read from file. Args: @@ -1192,7 +1221,7 @@ class FlagValues: return new_argv - def flags_into_string(self): + def flags_into_string(self) -> Text: """Returns a string with the flags assignments from this FlagValues object. This function ignores flags whose value is None. Each flag @@ -1214,7 +1243,7 @@ class FlagValues: s += flag.serialize() + '\n' return s - def append_flags_into_file(self, filename): + def append_flags_into_file(self, filename: Text) -> None: """Appends all flags assignments from this FlagInfo object to a file. Output will be in the format of a flagfile. @@ -1228,7 +1257,7 @@ class FlagValues: with open(filename, 'a') as out_file: out_file.write(self.flags_into_string()) - def write_help_in_xml_format(self, outfile=None): + def write_help_in_xml_format(self, outfile: Optional[TextIO] = None) -> None: """Outputs flag documentation in XML format. NOTE: We use element names that are consistent with those used by @@ -1280,7 +1309,7 @@ class FlagValues: doc.toprettyxml(indent=' ', encoding='utf-8').decode('utf-8')) outfile.flush() - def _check_method_name_conflicts(self, name, flag): + def _check_method_name_conflicts(self, name: str, flag: Flag): if flag.allow_using_method_names: return short_name = flag.short_name @@ -1325,7 +1354,14 @@ class FlagHolder(Generic[_T]): since the name of the flag appears only once in the source code. """ - def __init__(self, flag_values, flag, ensure_non_none_value=False): + value: _T + + def __init__( + self, + flag_values: FlagValues, + flag: Flag[_T], + ensure_non_none_value: bool = False, + ): """Constructs a FlagHolder instance providing typesafe access to flag. Args: @@ -1359,11 +1395,11 @@ class FlagHolder(Generic[_T]): __nonzero__ = __bool__ @property - def name(self): + def name(self) -> Text: return self._name @property - def value(self): + def value(self) -> _T: """Returns the value of the flag. If ``_ensure_non_none_value`` is ``True``, then return value is not @@ -1380,17 +1416,19 @@ class FlagHolder(Generic[_T]): return val @property - def default(self): + def default(self) -> _T: """Returns the default value of the flag.""" return self._flagvalues[self._name].default @property - def present(self): + def present(self) -> bool: """Returns True if the flag was parsed from command-line flags.""" return bool(self._flagvalues[self._name].present) -def resolve_flag_ref(flag_ref, flag_values): +def resolve_flag_ref( + flag_ref: Union[str, FlagHolder], flag_values: FlagValues +) -> Tuple[str, FlagValues]: """Helper to validate and resolve a flag reference argument.""" if isinstance(flag_ref, FlagHolder): new_flag_values = flag_ref._flagvalues # pylint: disable=protected-access @@ -1401,7 +1439,9 @@ def resolve_flag_ref(flag_ref, flag_values): return flag_ref, flag_values -def resolve_flag_refs(flag_refs, flag_values): +def resolve_flag_refs( + flag_refs: Sequence[Union[str, FlagHolder]], flag_values: FlagValues +) -> Tuple[List[str], FlagValues]: """Helper to validate and resolve flag reference list arguments.""" fv = None names = [] diff --git a/absl/flags/_flagvalues.pyi b/absl/flags/_flagvalues.pyi deleted file mode 100644 index d8e3935..0000000 --- a/absl/flags/_flagvalues.pyi +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2020 The Abseil Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Defines type annotations for _flagvalues.""" - - -from absl.flags import _flag - -from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Text, Type, TypeVar - - -class FlagValues: - - def __getitem__(self, name: Text) -> _flag.Flag: ... - - def __setitem__(self, name: Text, flag: _flag.Flag) -> None: ... - - def __getattr__(self, name: Text) -> Any: ... - - def __setattr__(self, name: Text, value: Any) -> Any: ... - - def __call__( - self, - argv: Sequence[Text], - known_only: bool = ..., - ) -> List[Text]: ... - - def __contains__(self, name: Text) -> bool: ... - - def __copy__(self) -> Any: ... - - def __deepcopy__(self, memo) -> Any: ... - - def __delattr__(self, flag_name: Text) -> None: ... - - def __dir__(self) -> List[Text]: ... - - def __getstate__(self) -> Any: ... - - def __iter__(self) -> Iterator[Text]: ... - - def __len__(self) -> int: ... - - def get_help(self, - prefix: Text = ..., - include_special_flags: bool = ...) -> Text: - ... - - - def set_gnu_getopt(self, gnu_getopt: bool = ...) -> None: ... - - def is_gnu_getopt(self) -> bool: ... - - def flags_by_module_dict(self) -> Dict[Text, List[_flag.Flag]]: ... - - def flags_by_module_id_dict(self) -> Dict[Text, List[_flag.Flag]]: ... - - def key_flags_by_module_dict(self) -> Dict[Text, List[_flag.Flag]]: ... - - def register_flag_by_module( - self, module_name: Text, flag: _flag.Flag) -> None: ... - - def register_flag_by_module_id( - self, module_id: int, flag: _flag.Flag) -> None: ... - - def register_key_flag_for_module( - self, module_name: Text, flag: _flag.Flag) -> None: ... - - def get_key_flags_for_module(self, module: Any) -> List[_flag.Flag]: ... - - def find_module_defining_flag( - self, flagname: Text, default: Any = ...) -> Any: - ... - - def find_module_id_defining_flag( - self, flagname: Text, default: Any = ...) -> Any: - ... - - def append_flag_values(self, flag_values: Any) -> None: ... - - def remove_flag_values(self, flag_values: Any) -> None: ... - - def validate_all_flags(self) -> None: ... - - def set_default(self, name: Text, value: Any) -> None: ... - - def is_parsed(self) -> bool: ... - - def mark_as_parsed(self) -> None: ... - - def unparse_flags(self) -> None: ... - - def flag_values_dict(self) -> Dict[Text, Any]: ... - - def module_help(self, module: Any) -> Text: ... - - def main_module_help(self) -> Text: ... - - def get_flag_value(self, name: Text, default: Any) -> Any: ... - - def read_flags_from_files( - self, argv: List[Text], force_gnu: bool = ...) -> List[Text]: ... - - def flags_into_string(self) -> Text: ... - - def append_flags_into_file(self, filename: Text) -> None:... - - # outfile is Optional[fileobject] - def write_help_in_xml_format(self, outfile: Any = ...) -> None: ... - - -FLAGS = ... # type: FlagValues - - -_T = TypeVar('_T') # The type of parsed default value of the flag. - -# We assume that default and value are guaranteed to have the same type. -class FlagHolder(Generic[_T]): - def __init__( - self, - flag_values: FlagValues, - # NOTE: Use Flag instead of Flag[T] is used to work around some superficial - # differences between Flag and FlagHolder typing. - flag: _flag.Flag, - ensure_non_none_value: bool=False) -> None: ... - - @property - def name(self) -> Text: ... - - @property - def value(self) -> _T: ... - - @property - def default(self) -> _T: ... - - @property - def present(self) -> bool: ... diff --git a/absl/flags/_helpers.py b/absl/flags/_helpers.py index cbb98a7..1ad559c 100644 --- a/absl/flags/_helpers.py +++ b/absl/flags/_helpers.py @@ -14,12 +14,15 @@ """Internal helper functions for Abseil Python flags library.""" -import collections import os import re import struct import sys import textwrap +import types +from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Set +from xml.dom import minidom +# pylint: disable=g-import-not-at-top try: import fcntl except ImportError: @@ -29,6 +32,7 @@ try: import termios except ImportError: termios = None +# pylint: enable=g-import-not-at-top _DEFAULT_HELP_WIDTH = 80 # Default width of help output. @@ -56,32 +60,37 @@ _ILLEGAL_XML_CHARS_REGEX = re.compile( # This is a set of module ids for the modules that disclaim key flags. # This module is explicitly added to this set so that we never consider it to # define key flag. -disclaim_module_ids = set([id(sys.modules[__name__])]) +disclaim_module_ids: Set[int] = set([id(sys.modules[__name__])]) # Define special flags here so that help may be generated for them. # NOTE: Please do NOT use SPECIAL_FLAGS from outside flags module. # Initialized inside flagvalues.py. -SPECIAL_FLAGS = None +# NOTE: This cannot be annotated as its actual FlagValues type since this would +# create a circular dependency. +SPECIAL_FLAGS: Any = None # This points to the flags module, initialized in flags/__init__.py. # This should only be used in adopt_module_key_flags to take SPECIAL_FLAGS into # account. -FLAGS_MODULE = None +FLAGS_MODULE: types.ModuleType = None -class _ModuleObjectAndName( - collections.namedtuple('_ModuleObjectAndName', 'module module_name')): +class _ModuleObjectAndName(NamedTuple): """Module object and name. Fields: - module: object, module object. - module_name: str, module name. """ + module: types.ModuleType + module_name: str -def get_module_object_and_name(globals_dict): +def get_module_object_and_name( + globals_dict: Dict[str, Any] +) -> _ModuleObjectAndName: """Returns the module that defines a global environment, and its name. Args: @@ -99,7 +108,7 @@ def get_module_object_and_name(globals_dict): (sys.argv[0] if name == '__main__' else name)) -def get_calling_module_object_and_name(): +def get_calling_module_object_and_name() -> _ModuleObjectAndName: """Returns the module that's calling into this module. We generally use this function to get the name of the module calling a @@ -121,12 +130,14 @@ def get_calling_module_object_and_name(): raise AssertionError('No module was found') -def get_calling_module(): +def get_calling_module() -> str: """Returns the name of the module that's calling into this module.""" return get_calling_module_object_and_name().module_name -def create_xml_dom_element(doc, name, value): +def create_xml_dom_element( + doc: minidom.Document, name: str, value: Any +) -> minidom.Element: """Returns an XML DOM element with name and text value. Args: @@ -151,7 +162,7 @@ def create_xml_dom_element(doc, name, value): return e -def get_help_width(): +def get_help_width() -> int: """Returns the integer width of help lines that is used in TextWrap.""" if not sys.stdout.isatty() or termios is None or fcntl is None: return _DEFAULT_HELP_WIDTH @@ -169,7 +180,9 @@ def get_help_width(): return _DEFAULT_HELP_WIDTH -def get_flag_suggestions(attempt, longopt_list): +def get_flag_suggestions( + attempt: Optional[str], longopt_list: Sequence[str] +) -> List[str]: """Returns helpful similar matches for an invalid flag.""" # Don't suggest on very short strings, or if no longopts are specified. if len(attempt) <= 2 or not longopt_list: @@ -226,7 +239,12 @@ def _damerau_levenshtein(a, b): return distance(a, b) -def text_wrap(text, length=None, indent='', firstline_indent=None): +def text_wrap( + text: str, + length: Optional[int] = None, + indent: str = '', + firstline_indent: Optional[str] = None, +) -> str: """Wraps a given text to a maximum line length and returns it. It turns lines that only contain whitespace into empty lines, keeps new lines, @@ -283,7 +301,9 @@ def text_wrap(text, length=None, indent='', firstline_indent=None): return '\n'.join(result) -def flag_dict_to_args(flag_map, multi_flags=None): +def flag_dict_to_args( + flag_map: Dict[str, Any], multi_flags: Optional[Set[str]] = None +) -> Iterable[str]: """Convert a dict of values into process call parameters. This method is used to convert a dictionary into a sequence of parameters @@ -333,7 +353,7 @@ def flag_dict_to_args(flag_map, multi_flags=None): yield '--%s=%s' % (key, value) -def trim_docstring(docstring): +def trim_docstring(docstring: str) -> str: """Removes indentation from triple-quoted strings. This is the function specified in PEP 257 to handle docstrings: @@ -375,7 +395,7 @@ def trim_docstring(docstring): return '\n'.join(trimmed) -def doc_to_help(doc): +def doc_to_help(doc: str) -> str: """Takes a __doc__ string and reformats it as help.""" # Get rid of starting and ending white space. Using lstrip() or even diff --git a/absl/flags/_helpers.pyi b/absl/flags/_helpers.pyi deleted file mode 100644 index fe3b9f5..0000000 --- a/absl/flags/_helpers.pyi +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2017 The Abseil Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from xml.dom import minidom -import types -from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Set - -disclaim_module_ids: Set[int] -FLAGS_MODULE: types.ModuleType -# NOTE: This cannot be annotated as its actual FlagValues type since this would -# create a circular dependency. -SPECIAL_FLAGS: Any - - -class _ModuleObjectAndName(NamedTuple): - module: types.ModuleType - module_name: str - - -def get_module_object_and_name( - globals_dict: Dict[str, Any] -) -> _ModuleObjectAndName: - ... - - -def get_calling_module_object_and_name() -> _ModuleObjectAndName: - ... - - -def get_calling_module() -> str: - ... - - -def create_xml_dom_element( - doc: minidom.Document, name: str, value: Any -) -> minidom.Element: - ... - - -def get_help_width() -> int: - ... - - -def get_flag_suggestions( - attempt: Optional[str], longopt_list: List[str] -) -> List[str]: - ... - - -def text_wrap( - text: str, - length: Optional[int] = ..., - indent: str = ..., - firstline_indent: Optional[str] = ..., -) -> str: - ... - - -def flag_dict_to_args( - flag_map: Dict[str, str], multi_flags: Optional[Set[str]] = ... -) -> Iterable[str]: - ... - - -def trim_docstring(docstring: str) -> str: - ... - - -def doc_to_help(doc: str) -> str: - ... diff --git a/absl/flags/tests/_argument_parser_test.py b/absl/flags/tests/_argument_parser_test.py index 4281c3f..6f7d191 100644 --- a/absl/flags/tests/_argument_parser_test.py +++ b/absl/flags/tests/_argument_parser_test.py @@ -33,12 +33,12 @@ class ArgumentParserTest(absltest.TestCase): def test_parse_wrong_type(self): parser = _argument_parser.ArgumentParser() with self.assertRaises(TypeError): - parser.parse(0) + parser.parse(0) # type: ignore if bytes is not str: # In PY3, it does not accept bytes. with self.assertRaises(TypeError): - parser.parse(b'') + parser.parse(b'') # type: ignore class BooleanParserTest(absltest.TestCase): @@ -49,7 +49,7 @@ class BooleanParserTest(absltest.TestCase): def test_parse_bytes(self): with self.assertRaises(TypeError): - self.parser.parse(b'true') + self.parser.parse(b'true') # type: ignore def test_parse_str(self): self.assertTrue(self.parser.parse('true')) @@ -59,7 +59,7 @@ class BooleanParserTest(absltest.TestCase): def test_parse_wrong_type(self): with self.assertRaises(TypeError): - self.parser.parse(1.234) + self.parser.parse(1.234) # type: ignore def test_parse_str_false(self): self.assertFalse(self.parser.parse('false')) @@ -86,7 +86,7 @@ class FloatParserTest(absltest.TestCase): def test_parse_wrong_type(self): with self.assertRaises(TypeError): - self.parser.parse(False) + self.parser.parse(False) # type: ignore class IntegerParserTest(absltest.TestCase): @@ -99,9 +99,9 @@ class IntegerParserTest(absltest.TestCase): def test_parse_wrong_type(self): with self.assertRaises(TypeError): - self.parser.parse(1e2) + self.parser.parse(1e2) # type: ignore with self.assertRaises(TypeError): - self.parser.parse(False) + self.parser.parse(False) # type: ignore class EnumParserTest(absltest.TestCase): @@ -139,7 +139,7 @@ class EnumClassParserTest(parameterized.TestCase): def test_requires_enum(self): with self.assertRaises(TypeError): - _argument_parser.EnumClassParser(['apple', 'banana']) + _argument_parser.EnumClassParser(['apple', 'banana']) # type: ignore def test_requires_non_empty_enum_class(self): with self.assertRaises(ValueError): diff --git a/absl/flags/tests/_flag_test.py b/absl/flags/tests/_flag_test.py index 1625289..92de6c0 100644 --- a/absl/flags/tests/_flag_test.py +++ b/absl/flags/tests/_flag_test.py @@ -142,7 +142,7 @@ class EnumClassFlagTest(parameterized.TestCase): def test_requires_enum(self): with self.assertRaises(TypeError): - _flag.EnumClassFlag('fruit', None, 'help', ['apple', 'orange']) + _flag.EnumClassFlag('fruit', None, 'help', ['apple', 'orange']) # type: ignore def test_requires_non_empty_enum_class(self): with self.assertRaises(ValueError): @@ -186,7 +186,7 @@ class MultiEnumClassFlagTest(parameterized.TestCase): def test_requires_enum(self): with self.assertRaises(TypeError): - _flag.MultiEnumClassFlag('fruit', None, 'help', ['apple', 'orange']) + _flag.MultiEnumClassFlag('fruit', None, 'help', ['apple', 'orange']) # type: ignore def test_requires_non_empty_enum_class(self): with self.assertRaises(ValueError): diff --git a/absl/flags/tests/_flagvalues_test.py b/absl/flags/tests/_flagvalues_test.py index 46639f2..e37004d 100644 --- a/absl/flags/tests/_flagvalues_test.py +++ b/absl/flags/tests/_flagvalues_test.py @@ -323,7 +323,7 @@ class FlagValuesTest(absltest.TestCase): _defines.DEFINE_boolean('', 0, '') with self.assertRaises(_exceptions.Error): - _defines.DEFINE_boolean(1, 0, '') + _defines.DEFINE_boolean(1, 0, '') # type: ignore def test_len(self): fv = _flagvalues.FlagValues() @@ -511,11 +511,9 @@ absl.flags.tests.module_foo: def test_invalid_argv(self): fv = _flagvalues.FlagValues() with self.assertRaises(TypeError): - fv('./program') + fv('./program') # type: ignore with self.assertRaises(TypeError): - fv(b'./program') - with self.assertRaises(TypeError): - fv(u'./program') + fv(b'./program') # type: ignore def test_flags_dir(self): flag_values = _flagvalues.FlagValues() diff --git a/absl/flags/tests/_helpers_test.py b/absl/flags/tests/_helpers_test.py index 78b9051..daaf98c 100644 --- a/absl/flags/tests/_helpers_test.py +++ b/absl/flags/tests/_helpers_test.py @@ -73,8 +73,9 @@ class FlagSuggestionTest(absltest.TestCase): def test_suggestions_are_sorted(self): sorted_flags = sorted(['aab', 'aac', 'aad']) misspelt_flag = 'aaa' - suggestions = _helpers.get_flag_suggestions(misspelt_flag, - reversed(sorted_flags)) + suggestions = _helpers.get_flag_suggestions( + misspelt_flag, list(reversed(sorted_flags)) + ) self.assertEqual(sorted_flags, suggestions) diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py index 9aa328e..cf64cbe 100644 --- a/absl/flags/tests/_validators_test.py +++ b/absl/flags/tests/_validators_test.py @@ -20,7 +20,6 @@ failed validator will throw an exception, etc. import warnings - from absl.flags import _defines from absl.flags import _exceptions from absl.flags import _flagvalues diff --git a/absl/flags/tests/argparse_flags_test.py b/absl/flags/tests/argparse_flags_test.py index 5e6f49a..5469c4e 100644 --- a/absl/flags/tests/argparse_flags_test.py +++ b/absl/flags/tests/argparse_flags_test.py @@ -179,8 +179,11 @@ class ArgparseFlagsTest(parameterized.TestCase): parser.add_argument('--header', help='Header message to print.') subparsers = parser.add_subparsers(help='The command to execute.') - sub_parser = subparsers.add_parser( - 'sub_cmd', help='Sub command.', inherited_absl_flags=self._absl_flags) + # NOTE: The sub parsers don't work well with typing hence `type: ignore`. + # See https://github.com/python/typeshed/issues/10082. + sub_parser = subparsers.add_parser( # type: ignore + 'sub_cmd', help='Sub command.', inherited_absl_flags=self._absl_flags + ) sub_parser.add_argument('--sub_flag', help='Sub command flag.') def sub_command_func(): @@ -203,11 +206,15 @@ class ArgparseFlagsTest(parameterized.TestCase): inherited_absl_flags=self._absl_flags) subparsers = parser.add_subparsers(help='The command to execute.') - subparsers.add_parser( - 'sub_cmd', help='Sub command.', + # NOTE: The sub parsers don't work well with typing hence `type: ignore`. + # See https://github.com/python/typeshed/issues/10082. + subparsers.add_parser( # type: ignore + 'sub_cmd', + help='Sub command.', # Do not inherit absl flags in the subparser. # This is the behavior that this test exercises. - inherited_absl_flags=None) + inherited_absl_flags=None, + ) with self.assertRaises(SystemExit): parser.parse_args(['sub_cmd', '--absl_string=new_value']) @@ -270,10 +277,10 @@ class ArgparseFlagsTest(parameterized.TestCase): def test_no_help_flags(self, args): parser = argparse_flags.ArgumentParser( inherited_absl_flags=self._absl_flags, add_help=False) - with mock.patch.object(parser, 'print_help'): + with mock.patch.object(parser, 'print_help') as print_help_mock: with self.assertRaises(SystemExit): parser.parse_args(args) - parser.print_help.assert_not_called() + print_help_mock.assert_not_called() def test_helpfull_message(self): flags.DEFINE_string( diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py index 7cacbc8..0a7eaf6 100644 --- a/absl/flags/tests/flags_test.py +++ b/absl/flags/tests/flags_test.py @@ -302,7 +302,11 @@ class FlagsUnitTest(absltest.TestCase): flags.DEFINE_integer('l', 0x7fffffff00000000, 'how long to be') flags.DEFINE_list('args', 'v=1,"vmodule=a=0,b=2"', 'a list of arguments') flags.DEFINE_list('letters', 'a,b,c', 'a list of letters') - flags.DEFINE_list('numbers', [1, 2, 3], 'a list of numbers') + flags.DEFINE_list( + 'list_default_list', + ['a', 'b', 'c'], + 'with default being a list of strings', + ) flags.DEFINE_enum('kwery', None, ['who', 'what', 'Why', 'where', 'when'], '?') flags.DEFINE_enum( @@ -346,7 +350,7 @@ class FlagsUnitTest(absltest.TestCase): self.assertEqual(FLAGS.l, 0x7fffffff00000000) self.assertEqual(FLAGS.args, ['v=1', 'vmodule=a=0,b=2']) self.assertEqual(FLAGS.letters, ['a', 'b', 'c']) - self.assertEqual(FLAGS.numbers, [1, 2, 3]) + self.assertEqual(FLAGS.list_default_list, ['a', 'b', 'c']) self.assertIsNone(FLAGS.kwery) self.assertIsNone(FLAGS.sense) self.assertIsNone(FLAGS.cases) @@ -364,7 +368,7 @@ class FlagsUnitTest(absltest.TestCase): self.assertEqual(flag_values['l'], 0x7fffffff00000000) self.assertEqual(flag_values['args'], ['v=1', 'vmodule=a=0,b=2']) self.assertEqual(flag_values['letters'], ['a', 'b', 'c']) - self.assertEqual(flag_values['numbers'], [1, 2, 3]) + self.assertEqual(flag_values['list_default_list'], ['a', 'b', 'c']) self.assertIsNone(flag_values['kwery']) self.assertIsNone(flag_values['sense']) self.assertIsNone(flag_values['cases']) @@ -382,7 +386,7 @@ class FlagsUnitTest(absltest.TestCase): self.assertEqual(FLAGS['l'].default_as_str, "'9223372032559808512'") self.assertEqual(FLAGS['args'].default_as_str, '\'v=1,"vmodule=a=0,b=2"\'') self.assertEqual(FLAGS['letters'].default_as_str, "'a,b,c'") - self.assertEqual(FLAGS['numbers'].default_as_str, "'1,2,3'") + self.assertEqual(FLAGS['list_default_list'].default_as_str, "'a,b,c'") # Verify that the iterator for flags yields all the keys keys = list(FLAGS) @@ -424,7 +428,7 @@ class FlagsUnitTest(absltest.TestCase): self.assertIn('l', FLAGS) self.assertIn('args', FLAGS) self.assertIn('letters', FLAGS) - self.assertIn('numbers', FLAGS) + self.assertIn('list_default_list', FLAGS) # __contains__ self.assertIn('name', FLAGS) @@ -787,67 +791,70 @@ class FlagsUnitTest(absltest.TestCase): self.assertEqual(FLAGS.get_flag_value('repeat', None), 3) self.assertEqual(FLAGS.get_flag_value('name', None), 'giants') self.assertEqual(FLAGS.get_flag_value('debug', None), 0) - self.assertListEqual([ - '--alsologtostderr', - "--args ['v=1', 'vmodule=a=0,b=2']", - '--blah None', - '--cases None', - '--decimal 666', - '--float 3.14', - '--funny None', - '--hexadecimal 1638', - '--kwery None', - '--l 9223372032559808512', - "--letters ['a', 'b', 'c']", - '--logger_levels {}', - "--m ['str1', 'str2']", - "--m_str ['str1', 'str2']", - '--name giants', - '--no?', - '--nodebug', - '--noexec', - '--nohelp', - '--nohelpfull', - '--nohelpshort', - '--nohelpxml', - '--nologtostderr', - '--noonly_check_args', - '--nopdb_post_mortem', - '--noq', - '--norun_with_pdb', - '--norun_with_profiling', - '--notest0', - '--notestget2', - '--notestget3', - '--notestnone', - '--numbers [1, 2, 3]', - '--octal 438', - '--only_once singlevalue', - '--pdb False', - '--profile_file None', - '--quack', - '--repeat 3', - "--s ['sing1']", - "--s_str ['sing1']", - '--sense None', - '--showprefixforinfo', - '--stderrthreshold fatal', - '--test1', - '--test_random_seed 301', - '--test_randomize_ordering_seed ', - '--testcomma_list []', - '--testget1', - '--testget4 None', - '--testspace_list []', - '--testspace_or_comma_list []', - '--tmod_baz_x', - '--universe ptolemaic', - '--use_cprofile_for_profiling', - '--v -1', - '--verbosity -1', - '--x 10', - '--xml_output_file ', - ], args_list()) + self.assertListEqual( + [ + '--alsologtostderr', + "--args ['v=1', 'vmodule=a=0,b=2']", + '--blah None', + '--cases None', + '--decimal 666', + '--float 3.14', + '--funny None', + '--hexadecimal 1638', + '--kwery None', + '--l 9223372032559808512', + "--letters ['a', 'b', 'c']", + "--list_default_list ['a', 'b', 'c']", + '--logger_levels {}', + "--m ['str1', 'str2']", + "--m_str ['str1', 'str2']", + '--name giants', + '--no?', + '--nodebug', + '--noexec', + '--nohelp', + '--nohelpfull', + '--nohelpshort', + '--nohelpxml', + '--nologtostderr', + '--noonly_check_args', + '--nopdb_post_mortem', + '--noq', + '--norun_with_pdb', + '--norun_with_profiling', + '--notest0', + '--notestget2', + '--notestget3', + '--notestnone', + '--octal 438', + '--only_once singlevalue', + '--pdb False', + '--profile_file None', + '--quack', + '--repeat 3', + "--s ['sing1']", + "--s_str ['sing1']", + '--sense None', + '--showprefixforinfo', + '--stderrthreshold fatal', + '--test1', + '--test_random_seed 301', + '--test_randomize_ordering_seed ', + '--testcomma_list []', + '--testget1', + '--testget4 None', + '--testspace_list []', + '--testspace_or_comma_list []', + '--tmod_baz_x', + '--universe ptolemaic', + '--use_cprofile_for_profiling', + '--v -1', + '--verbosity -1', + '--x 10', + '--xml_output_file ', + ], + args_list(), + ) argv = ('./program', '--debug', '--m_str=upd1', '-s', 'upd2') FLAGS(argv) @@ -857,67 +864,70 @@ class FlagsUnitTest(absltest.TestCase): # items appended to existing non-default value lists for --m/--m_str # new value overwrites default value (not appended to it) for --s/--s_str - self.assertListEqual([ - '--alsologtostderr', - "--args ['v=1', 'vmodule=a=0,b=2']", - '--blah None', - '--cases None', - '--debug', - '--decimal 666', - '--float 3.14', - '--funny None', - '--hexadecimal 1638', - '--kwery None', - '--l 9223372032559808512', - "--letters ['a', 'b', 'c']", - '--logger_levels {}', - "--m ['str1', 'str2', 'upd1']", - "--m_str ['str1', 'str2', 'upd1']", - '--name giants', - '--no?', - '--noexec', - '--nohelp', - '--nohelpfull', - '--nohelpshort', - '--nohelpxml', - '--nologtostderr', - '--noonly_check_args', - '--nopdb_post_mortem', - '--noq', - '--norun_with_pdb', - '--norun_with_profiling', - '--notest0', - '--notestget2', - '--notestget3', - '--notestnone', - '--numbers [1, 2, 3]', - '--octal 438', - '--only_once singlevalue', - '--pdb False', - '--profile_file None', - '--quack', - '--repeat 3', - "--s ['sing1', 'upd2']", - "--s_str ['sing1', 'upd2']", - '--sense None', - '--showprefixforinfo', - '--stderrthreshold fatal', - '--test1', - '--test_random_seed 301', - '--test_randomize_ordering_seed ', - '--testcomma_list []', - '--testget1', - '--testget4 None', - '--testspace_list []', - '--testspace_or_comma_list []', - '--tmod_baz_x', - '--universe ptolemaic', - '--use_cprofile_for_profiling', - '--v -1', - '--verbosity -1', - '--x 10', - '--xml_output_file ', - ], args_list()) + self.assertListEqual( + [ + '--alsologtostderr', + "--args ['v=1', 'vmodule=a=0,b=2']", + '--blah None', + '--cases None', + '--debug', + '--decimal 666', + '--float 3.14', + '--funny None', + '--hexadecimal 1638', + '--kwery None', + '--l 9223372032559808512', + "--letters ['a', 'b', 'c']", + "--list_default_list ['a', 'b', 'c']", + '--logger_levels {}', + "--m ['str1', 'str2', 'upd1']", + "--m_str ['str1', 'str2', 'upd1']", + '--name giants', + '--no?', + '--noexec', + '--nohelp', + '--nohelpfull', + '--nohelpshort', + '--nohelpxml', + '--nologtostderr', + '--noonly_check_args', + '--nopdb_post_mortem', + '--noq', + '--norun_with_pdb', + '--norun_with_profiling', + '--notest0', + '--notestget2', + '--notestget3', + '--notestnone', + '--octal 438', + '--only_once singlevalue', + '--pdb False', + '--profile_file None', + '--quack', + '--repeat 3', + "--s ['sing1', 'upd2']", + "--s_str ['sing1', 'upd2']", + '--sense None', + '--showprefixforinfo', + '--stderrthreshold fatal', + '--test1', + '--test_random_seed 301', + '--test_randomize_ordering_seed ', + '--testcomma_list []', + '--testget1', + '--testget4 None', + '--testspace_list []', + '--testspace_or_comma_list []', + '--tmod_baz_x', + '--universe ptolemaic', + '--use_cprofile_for_profiling', + '--v -1', + '--verbosity -1', + '--x 10', + '--xml_output_file ', + ], + args_list(), + ) #################################### # Test all kind of error conditions. @@ -993,7 +1003,7 @@ class FlagsUnitTest(absltest.TestCase): # to be raised. try: sys.modules.pop('absl.flags.tests.module_baz') - import absl.flags.tests.module_baz + import absl.flags.tests.module_baz # pylint: disable=g-import-not-at-top del absl except flags.DuplicateFlagError: raise AssertionError('Module reimport caused flag duplication error') @@ -1236,6 +1246,9 @@ class FlagsUnitTest(absltest.TestCase): --letters: a list of letters (default: 'a,b,c') (a comma separated list) + --list_default_list: with default being a list of strings + (default: 'a,b,c') + (a comma separated list) -m,--m_str: string option that can occur multiple times; repeat this option to specify a list of values (default: "['def1', 'def2']") @@ -1243,9 +1256,6 @@ class FlagsUnitTest(absltest.TestCase): (default: 'Bob') --[no]noexec: boolean flag with no as prefix (default: 'true') - --numbers: a list of numbers - (default: '1,2,3') - (a comma separated list) --octal: using octals (default: '438') (an integer) @@ -1290,16 +1300,16 @@ class FlagsUnitTest(absltest.TestCase): def test_string_flag_with_wrong_type(self): fv = flags.FlagValues() with self.assertRaises(flags.IllegalFlagValueError): - flags.DEFINE_string('name', False, 'help', flag_values=fv) + flags.DEFINE_string('name', False, 'help', flag_values=fv) # type: ignore with self.assertRaises(flags.IllegalFlagValueError): - flags.DEFINE_string('name2', 0, 'help', flag_values=fv) + flags.DEFINE_string('name2', 0, 'help', flag_values=fv) # type: ignore def test_integer_flag_with_wrong_type(self): fv = flags.FlagValues() with self.assertRaises(flags.IllegalFlagValueError): - flags.DEFINE_integer('name', 1e2, 'help', flag_values=fv) + flags.DEFINE_integer('name', 1e2, 'help', flag_values=fv) # type: ignore with self.assertRaises(flags.IllegalFlagValueError): - flags.DEFINE_integer('name', [], 'help', flag_values=fv) + flags.DEFINE_integer('name', [], 'help', flag_values=fv) # type: ignore with self.assertRaises(flags.IllegalFlagValueError): flags.DEFINE_integer('name', False, 'help', flag_values=fv) @@ -1313,6 +1323,16 @@ class FlagsUnitTest(absltest.TestCase): with self.assertRaises(ValueError): flags.DEFINE_enum('fruit', None, [], 'help', flag_values=fv) + def test_enum_flag_with_str_values(self): + fv = flags.FlagValues() + with self.assertRaises(ValueError): + flags.DEFINE_enum('fruit', None, 'option', 'help', flag_values=fv) # type: ignore + + def test_multi_enum_flag_with_str_values(self): + fv = flags.FlagValues() + with self.assertRaises(ValueError): + flags.DEFINE_multi_enum('fruit', None, 'option', 'help', flag_values=fv) # type: ignore + def test_define_enum_class_flag(self): fv = flags.FlagValues() flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv) @@ -1351,13 +1371,14 @@ class FlagsUnitTest(absltest.TestCase): def test_enum_class_flag_with_wrong_default_value_type(self): fv = flags.FlagValues() with self.assertRaises(_exceptions.IllegalFlagValueError): - flags.DEFINE_enum_class('fruit', 1, Fruit, 'help', flag_values=fv) + flags.DEFINE_enum_class('fruit', 1, Fruit, 'help', flag_values=fv) # type: ignore def test_enum_class_flag_requires_enum_class(self): fv = flags.FlagValues() with self.assertRaises(TypeError): - flags.DEFINE_enum_class( - 'fruit', None, ['apple', 'orange'], 'help', flag_values=fv) + flags.DEFINE_enum_class( # type: ignore + 'fruit', None, ['apple', 'orange'], 'help', flag_values=fv + ) def test_enum_class_flag_requires_non_empty_enum_class(self): fv = flags.FlagValues() @@ -2491,7 +2512,7 @@ class NonGlobalFlagsTest(absltest.TestCase): def test_flag_definition_via_setitem(self): with self.assertRaises(flags.IllegalFlagValueError): flag_values = flags.FlagValues() - flag_values['flag_name'] = 'flag_value' + flag_values['flag_name'] = 'flag_value' # type: ignore class SetDefaultTest(absltest.TestCase): @@ -2545,7 +2566,7 @@ class SetDefaultTest(absltest.TestCase): self.flag_values.mark_as_parsed() with self.assertRaises(flags.IllegalFlagValueError): - flags.set_default(int_holder, 'a') + flags.set_default(int_holder, 'a') # type: ignore def test_failure_on_type_protected_none_default(self): int_holder = flags.DEFINE_integer( @@ -2553,7 +2574,7 @@ class SetDefaultTest(absltest.TestCase): self.flag_values.mark_as_parsed() - flags.set_default(int_holder, None) # NOTE: should be a type failure + flags.set_default(int_holder, None) # type: ignore with self.assertRaises(flags.IllegalFlagValueError): _ = int_holder.value # Will also fail on later access. diff --git a/absl/logging/__init__.py b/absl/logging/__init__.py index 494d782..e8177d3 100644 --- a/absl/logging/__init__.py +++ b/absl/logging/__init__.py @@ -98,11 +98,13 @@ import warnings from absl import flags from absl.logging import converter +# pylint: disable=g-import-not-at-top try: from typing import NoReturn except ImportError: pass +# pylint: enable=g-import-not-at-top FLAGS = flags.FLAGS @@ -295,44 +297,76 @@ class _StderrthresholdFlag(flags.Flag): self._value = v -flags.DEFINE_boolean('logtostderr', - False, - 'Should only log to stderr?', allow_override_cpp=True) -flags.DEFINE_boolean('alsologtostderr', - False, - 'also log to stderr?', allow_override_cpp=True) -flags.DEFINE_string('log_dir', - os.getenv('TEST_TMPDIR', ''), - 'directory to write logfiles into', - allow_override_cpp=True) -flags.DEFINE_flag(_VerbosityFlag( - 'verbosity', -1, - 'Logging verbosity level. Messages logged at this level or lower will ' - 'be included. Set to 1 for debug logging. If the flag was not set or ' - 'supplied, the value will be changed from the default of -1 (warning) to ' - '0 (info) after flags are parsed.', - short_name='v', allow_hide_cpp=True)) -flags.DEFINE_flag( +LOGTOSTDERR = flags.DEFINE_boolean( + 'logtostderr', + False, + 'Should only log to stderr?', + allow_override_cpp=True, +) +ALSOLOGTOSTDERR = flags.DEFINE_boolean( + 'alsologtostderr', + False, + 'also log to stderr?', + allow_override_cpp=True, +) +LOG_DIR = flags.DEFINE_string( + 'log_dir', + os.getenv('TEST_TMPDIR', ''), + 'directory to write logfiles into', + allow_override_cpp=True, +) +VERBOSITY = flags.DEFINE_flag( + _VerbosityFlag( + 'verbosity', + -1, + ( + 'Logging verbosity level. Messages logged at this level or lower' + ' will be included. Set to 1 for debug logging. If the flag was not' + ' set or supplied, the value will be changed from the default of -1' + ' (warning) to 0 (info) after flags are parsed.' + ), + short_name='v', + allow_hide_cpp=True, + ) +) +LOGGER_LEVELS = flags.DEFINE_flag( _LoggerLevelsFlag( - 'logger_levels', {}, - 'Specify log level of loggers. The format is a CSV list of ' - '`name:level`. Where `name` is the logger name used with ' - '`logging.getLogger()`, and `level` is a level name (INFO, DEBUG, ' - 'etc). e.g. `myapp.foo:INFO,other.logger:DEBUG`')) -flags.DEFINE_flag(_StderrthresholdFlag( - 'stderrthreshold', 'fatal', - 'log messages at this level, or more severe, to stderr in ' - 'addition to the logfile. Possible values are ' - "'debug', 'info', 'warning', 'error', and 'fatal'. " - 'Obsoletes --alsologtostderr. Using --alsologtostderr ' - 'cancels the effect of this flag. Please also note that ' - 'this flag is subject to --verbosity and requires logfile ' - 'not be stderr.', allow_hide_cpp=True)) -flags.DEFINE_boolean('showprefixforinfo', True, - 'If False, do not prepend prefix to info messages ' - 'when it\'s logged to stderr, ' - '--verbosity is set to INFO level, ' - 'and python logging is used.') + 'logger_levels', + {}, + ( + 'Specify log level of loggers. The format is a CSV list of ' + '`name:level`. Where `name` is the logger name used with ' + '`logging.getLogger()`, and `level` is a level name (INFO, DEBUG, ' + 'etc). e.g. `myapp.foo:INFO,other.logger:DEBUG`' + ), + ) +) +STDERRTHRESHOLD = flags.DEFINE_flag( + _StderrthresholdFlag( + 'stderrthreshold', + 'fatal', + ( + 'log messages at this level, or more severe, to stderr in ' + 'addition to the logfile. Possible values are ' + "'debug', 'info', 'warning', 'error', and 'fatal'. " + 'Obsoletes --alsologtostderr. Using --alsologtostderr ' + 'cancels the effect of this flag. Please also note that ' + 'this flag is subject to --verbosity and requires logfile ' + 'not be stderr.' + ), + allow_hide_cpp=True, + ) +) +SHOWPREFIXFORINFO = flags.DEFINE_boolean( + 'showprefixforinfo', + True, + ( + 'If False, do not prepend prefix to info messages ' + "when it's logged to stderr, " + '--verbosity is set to INFO level, ' + 'and python logging is used.' + ), +) def get_verbosity(): diff --git a/absl/logging/__init__.pyi b/absl/logging/__init__.pyi new file mode 100644 index 0000000..5d5bb69 --- /dev/null +++ b/absl/logging/__init__.pyi @@ -0,0 +1,290 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, Dict, NoReturn, Optional, Tuple, TypeVar, Union + +from absl import flags + +# Logging levels. +FATAL: int +ERROR: int +WARNING: int +WARN: int # Deprecated name. +INFO: int +DEBUG: int + +ABSL_LOGGING_PREFIX_REGEX: str + +LOGTOSTDERR: flags.FlagHolder[bool] +ALSOLOGTOSTDERR: flags.FlagHolder[bool] +LOG_DIR: flags.FlagHolder[str] +VERBOSITY: flags.FlagHolder[int] +LOGGER_LEVELS: flags.FlagHolder[Dict[str, str]] +STDERRTHRESHOLD: flags.FlagHolder[str] +SHOWPREFIXFORINFO: flags.FlagHolder[bool] + + +def get_verbosity() -> int: + ... + + +def set_verbosity(v: Union[int, str]) -> None: + ... + + +def set_stderrthreshold(s: Union[int, str]) -> None: + ... + + +# TODO(b/277607978): Provide actual args+kwargs shadowing stdlib's logging functions. +def fatal(msg: Any, *args: Any, **kwargs: Any) -> NoReturn: + ... + + +def error(msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + +def warning(msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + +def warn(msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + +def info(msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + +def debug(msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + +def exception(msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + +def log_every_n(level: int, msg: Any, n: int, *args: Any) -> None: + ... + + +def log_every_n_seconds( + level: int, msg: Any, n_seconds: float, *args: Any +) -> None: + ... + + +def log_first_n(level: int, msg: Any, n: int, *args: Any) -> None: + ... + + +def log_if(level: int, msg: Any, condition: Any, *args: Any) -> None: + ... + + +def log(level: int, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + +def vlog(level: int, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + +def vlog_is_on(level: int) -> bool: + ... + + +def flush() -> None: + ... + + +def level_debug() -> bool: + ... + + +def level_info() -> bool: + ... + + +def level_warning() -> bool: + ... + + +level_warn = level_warning # Deprecated function. + + +def level_error() -> bool: + ... + + +def get_log_file_name(level: int = ...) -> str: + ... + + +def find_log_dir_and_names( + program_name: Optional[str] = ..., log_dir: Optional[str] = ... +) -> Tuple[str, str, str]: + ... + + +def find_log_dir(log_dir: Optional[str] = ...) -> str: + ... + + +def get_absl_log_prefix(record: logging.LogRecord) -> str: + ... + + +_SkipLogT = TypeVar('_SkipLogT', str, Callable[..., Any]) + +def skip_log_prefix(func: _SkipLogT) -> _SkipLogT: + ... + + +_StreamT = TypeVar("_StreamT") + + +class PythonHandler(logging.StreamHandler[_StreamT]): + + def __init__( + self, + stream: Optional[_StreamT] = ..., + formatter: Optional[logging.Formatter] = ..., + ) -> None: + ... + + def start_logging_to_file( + self, program_name: Optional[str] = ..., log_dir: Optional[str] = ... + ) -> None: + ... + + def use_absl_log_file( + self, program_name: Optional[str] = ..., log_dir: Optional[str] = ... + ) -> None: + ... + + def flush(self) -> None: + ... + + def emit(self, record: logging.LogRecord) -> None: + ... + + def close(self) -> None: + ... + + +class ABSLHandler(logging.Handler): + + def __init__(self, python_logging_formatter: PythonFormatter) -> None: + ... + + def format(self, record: logging.LogRecord) -> str: + ... + + def setFormatter(self, fmt) -> None: + ... + + def emit(self, record: logging.LogRecord) -> None: + ... + + def flush(self) -> None: + ... + + def close(self) -> None: + ... + + def handle(self, record: logging.LogRecord) -> bool: + ... + + @property + def python_handler(self) -> PythonHandler: + ... + + def activate_python_handler(self) -> None: + ... + + def use_absl_log_file( + self, program_name: Optional[str] = ..., log_dir: Optional[str] = ... + ) -> None: + ... + + def start_logging_to_file(self, program_name=None, log_dir=None) -> None: + ... + + +class PythonFormatter(logging.Formatter): + + def format(self, record: logging.LogRecord) -> str: + ... + + +class ABSLLogger(logging.Logger): + + def findCaller( + self, stack_info: bool = ..., stacklevel: int = ... + ) -> Tuple[str, int, str, Optional[str]]: + ... + + def critical(self, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + def fatal(self, msg: Any, *args: Any, **kwargs: Any) -> NoReturn: + ... + + def error(self, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + def warn(self, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + def warning(self, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + def info(self, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + def debug(self, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + def log(self, level: int, msg: Any, *args: Any, **kwargs: Any) -> None: + ... + + def handle(self, record: logging.LogRecord) -> None: + ... + + @classmethod + def register_frame_to_skip( + cls, file_name: str, function_name: str, line_number: Optional[int] = ... + ) -> None: + ... + + +# NOTE: Returns None before _initialize called but shouldn't occur after import. +def get_absl_logger() -> ABSLLogger: + ... + + +# NOTE: Returns None before _initialize called but shouldn't occur after import. +def get_absl_handler() -> ABSLHandler: + ... + + +def use_python_logging(quiet: bool = ...) -> None: + ... + + +def use_absl_handler() -> None: + ... diff --git a/absl/logging/tests/verbosity_flag_test.py b/absl/logging/tests/verbosity_flag_test.py index ea9944d..44a6034 100644 --- a/absl/logging/tests/verbosity_flag_test.py +++ b/absl/logging/tests/verbosity_flag_test.py @@ -27,9 +27,11 @@ assert logging.root.getEffectiveLevel() == logging.ERROR, ( 'logging.root level should be changed to ERROR, but found {}'.format( logging.root.getEffectiveLevel())) +# pylint: disable=g-import-not-at-top from absl import flags from absl import logging as _ # pylint: disable=unused-import from absl.testing import absltest +# pylint: enable=g-import-not-at-top FLAGS = flags.FLAGS diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py index b977109..7d2b930 100644 --- a/absl/testing/absltest.py +++ b/absl/testing/absltest.py @@ -23,6 +23,7 @@ import contextlib import difflib import enum import errno +import faulthandler import getpass import inspect import io @@ -43,22 +44,13 @@ import unittest from unittest import mock # pylint: disable=unused-import Allow absltest.mock. from urllib import parse -try: - # The faulthandler module isn't always available, and pytype doesn't - # understand that we're catching ImportError, so suppress the error. - # pytype: disable=import-error - import faulthandler - # pytype: enable=import-error -except ImportError: - # We use faulthandler if it is available. - faulthandler = None - -from absl import app +from absl import app # pylint: disable=g-import-not-at-top from absl import flags from absl import logging from absl.testing import _pretty_print_reporter from absl.testing import xml_reporter +# pylint: disable=g-import-not-at-top # Make typing an optional import to avoid it being a required dependency # in Python 2. Type checkers will still understand the imports. try: @@ -79,6 +71,7 @@ else: _OutcomeType = unittest.case._Outcome # pytype: disable=module-attr +# pylint: enable=g-import-not-at-top # Re-export a bunch of unittest functions we support so that people don't # have to import unittest to get them @@ -2093,7 +2086,7 @@ def _is_in_app_main(): def _register_sigterm_with_faulthandler(): # type: () -> None """Have faulthandler dump stacks on SIGTERM. Useful to diagnose timeouts.""" - if faulthandler and getattr(faulthandler, 'register', None): + if getattr(faulthandler, 'register', None): # faulthandler.register is not available on Windows. # faulthandler.enable() is already called by app.run. try: diff --git a/absl/testing/tests/absltest_py3_test.py b/absl/testing/tests/absltest_py3_test.py deleted file mode 100644 index 7c5f500..0000000 --- a/absl/testing/tests/absltest_py3_test.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2020 The Abseil Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Python3-only Tests for absltest.""" - -from absl.testing import absltest - - -class GetTestCaseNamesPEP3102Test(absltest.TestCase): - """This test verifies absltest.TestLoader.GetTestCasesNames PEP3102 support. - - The test is Python3 only, as keyword only arguments are considered - syntax error in Python2. - - The rest of getTestCaseNames functionality is covered - by absltest_test.TestLoaderTest. - """ - - class Valid(absltest.TestCase): - - def testKeywordOnly(self, *, arg): - pass - - def setUp(self): - self.loader = absltest.TestLoader() - super(GetTestCaseNamesPEP3102Test, self).setUp() - - def test_PEP3102_get_test_case_names(self): - self.assertCountEqual( - self.loader.getTestCaseNames(GetTestCaseNamesPEP3102Test.Valid), - ["testKeywordOnly"]) - -if __name__ == "__main__": - absltest.main() diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py index 531da92..00bd2bf 100644 --- a/absl/testing/tests/absltest_test.py +++ b/absl/testing/tests/absltest_test.py @@ -1468,14 +1468,6 @@ test case class GetCommandStderrTestCase(absltest.TestCase): - def setUp(self): - super(GetCommandStderrTestCase, self).setUp() - self.original_environ = os.environ.copy() - - def tearDown(self): - super(GetCommandStderrTestCase, self).tearDown() - os.environ = self.original_environ - def test_return_status(self): tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value) returncode = ( @@ -1961,6 +1953,9 @@ class TestLoaderTest(absltest.TestCase): def TestHelperWithDefaults(self, a=5): pass + def TestHelperWithKeywordOnly(self, *, arg): + pass + class Invalid(absltest.TestCase): """Test case containing a suspicious method.""" diff --git a/absl/tests/app_test_helper.py b/absl/tests/app_test_helper.py index f9fbdec..92f7be3 100644 --- a/absl/tests/app_test_helper.py +++ b/absl/tests/app_test_helper.py @@ -18,11 +18,11 @@ import os import sys try: - import faulthandler + import faulthandler # pylint: disable=g-import-not-at-top except ImportError: faulthandler = None -from absl import app +from absl import app # pylint: disable=g-import-not-at-top from absl import flags FLAGS = flags.FLAGS diff --git a/ci/run_tests.sh b/ci/run_tests.sh new file mode 100755 index 0000000..99de7cd --- /dev/null +++ b/ci/run_tests.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Fail on any error. Treat unset variables an error. Print commands as executed. +set -eux + +# Log environment variables. +env + +# Let the script continue even if "bazel test" fails, so that all tests are +# always executed. +exit_code=0 + +# Log the bazel version for easier debugging. +bazel version +bazel test --test_output=errors absl/... || exit_code=$? +if [[ ! -z "${ABSL_EXPECTED_PYTHON_VERSION}" ]]; then + bazel test \ + --test_output=errors absl:tests/python_version_test \ + --test_arg=--expected_version="${ABSL_EXPECTED_PYTHON_VERSION}" || exit_code=$? +fi + +if [[ ! -z "${ABSL_COPY_TESTLOGS_TO}" ]]; then + mkdir -p "${ABSL_COPY_TESTLOGS_TO}" + readonly testlogs_dir=$(bazel info bazel-testlogs) + echo "Copying bazel test logs from ${testlogs_dir} to ${ABSL_COPY_TESTLOGS_TO}..." + cp -r "${testlogs_dir}" "${ABSL_COPY_TESTLOGS_TO}" || exit_code=$? +fi + +# TODO(yileiyang): Update and run smoke_test.sh. + +exit $exit_code @@ -17,12 +17,14 @@ import os import sys +# pylint: disable=g-import-not-at-top try: import setuptools except ImportError: from ez_setup import use_setuptools use_setuptools() import setuptools +# pylint: enable=g-import-not-at-top if sys.version_info < (3, 7): raise RuntimeError('Python version 3.7+ is required.') |