aboutsummaryrefslogtreecommitdiff
path: root/absl/testing/flagsaver.py
diff options
context:
space:
mode:
Diffstat (limited to 'absl/testing/flagsaver.py')
-rw-r--r--absl/testing/flagsaver.py232
1 files changed, 212 insertions, 20 deletions
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
index 37926d7..e96c8c5 100644
--- a/absl/testing/flagsaver.py
+++ b/absl/testing/flagsaver.py
@@ -50,6 +50,36 @@ Here are examples of each method. They all call ``do_stuff()`` while
finally:
flagsaver.restore_flag_values(saved_flag_values)
+ # Use the parsing version to emulate users providing the flags.
+ # Note that all flags must be provided as strings (unparsed).
+ @flagsaver.as_parsed(some_int_flag='123')
+ def some_func():
+ # Because the flag was parsed it is considered "present".
+ assert FLAGS.some_int_flag.present
+ do_stuff()
+
+ # flagsaver.as_parsed() can also be used as a context manager just like
+ # flagsaver.flagsaver()
+ with flagsaver.as_parsed(some_int_flag='123'):
+ do_stuff()
+
+ # The flagsaver.as_parsed() interface also supports FlagHolder objects.
+ @flagsaver.as_parsed((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, '23'))
+ def some_func():
+ do_stuff()
+
+ # Using as_parsed with a multi_X flag requires a sequence of strings.
+ @flagsaver.as_parsed(some_multi_int_flag=['123', '456'])
+ def some_func():
+ assert FLAGS.some_multi_int_flag.present
+ do_stuff()
+
+ # If a flag name includes non-identifier characters it can be specified like
+ # so:
+ @flagsaver.as_parsed(**{'i-like-dashes': 'true'})
+ def some_func():
+ do_stuff()
+
We save and restore a shallow copy of each Flag object's ``__dict__`` attribute.
This preserves all attributes of the flag, such as whether or not it was
overridden from its default value.
@@ -59,18 +89,113 @@ exception will be raised. However if you *add* a flag after saving flag values,
and then restore flag values, the added flag will be deleted with no errors.
"""
+import collections
import functools
import inspect
+from typing import overload, Any, Callable, Mapping, Tuple, TypeVar, Type, Sequence, Union
from absl import flags
FLAGS = flags.FLAGS
+# The type of pre/post wrapped functions.
+_CallableT = TypeVar('_CallableT', bound=Callable)
+
+
+@overload
+def flagsaver(*args: Tuple[flags.FlagHolder, Any],
+ **kwargs: Any) -> '_FlagOverrider':
+ ...
+
+
+@overload
+def flagsaver(func: _CallableT) -> _CallableT:
+ ...
+
+
def flagsaver(*args, **kwargs):
"""The main flagsaver interface. See module doc for usage."""
+ return _construct_overrider(_FlagOverrider, *args, **kwargs)
+
+
+@overload
+def as_parsed(*args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
+ **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
+ ...
+
+
+@overload
+def as_parsed(func: _CallableT) -> _CallableT:
+ ...
+
+
+def as_parsed(*args, **kwargs):
+ """Overrides flags by parsing strings, saves flag state similar to flagsaver.
+
+ This function can be used as either a decorator or context manager similar to
+ flagsaver.flagsaver(). However, where flagsaver.flagsaver() directly sets the
+ flags to new values, this function will parse the provided arguments as if
+ they were provided on the command line. Among other things, this will cause
+ `FLAGS['flag_name'].parsed == True`.
+
+ A note on unparsed input: For many flag types, the unparsed version will be
+ a single string. However for multi_x (multi_string, multi_integer, multi_enum)
+ the unparsed version will be a Sequence of strings.
+
+ Args:
+ *args: Tuples of FlagHolders and their unparsed value.
+ **kwargs: The keyword args are flag names, and the values are unparsed
+ values.
+
+ Returns:
+ _ParsingFlagOverrider that serves as a context manager or decorator. Will
+ save previous flag state and parse new flags, then on cleanup it will
+ restore the previous flag state.
+ """
+ return _construct_overrider(_ParsingFlagOverrider, *args, **kwargs)
+
+
+# NOTE: the order of these overload declarations matters. The type checker will
+# pick the first match which could be incorrect.
+@overload
+def _construct_overrider(
+ flag_overrider_cls: Type['_ParsingFlagOverrider'],
+ *args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
+ **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
+ ...
+
+
+@overload
+def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
+ *args: Tuple[flags.FlagHolder, Any],
+ **kwargs: Any) -> '_FlagOverrider':
+ ...
+
+
+@overload
+def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
+ func: _CallableT) -> _CallableT:
+ ...
+
+
+def _construct_overrider(flag_overrider_cls, *args, **kwargs):
+ """Handles the args/kwargs returning an instance of flag_overrider_cls.
+
+ If flag_overrider_cls is _FlagOverrider then values should be native python
+ types matching the python types. Otherwise if flag_overrider_cls is
+ _ParsingFlagOverrider the values should be strings or sequences of strings.
+
+ Args:
+ flag_overrider_cls: The class that will do the overriding.
+ *args: Tuples of FlagHolder and the new flag value.
+ **kwargs: Keword args mapping flag name to new flag value.
+
+ Returns:
+ A _FlagOverrider to be used as a decorator or context manager.
+ """
if not args:
- return _FlagOverrider(**kwargs)
+ return flag_overrider_cls(**kwargs)
# args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)`
if len(args) == 1 and callable(args[0]):
if kwargs:
@@ -79,7 +204,7 @@ def flagsaver(*args, **kwargs):
func = args[0]
if inspect.isclass(func):
raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
- return _wrap(func, {})
+ return _wrap(flag_overrider_cls, func, {})
# args can be a list of (FlagHolder, value) pairs.
# In which case they augment any specified kwargs.
for arg in args:
@@ -91,15 +216,17 @@ def flagsaver(*args, **kwargs):
if holder.name in kwargs:
raise ValueError('Cannot set --%s multiple times' % holder.name)
kwargs[holder.name] = value
- return _FlagOverrider(**kwargs)
+ return flag_overrider_cls(**kwargs)
-def save_flag_values(flag_values=FLAGS):
+def save_flag_values(
+ flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]:
"""Returns copy of flag values as a dict.
Args:
- flag_values: FlagValues, the FlagValues instance with which the flag will
- be saved. This should almost never need to be overridden.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ saved. This should almost never need to be overridden.
+
Returns:
Dictionary mapping keys to values. Keys are flag names, values are
corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``.
@@ -107,13 +234,14 @@ def save_flag_values(flag_values=FLAGS):
return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
-def restore_flag_values(saved_flag_values, flag_values=FLAGS):
+def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
+ flag_values: flags.FlagValues = FLAGS):
"""Restores flag values based on the dictionary of flag values.
Args:
saved_flag_values: {'flag_name': value_dict, ...}
- flag_values: FlagValues, the FlagValues instance from which the flag will
- be restored. This should almost never need to be overridden.
+ flag_values: FlagValues, the FlagValues instance from which the flag will be
+ restored. This should almost never need to be overridden.
"""
new_flag_names = list(flag_values)
for name in new_flag_names:
@@ -127,23 +255,38 @@ def restore_flag_values(saved_flag_values, flag_values=FLAGS):
flag_values[name].__dict__ = saved
-def _wrap(func, overrides):
+@overload
+def _wrap(flag_overrider_cls: Type['_FlagOverrider'], func: _CallableT,
+ overrides: Mapping[str, Any]) -> _CallableT:
+ ...
+
+
+@overload
+def _wrap(flag_overrider_cls: Type['_ParsingFlagOverrider'], func: _CallableT,
+ overrides: Mapping[str, Union[str, Sequence[str]]]) -> _CallableT:
+ ...
+
+
+def _wrap(flag_overrider_cls, func, overrides):
"""Creates a wrapper function that saves/restores flag values.
Args:
- func: function object - This will be called between saving flags and
- restoring flags.
- overrides: {str: object} - Flag names mapped to their values. These flags
- will be set after saving the original flag state.
+ flag_overrider_cls: The class that will be used as a context manager.
+ func: This will be called between saving flags and restoring flags.
+ overrides: Flag names mapped to their values. These flags will be set after
+ saving the original flag state. The type of the values depends on if
+ _FlagOverrider or _ParsingFlagOverrider was specified.
Returns:
- return value from func()
+ A wrapped version of func.
"""
+
@functools.wraps(func)
def _flagsaver_wrapper(*args, **kwargs):
"""Wrapper function that saves and restores flags."""
- with _FlagOverrider(**overrides):
+ with flag_overrider_cls(**overrides):
return func(*args, **kwargs)
+
return _flagsaver_wrapper
@@ -154,14 +297,14 @@ class _FlagOverrider(object):
completes.
"""
- def __init__(self, **overrides):
+ def __init__(self, **overrides: Any):
self._overrides = overrides
self._saved_flag_values = None
- def __call__(self, func):
+ def __call__(self, func: _CallableT) -> _CallableT:
if inspect.isclass(func):
raise TypeError('flagsaver cannot be applied to a class.')
- return _wrap(func, self._overrides)
+ return _wrap(self.__class__, func, self._overrides)
def __enter__(self):
self._saved_flag_values = save_flag_values(FLAGS)
@@ -176,7 +319,56 @@ class _FlagOverrider(object):
restore_flag_values(self._saved_flag_values, FLAGS)
-def _copy_flag_dict(flag):
+class _ParsingFlagOverrider(_FlagOverrider):
+ """Context manager for overriding flags.
+
+ Simulates command line parsing.
+
+ This is simlar to _FlagOverrider except that all **overrides should be
+ strings or sequences of strings, and when context is entered this class calls
+ .parse(value)
+
+ This results in the flags having .present set properly.
+ """
+
+ def __init__(self, **overrides: Union[str, Sequence[str]]):
+ for flag_name, new_value in overrides.items():
+ if isinstance(new_value, str):
+ continue
+ if (isinstance(new_value, collections.abc.Sequence) and
+ all(isinstance(single_value, str) for single_value in new_value)):
+ continue
+ raise TypeError(
+ f'flagsaver.as_parsed() cannot parse {flag_name}. Expected a single '
+ f'string or sequence of strings but {type(new_value)} was provided.')
+ super().__init__(**overrides)
+
+ def __enter__(self):
+ self._saved_flag_values = save_flag_values(FLAGS)
+ try:
+ for flag_name, unparsed_value in self._overrides.items():
+ # LINT.IfChange(flag_override_parsing)
+ FLAGS[flag_name].parse(unparsed_value)
+ FLAGS[flag_name].using_default_value = False
+ # LINT.ThenChange()
+
+ # Perform the validation on all modified flags. This is something that
+ # FLAGS._set_attributes() does for you in _FlagOverrider.
+ for flag_name in self._overrides:
+ FLAGS._assert_validators(FLAGS[flag_name].validators)
+
+ except KeyError as e:
+ # If a flag doesn't exist, an UnrecognizedFlagError is more specific.
+ restore_flag_values(self._saved_flag_values, FLAGS)
+ raise flags.UnrecognizedFlagError('Unknown command line flag.') from e
+
+ except:
+ # It may fail because of flag validators or general parsing issues.
+ restore_flag_values(self._saved_flag_values, FLAGS)
+ raise
+
+
+def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
"""Returns a copy of the flag object's ``__dict__``.
It's mostly a shallow copy of the ``__dict__``, except it also does a shallow