diff options
author | Abseil Team <absl-team@google.com> | 2022-12-02 12:02:17 -0800 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2022-12-02 12:03:00 -0800 |
commit | b568abfd3be3480f451aec4923ada17d3acc6734 (patch) | |
tree | a5a205fb4ff11b1c232a980909d8c889c53b5cf2 | |
parent | 916113a2ec897568959dbad330ec77939059fe27 (diff) | |
download | absl-py-b568abfd3be3480f451aec4923ada17d3acc6734.tar.gz |
Add type hinting to the flagsaver module.
PiperOrigin-RevId: 492525537
Change-Id: Id30f2c466c3ea798f9207346d38bdfd754521c08
-rw-r--r-- | absl/flags/_flag.pyi | 3 | ||||
-rw-r--r-- | absl/testing/flagsaver.py | 50 |
2 files changed, 36 insertions, 17 deletions
diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi index 9b4a3d3..3506644 100644 --- a/absl/flags/_flag.pyi +++ b/absl/flags/_flag.pyi @@ -20,7 +20,7 @@ import functools from absl.flags import _argument_parser import enum -from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence +from typing import Callable, Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence _T = TypeVar('_T') _ET = TypeVar('_ET', bound=enum.Enum) @@ -44,6 +44,7 @@ class Flag(Generic[_T]): using_default_value = ... # type: bool allow_overwrite = ... # type: bool allow_using_method_names = ... # type: bool + validators = ... # type: List[Callable[[Any], bool]] def __init__(self, parser: _argument_parser.ArgumentParser[_T], diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py index 37926d7..774c698 100644 --- a/absl/testing/flagsaver.py +++ b/absl/testing/flagsaver.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Decorator and context manager for saving and restoring flag values. There are many ways to save and restore. Always use the most convenient method @@ -61,11 +60,26 @@ and then restore flag values, the added flag will be deleted with no errors. import functools import inspect +from typing import overload, Any, Callable, Mapping, Tuple, TypeVar from absl import flags FLAGS = flags.FLAGS +# The type of pre/post wrapped functions. +_CallableT = TypeVar('_CallableT', bound=Callable) + + +@overload +def flagsaver(*args: Tuple[flags.FlagHolder, Any], + **kwargs: Any) -> '_FlagOverrider': + ... + + +@overload +def flagsaver(func: _CallableT) -> _CallableT: + ... + def flagsaver(*args, **kwargs): """The main flagsaver interface. See module doc for usage.""" @@ -94,12 +108,14 @@ def flagsaver(*args, **kwargs): return _FlagOverrider(**kwargs) -def save_flag_values(flag_values=FLAGS): +def save_flag_values( + flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]: """Returns copy of flag values as a dict. Args: - flag_values: FlagValues, the FlagValues instance with which the flag will - be saved. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance with which the flag will be + saved. This should almost never need to be overridden. + Returns: Dictionary mapping keys to values. Keys are flag names, values are corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``. @@ -107,13 +123,14 @@ def save_flag_values(flag_values=FLAGS): return {name: _copy_flag_dict(flag_values[name]) for name in flag_values} -def restore_flag_values(saved_flag_values, flag_values=FLAGS): +def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]], + flag_values: flags.FlagValues = FLAGS): """Restores flag values based on the dictionary of flag values. Args: saved_flag_values: {'flag_name': value_dict, ...} - flag_values: FlagValues, the FlagValues instance from which the flag will - be restored. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance from which the flag will be + restored. This should almost never need to be overridden. """ new_flag_names = list(flag_values) for name in new_flag_names: @@ -127,23 +144,24 @@ def restore_flag_values(saved_flag_values, flag_values=FLAGS): flag_values[name].__dict__ = saved -def _wrap(func, overrides): +def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT: """Creates a wrapper function that saves/restores flag values. Args: - func: function object - This will be called between saving flags and - restoring flags. - overrides: {str: object} - Flag names mapped to their values. These flags - will be set after saving the original flag state. + func: This will be called between saving flags and restoring flags. + overrides: Flag names mapped to their values. These flags will be set after + saving the original flag state. Returns: - return value from func() + A wrapped version of func. """ + @functools.wraps(func) def _flagsaver_wrapper(*args, **kwargs): """Wrapper function that saves and restores flags.""" with _FlagOverrider(**overrides): return func(*args, **kwargs) + return _flagsaver_wrapper @@ -154,11 +172,11 @@ class _FlagOverrider(object): completes. """ - def __init__(self, **overrides): + def __init__(self, **overrides: Any): self._overrides = overrides self._saved_flag_values = None - def __call__(self, func): + def __call__(self, func: _CallableT) -> _CallableT: if inspect.isclass(func): raise TypeError('flagsaver cannot be applied to a class.') return _wrap(func, self._overrides) @@ -176,7 +194,7 @@ class _FlagOverrider(object): restore_flag_values(self._saved_flag_values, FLAGS) -def _copy_flag_dict(flag): +def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]: """Returns a copy of the flag object's ``__dict__``. It's mostly a shallow copy of the ``__dict__``, except it also does a shallow |