aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2022-12-02 12:02:17 -0800
committerCopybara-Service <copybara-worker@google.com>2022-12-02 12:03:00 -0800
commitb568abfd3be3480f451aec4923ada17d3acc6734 (patch)
treea5a205fb4ff11b1c232a980909d8c889c53b5cf2
parent916113a2ec897568959dbad330ec77939059fe27 (diff)
downloadabsl-py-b568abfd3be3480f451aec4923ada17d3acc6734.tar.gz
Add type hinting to the flagsaver module.
PiperOrigin-RevId: 492525537 Change-Id: Id30f2c466c3ea798f9207346d38bdfd754521c08
-rw-r--r--absl/flags/_flag.pyi3
-rw-r--r--absl/testing/flagsaver.py50
2 files changed, 36 insertions, 17 deletions
diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi
index 9b4a3d3..3506644 100644
--- a/absl/flags/_flag.pyi
+++ b/absl/flags/_flag.pyi
@@ -20,7 +20,7 @@ import functools
from absl.flags import _argument_parser
import enum
-from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence
+from typing import Callable, Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence
_T = TypeVar('_T')
_ET = TypeVar('_ET', bound=enum.Enum)
@@ -44,6 +44,7 @@ class Flag(Generic[_T]):
using_default_value = ... # type: bool
allow_overwrite = ... # type: bool
allow_using_method_names = ... # type: bool
+ validators = ... # type: List[Callable[[Any], bool]]
def __init__(self,
parser: _argument_parser.ArgumentParser[_T],
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
index 37926d7..774c698 100644
--- a/absl/testing/flagsaver.py
+++ b/absl/testing/flagsaver.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Decorator and context manager for saving and restoring flag values.
There are many ways to save and restore. Always use the most convenient method
@@ -61,11 +60,26 @@ and then restore flag values, the added flag will be deleted with no errors.
import functools
import inspect
+from typing import overload, Any, Callable, Mapping, Tuple, TypeVar
from absl import flags
FLAGS = flags.FLAGS
+# The type of pre/post wrapped functions.
+_CallableT = TypeVar('_CallableT', bound=Callable)
+
+
+@overload
+def flagsaver(*args: Tuple[flags.FlagHolder, Any],
+ **kwargs: Any) -> '_FlagOverrider':
+ ...
+
+
+@overload
+def flagsaver(func: _CallableT) -> _CallableT:
+ ...
+
def flagsaver(*args, **kwargs):
"""The main flagsaver interface. See module doc for usage."""
@@ -94,12 +108,14 @@ def flagsaver(*args, **kwargs):
return _FlagOverrider(**kwargs)
-def save_flag_values(flag_values=FLAGS):
+def save_flag_values(
+ flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]:
"""Returns copy of flag values as a dict.
Args:
- flag_values: FlagValues, the FlagValues instance with which the flag will
- be saved. This should almost never need to be overridden.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ saved. This should almost never need to be overridden.
+
Returns:
Dictionary mapping keys to values. Keys are flag names, values are
corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``.
@@ -107,13 +123,14 @@ def save_flag_values(flag_values=FLAGS):
return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
-def restore_flag_values(saved_flag_values, flag_values=FLAGS):
+def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
+ flag_values: flags.FlagValues = FLAGS):
"""Restores flag values based on the dictionary of flag values.
Args:
saved_flag_values: {'flag_name': value_dict, ...}
- flag_values: FlagValues, the FlagValues instance from which the flag will
- be restored. This should almost never need to be overridden.
+ flag_values: FlagValues, the FlagValues instance from which the flag will be
+ restored. This should almost never need to be overridden.
"""
new_flag_names = list(flag_values)
for name in new_flag_names:
@@ -127,23 +144,24 @@ def restore_flag_values(saved_flag_values, flag_values=FLAGS):
flag_values[name].__dict__ = saved
-def _wrap(func, overrides):
+def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT:
"""Creates a wrapper function that saves/restores flag values.
Args:
- func: function object - This will be called between saving flags and
- restoring flags.
- overrides: {str: object} - Flag names mapped to their values. These flags
- will be set after saving the original flag state.
+ func: This will be called between saving flags and restoring flags.
+ overrides: Flag names mapped to their values. These flags will be set after
+ saving the original flag state.
Returns:
- return value from func()
+ A wrapped version of func.
"""
+
@functools.wraps(func)
def _flagsaver_wrapper(*args, **kwargs):
"""Wrapper function that saves and restores flags."""
with _FlagOverrider(**overrides):
return func(*args, **kwargs)
+
return _flagsaver_wrapper
@@ -154,11 +172,11 @@ class _FlagOverrider(object):
completes.
"""
- def __init__(self, **overrides):
+ def __init__(self, **overrides: Any):
self._overrides = overrides
self._saved_flag_values = None
- def __call__(self, func):
+ def __call__(self, func: _CallableT) -> _CallableT:
if inspect.isclass(func):
raise TypeError('flagsaver cannot be applied to a class.')
return _wrap(func, self._overrides)
@@ -176,7 +194,7 @@ class _FlagOverrider(object):
restore_flag_values(self._saved_flag_values, FLAGS)
-def _copy_flag_dict(flag):
+def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
"""Returns a copy of the flag object's ``__dict__``.
It's mostly a shallow copy of the ``__dict__``, except it also does a shallow