aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYilei "Dolee" Yang <yileiyang@google.com>2022-10-13 13:30:23 -0700
committerGitHub <noreply@github.com>2022-10-13 13:30:23 -0700
commit9ac99c1b5699c11e4759a9955b74f3d07bcf3a34 (patch)
tree0403d229272bf94e8091b57413ff9c72b00617f9
parenta0ae31683e6cf3667886c500327f292c893a1740 (diff)
parent042ca2096daea1dc09308a3e3e29e128ddd6e80c (diff)
downloadabsl-py-9ac99c1b5699c11e4759a9955b74f3d07bcf3a34.tar.gz
Merge pull request #201 from yilei/push_up_to_480399279upstream/v1.3.0
Push up to 480399279
-rw-r--r--CHANGELOG.md34
-rw-r--r--absl/flags/__init__.py5
-rw-r--r--absl/flags/_argument_parser.py8
-rw-r--r--absl/flags/_defines.py31
-rw-r--r--absl/flags/_defines.pyi5
-rw-r--r--absl/flags/_flag.py2
-rw-r--r--absl/flags/_flagvalues.py42
-rw-r--r--absl/flags/_helpers.py40
-rw-r--r--absl/flags/_validators.py51
-rw-r--r--absl/flags/_validators_classes.py2
-rw-r--r--absl/flags/tests/_helpers_test.py15
-rw-r--r--absl/flags/tests/_validators_test.py245
-rw-r--r--absl/flags/tests/flags_test.py99
-rw-r--r--absl/testing/absltest.py2
-rw-r--r--absl/testing/tests/xml_reporter_test.py10
-rw-r--r--absl/testing/xml_reporter.py6
-rw-r--r--setup.py2
17 files changed, 501 insertions, 98 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5464857..ae82a55 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,11 +6,45 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
## Unreleased
+Nothing notable unreleased.
+
+## 1.3.0 (2022-10-11)
+
+### Added
+
+* (flags) Added a new `absl.flags.set_default` function that updates the flag
+ default for a provided `FlagHolder`. This parallels the
+ `absl.flags.FlagValues.set_default` interface which takes a flag name.
+* (flags) The following functions now also accept `FlagHolder` instance(s) in
+ addition to flag name(s) as their first positional argument:
+ - `flags.register_validator`
+ - `flags.validator`
+ - `flags.register_multi_flags_validator`
+ - `flags.multi_flags_validator`
+ - `flags.mark_flag_as_required`
+ - `flags.mark_flags_as_required`
+ - `flags.mark_flags_as_mutual_exclusive`
+ - `flags.mark_bool_flags_as_mutual_exclusive`
+ - `flags.declare_key_flag`
+
### Changed
* (testing) Assertions `assertRaisesWithPredicateMatch` and
`assertRaisesWithLiteralMatch` now capture the raised `Exception` for
further analysis when used as a context manager.
+* (testing) TextAndXMLTestRunner now produces time duration values with
+ millisecond precision in XML test result output.
+* (flags) Keyword access to `flag_name` arguments in the following functions
+ is deprecated. This parameter will be renamed in a future 2.0.0 release.
+ - `flags.register_validator`
+ - `flags.validator`
+ - `flags.register_multi_flags_validator`
+ - `flags.multi_flags_validator`
+ - `flags.mark_flag_as_required`
+ - `flags.mark_flags_as_required`
+ - `flags.mark_flags_as_mutual_exclusive`
+ - `flags.mark_bool_flags_as_mutual_exclusive`
+ - `flags.declare_key_flag`
## 1.2.0 (2022-07-18)
diff --git a/absl/flags/__init__.py b/absl/flags/__init__.py
index 45e64f3..6d8ba03 100644
--- a/absl/flags/__init__.py
+++ b/absl/flags/__init__.py
@@ -68,6 +68,8 @@ __all__ = (
'mark_flags_as_required',
'mark_flags_as_mutual_exclusive',
'mark_bool_flags_as_mutual_exclusive',
+ # Flag modifiers.
+ 'set_default',
# Key flag related functions.
'declare_key_flag',
'adopt_module_key_flags',
@@ -152,6 +154,9 @@ mark_flags_as_required = _validators.mark_flags_as_required
mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive
+# Flag modifiers.
+set_default = _defines.set_default
+
# Key flag related functions.
declare_key_flag = _defines.declare_key_flag
adopt_module_key_flags = _defines.adopt_module_key_flags
diff --git a/absl/flags/_argument_parser.py b/absl/flags/_argument_parser.py
index 7a94c69..2c4de9b 100644
--- a/absl/flags/_argument_parser.py
+++ b/absl/flags/_argument_parser.py
@@ -147,7 +147,7 @@ class ArgumentSerializer(object):
def serialize(self, value):
"""Returns a serialized string of the value."""
- return _helpers.str_or_unicode(value)
+ return str(value)
class NumericParser(ArgumentParser):
@@ -454,7 +454,7 @@ class ListSerializer(ArgumentSerializer):
def serialize(self, value):
"""See base class."""
- return self.list_sep.join([_helpers.str_or_unicode(x) for x in value])
+ return self.list_sep.join([str(x) for x in value])
class EnumClassListSerializer(ListSerializer):
@@ -498,7 +498,7 @@ class CsvListSerializer(ArgumentSerializer):
# We need the returned value to be pure ascii or Unicodes so that
# when the xml help is generated they are usefully encodable.
- return _helpers.str_or_unicode(serialized_value)
+ return str(serialized_value)
class EnumClassSerializer(ArgumentSerializer):
@@ -514,7 +514,7 @@ class EnumClassSerializer(ArgumentSerializer):
def serialize(self, value):
"""Returns a serialized string of the Enum class value."""
- as_string = _helpers.str_or_unicode(value.name)
+ as_string = str(value.name)
return as_string.lower() if self._lowercase else as_string
diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py
index 12335e5..dce53ea 100644
--- a/absl/flags/_defines.py
+++ b/absl/flags/_defines.py
@@ -148,6 +148,23 @@ def DEFINE_flag( # pylint: disable=invalid-name
fv, flag, ensure_non_none_value=ensure_non_none_value)
+def set_default(flag_holder, value):
+ """Changes the default value of the provided flag object.
+
+ The flag's current value is also updated if the flag is currently using
+ the default value, i.e. not specified in the command line, and not set
+ by FLAGS.name = value.
+
+ Args:
+ flag_holder: FlagHolder, the flag to modify.
+ value: The new default value.
+
+ Raises:
+ IllegalFlagValueError: Raised when value is not valid.
+ """
+ flag_holder._flagvalues.set_default(flag_holder.name, value) # pylint: disable=protected-access
+
+
def _internal_declare_key_flags(flag_names,
flag_values=_flagvalues.FLAGS,
key_flag_values=None):
@@ -157,8 +174,7 @@ def _internal_declare_key_flags(flag_names,
adopt_module_key_flags instead.
Args:
- flag_names: [str], a list of strings that are names of already-registered
- Flag objects.
+ flag_names: [str], a list of names of already-registered Flag objects.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flags listed in flag_names have registered (the value of the flag_values
argument from the ``DEFINE_*`` calls that defined those flags). This
@@ -176,8 +192,7 @@ def _internal_declare_key_flags(flag_names,
module = _helpers.get_calling_module()
for flag_name in flag_names:
- flag = flag_values[flag_name]
- key_flag_values.register_key_flag_for_module(module, flag)
+ key_flag_values.register_key_flag_for_module(module, flag_values[flag_name])
def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS):
@@ -194,9 +209,10 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS):
flags.declare_key_flag('flag_1')
Args:
- flag_name: str, the name of an already declared flag. (Redeclaring flags as
- key, including flags implicitly key because they were declared in this
- module, is a no-op.)
+ flag_name: str | :class:`FlagHolder`, the name or holder of an already
+ declared flag. (Redeclaring flags as key, including flags implicitly key
+ because they were declared in this module, is a no-op.)
+ Positional-only parameter.
flag_values: :class:`FlagValues`, the FlagValues instance in which the
flag will be declared as a key flag. This should almost never need to be
overridden.
@@ -204,6 +220,7 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS):
Raises:
ValueError: Raised if flag_name not defined as a Python flag.
"""
+ flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
if flag_name in _helpers.SPECIAL_FLAGS:
# Take care of the special flags, e.g., --flagfile, --undefok.
# These flags are defined in SPECIAL_FLAGS, and are treated
diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi
index 0fbe921..9bc8067 100644
--- a/absl/flags/_defines.pyi
+++ b/absl/flags/_defines.pyi
@@ -650,8 +650,11 @@ def DEFINE_alias(
...
+def set_default(flag_holder: _flagvalues.FlagHolder[_T], value: _T) -> None:
+ ...
+
-def declare_key_flag(flag_name: Text,
+def declare_key_flag(flag_name: Union[Text, _flagvalues.FlagHolder],
flag_values: _flagvalues.FlagValues = ...) -> None:
...
diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py
index 28d9219..124f137 100644
--- a/absl/flags/_flag.py
+++ b/absl/flags/_flag.py
@@ -153,7 +153,7 @@ class Flag(object):
return repr('true')
else:
return repr('false')
- return repr(_helpers.str_or_unicode(value))
+ return repr(str(value))
def parse(self, argument):
"""Parses string and sets flag value.
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index c52990c..937dc6c 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -412,11 +412,7 @@ class FlagValues:
fl = self._flags()
if not isinstance(flag, _flag.Flag):
raise _exceptions.IllegalFlagValueError(flag)
- if str is bytes and isinstance(name, unicode):
- # When using Python 2 with unicode_literals, allow it but encode it
- # into the bytes type we require.
- name = name.encode('utf-8')
- if not isinstance(name, type('')):
+ if not isinstance(name, str):
raise _exceptions.Error('Flag name must be a string')
if not name:
raise _exceptions.Error('Flag name cannot be empty')
@@ -632,7 +628,7 @@ class FlagValues:
TypeError: Raised on passing wrong type of arguments.
ValueError: Raised on flag value parsing error.
"""
- if _helpers.is_bytes_or_string(argv):
+ if isinstance(argv, (str, bytes)):
raise TypeError(
'argv should be a tuple/list of strings, not bytes or string.')
if not argv:
@@ -1006,7 +1002,7 @@ class FlagValues:
def _is_flag_file_directive(self, flag_string):
"""Checks whether flag_string contain a --flagfile=<foo> directive."""
- if isinstance(flag_string, type('')):
+ if isinstance(flag_string, str):
if flag_string.startswith('--flagfile='):
return 1
elif flag_string == '--flagfile':
@@ -1388,3 +1384,35 @@ class FlagHolder(Generic[_T]):
def present(self):
"""Returns True if the flag was parsed from command-line flags."""
return bool(self._flagvalues[self._name].present)
+
+
+def resolve_flag_ref(flag_ref, flag_values):
+ """Helper to validate and resolve a flag reference argument."""
+ if isinstance(flag_ref, FlagHolder):
+ new_flag_values = flag_ref._flagvalues # pylint: disable=protected-access
+ if flag_values != FLAGS and flag_values != new_flag_values:
+ raise ValueError(
+ 'flag_values must not be customized when operating on a FlagHolder')
+ return flag_ref.name, new_flag_values
+ return flag_ref, flag_values
+
+
+def resolve_flag_refs(flag_refs, flag_values):
+ """Helper to validate and resolve flag reference list arguments."""
+ fv = None
+ names = []
+ for ref in flag_refs:
+ if isinstance(ref, FlagHolder):
+ newfv = ref._flagvalues # pylint: disable=protected-access
+ name = ref.name
+ else:
+ newfv = flag_values
+ name = ref
+ if fv and fv != newfv:
+ raise ValueError(
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ fv = newfv
+ names.append(name)
+ return names, fv
diff --git a/absl/flags/_helpers.py b/absl/flags/_helpers.py
index cb0cfb2..ea02f2d 100644
--- a/absl/flags/_helpers.py
+++ b/absl/flags/_helpers.py
@@ -32,8 +32,9 @@ except ImportError:
_DEFAULT_HELP_WIDTH = 80 # Default width of help output.
-_MIN_HELP_WIDTH = 40 # Minimal "sane" width of help output. We assume that any
- # value below 40 is unreasonable.
+# Minimal "sane" width of help output. We assume that any value below 40 is
+# unreasonable.
+_MIN_HELP_WIDTH = 40
# Define the allowed error rate in an input string to get suggestions.
#
@@ -125,32 +126,6 @@ def get_calling_module():
return get_calling_module_object_and_name().module_name
-def str_or_unicode(value):
- """Converts a value to a python string.
-
- Behavior of this function is intentionally different in Python2/3.
-
- In Python2, the given value is attempted to convert to a str (byte string).
- If it contains non-ASCII characters, it is converted to a unicode instead.
-
- In Python3, the given value is always converted to a str (unicode string).
-
- This behavior reflects the (bad) practice in Python2 to try to represent
- a string as str as long as it contains ASCII characters only.
-
- Args:
- value: An object to be converted to a string.
-
- Returns:
- A string representation of the given value. See the description above
- for its type.
- """
- try:
- return str(value)
- except UnicodeEncodeError:
- return unicode(value) # Python3 should never come here
-
-
def create_xml_dom_element(doc, name, value):
"""Returns an XML DOM element with name and text value.
@@ -164,7 +139,7 @@ def create_xml_dom_element(doc, name, value):
Returns:
An instance of minidom.Element.
"""
- s = str_or_unicode(value)
+ s = str(value)
if isinstance(value, bool):
# Display boolean values as the C++ flag library does: no caps.
s = s.lower()
@@ -424,10 +399,3 @@ def doc_to_help(doc):
doc = re.sub(r'(?<=\S)\n(?=\S)', ' ', doc, flags=re.M)
return doc
-
-
-def is_bytes_or_string(maybe_string):
- if str is bytes:
- return isinstance(maybe_string, basestring)
- else:
- return isinstance(maybe_string, (str, bytes))
diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py
index c4e1139..2161284 100644
--- a/absl/flags/_validators.py
+++ b/absl/flags/_validators.py
@@ -51,7 +51,8 @@ def register_validator(flag_name,
change of the corresponding flag's value.
Args:
- flag_name: str, name of the flag to be checked.
+ flag_name: str | FlagHolder, name or holder of the flag to be checked.
+ Positional-only parameter.
checker: callable, a function to validate the flag.
* input - A single positional argument: The value of the corresponding
@@ -70,7 +71,10 @@ def register_validator(flag_name,
Raises:
AttributeError: Raised when flag_name is not registered as a valid flag
name.
+ ValueError: Raised when flag_values is non-default and does not match the
+ FlagValues of the provided FlagHolder instance.
"""
+ flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
v = _validators_classes.SingleFlagValidator(flag_name, checker, message)
_add_validator(flag_values, v)
@@ -88,7 +92,8 @@ def validator(flag_name, message='Flag validation failed',
See :func:`register_validator` for the specification of checker function.
Args:
- flag_name: str, name of the flag to be checked.
+ flag_name: str | FlagHolder, name or holder of the flag to be checked.
+ Positional-only parameter.
message: str, error text to be shown to the user if checker returns False.
If checker raises flags.ValidationError, message from the raised
error will be shown.
@@ -119,7 +124,8 @@ def register_multi_flags_validator(flag_names,
change of the corresponding flag's value.
Args:
- flag_names: [str], a list of the flag names to be checked.
+ flag_names: [str | FlagHolder], a list of the flag names or holders to be
+ checked. Positional-only parameter.
multi_flags_checker: callable, a function to validate the flag.
* input - dict, with keys() being flag_names, and value for each key
@@ -136,7 +142,13 @@ def register_multi_flags_validator(flag_names,
Raises:
AttributeError: Raised when a flag is not registered as a valid flag name.
+ ValueError: Raised when multiple FlagValues are used in the same
+ invocation. This can occur when FlagHolders have different `_flagvalues`
+ or when str-type flag_names entries are present and the `flag_values`
+ argument does not match that of provided FlagHolder(s).
"""
+ flag_names, flag_values = _flagvalues.resolve_flag_refs(
+ flag_names, flag_values)
v = _validators_classes.MultiFlagsValidator(
flag_names, multi_flags_checker, message)
_add_validator(flag_values, v)
@@ -157,7 +169,8 @@ def multi_flags_validator(flag_names,
function.
Args:
- flag_names: [str], a list of the flag names to be checked.
+ flag_names: [str | FlagHolder], a list of the flag names or holders to be
+ checked. Positional-only parameter.
message: str, error text to be shown to the user if checker returns False.
If checker raises flags.ValidationError, message from the raised
error will be shown.
@@ -196,13 +209,17 @@ def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS):
app.run()
Args:
- flag_name: str, name of the flag
+ flag_name: str | FlagHolder, name or holder of the flag.
+ Positional-only parameter.
flag_values: flags.FlagValues, optional :class:`~absl.flags.FlagValues`
instance where the flag is defined.
Raises:
AttributeError: Raised when flag_name is not registered as a valid flag
name.
+ ValueError: Raised when flag_values is non-default and does not match the
+ FlagValues of the provided FlagHolder instance.
"""
+ flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
if flag_values[flag_name].default is not None:
warnings.warn(
'Flag --%s has a non-None default value; therefore, '
@@ -227,7 +244,7 @@ def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS):
app.run()
Args:
- flag_names: Sequence[str], names of the flags.
+ flag_names: Sequence[str | FlagHolder], names or holders of the flags.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
Raises:
@@ -248,13 +265,22 @@ def mark_flags_as_mutual_exclusive(flag_names, required=False,
includes multi flags with a default value of ``[]`` instead of None.
Args:
- flag_names: [str], names of the flags.
+ flag_names: [str | FlagHolder], names or holders of flags.
+ Positional-only parameter.
required: bool. If true, exactly one of the flags must have a value other
than None. Otherwise, at most one of the flags can have a value other
than None, and it is valid for all of the flags to be None.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
+
+ Raises:
+ ValueError: Raised when multiple FlagValues are used in the same
+ invocation. This can occur when FlagHolders have different `_flagvalues`
+ or when str-type flag_names entries are present and the `flag_values`
+ argument does not match that of provided FlagHolder(s).
"""
+ flag_names, flag_values = _flagvalues.resolve_flag_refs(
+ flag_names, flag_values)
for flag_name in flag_names:
if flag_values[flag_name].default is not None:
warnings.warn(
@@ -280,12 +306,21 @@ def mark_bool_flags_as_mutual_exclusive(flag_names, required=False,
"""Ensures that only one flag among flag_names is True.
Args:
- flag_names: [str], names of the flags.
+ flag_names: [str | FlagHolder], names or holders of flags.
+ Positional-only parameter.
required: bool. If true, exactly one flag must be True. Otherwise, at most
one flag can be True, and it is valid for all flags to be False.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
+
+ Raises:
+ ValueError: Raised when multiple FlagValues are used in the same
+ invocation. This can occur when FlagHolders have different `_flagvalues`
+ or when str-type flag_names entries are present and the `flag_values`
+ argument does not match that of provided FlagHolder(s).
"""
+ flag_names, flag_values = _flagvalues.resolve_flag_refs(
+ flag_names, flag_values)
for flag_name in flag_names:
if not flag_values[flag_name].boolean:
raise _exceptions.ValidationError(
diff --git a/absl/flags/_validators_classes.py b/absl/flags/_validators_classes.py
index 2881499..59100c8 100644
--- a/absl/flags/_validators_classes.py
+++ b/absl/flags/_validators_classes.py
@@ -156,7 +156,7 @@ class MultiFlagsValidator(Validator):
Args:
flag_values: flags.FlagValues, the FlagValues instance to get flags from.
Returns:
- dict, with keys() being self.lag_names, and value for each key
+ dict, with keys() being self.flag_names, and value for each key
being the value of the corresponding flag (string, boolean, etc).
"""
return dict([key, flag_values[key].value] for key in self.flag_names)
diff --git a/absl/flags/tests/_helpers_test.py b/absl/flags/tests/_helpers_test.py
index 2697d1c..78b9051 100644
--- a/absl/flags/tests/_helpers_test.py
+++ b/absl/flags/tests/_helpers_test.py
@@ -150,20 +150,5 @@ class GetCallingModuleTest(absltest.TestCase):
sys.modules = orig_sys_modules
-class IsBytesOrString(absltest.TestCase):
-
- def test_bytes(self):
- self.assertTrue(_helpers.is_bytes_or_string(b'bytes'))
-
- def test_str(self):
- self.assertTrue(_helpers.is_bytes_or_string('str'))
-
- def test_unicode(self):
- self.assertTrue(_helpers.is_bytes_or_string(u'unicode'))
-
- def test_list(self):
- self.assertFalse(_helpers.is_bytes_or_string(['str']))
-
-
if __name__ == '__main__':
absltest.main()
diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py
index 1cccf53..9aa328e 100644
--- a/absl/flags/tests/_validators_test.py
+++ b/absl/flags/tests/_validators_test.py
@@ -55,6 +55,45 @@ class SingleFlagValidatorTest(absltest.TestCase):
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
+ def test_success_holder(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.test_flag)
+ self.flag_values.test_flag = 2
+ self.assertEqual(2, self.flag_values.test_flag)
+ self.assertEqual([None, 2], self.call_args)
+
+ def test_success_holder_infer_flagvalues(self):
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen')
+
+ argv = ('./program',)
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.test_flag)
+ self.flag_values.test_flag = 2
+ self.assertEqual(2, self.flag_values.test_flag)
+ self.assertEqual([None, 2], self.call_args)
+
def test_default_value_not_used_success(self):
def checker(x):
self.call_args.append(x)
@@ -218,6 +257,26 @@ class SingleFlagValidatorTest(absltest.TestCase):
self.assertTrue(checker(3))
self.assertEqual([None, 2, 3], self.call_args)
+ def test_mismatching_flagvalues(self):
+
+ def checker(x):
+ self.call_args.append(x)
+ return True
+
+ flag_holder = _defines.DEFINE_integer(
+ 'test_flag',
+ None,
+ 'Usual integer flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'flag_values must not be customized when operating on a FlagHolder')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.register_validator(
+ flag_holder,
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
class MultiFlagsValidatorTest(absltest.TestCase):
"""Test flags multi-flag validators."""
@@ -226,9 +285,9 @@ class MultiFlagsValidatorTest(absltest.TestCase):
super(MultiFlagsValidatorTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
self.call_args = []
- _defines.DEFINE_integer(
+ self.foo_holder = _defines.DEFINE_integer(
'foo', 1, 'Usual integer flag', flag_values=self.flag_values)
- _defines.DEFINE_integer(
+ self.bar_holder = _defines.DEFINE_integer(
'bar', 2, 'Usual integer flag', flag_values=self.flag_values)
def test_success(self):
@@ -248,6 +307,55 @@ class MultiFlagsValidatorTest(absltest.TestCase):
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}],
self.call_args)
+ def test_success_holder(self):
+
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ return True
+
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder],
+ checker,
+ flag_values=self.flag_values)
+
+ argv = ('./program', '--bar=2')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.foo)
+ self.assertEqual(2, self.flag_values.bar)
+ self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
+ self.flag_values.foo = 3
+ self.assertEqual(3, self.flag_values.foo)
+ self.assertEqual([{
+ 'foo': 1,
+ 'bar': 2
+ }, {
+ 'foo': 3,
+ 'bar': 2
+ }], self.call_args)
+
+ def test_success_holder_infer_flagvalues(self):
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ return True
+
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder], checker)
+
+ argv = ('./program', '--bar=2')
+ self.flag_values(argv)
+ self.assertEqual(1, self.flag_values.foo)
+ self.assertEqual(2, self.flag_values.bar)
+ self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
+ self.flag_values.foo = 3
+ self.assertEqual(3, self.flag_values.foo)
+ self.assertEqual([{
+ 'foo': 1,
+ 'bar': 2
+ }, {
+ 'foo': 3,
+ 'bar': 2
+ }], self.call_args)
+
def test_validator_not_called_when_other_flag_is_changed(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
@@ -322,6 +430,30 @@ class MultiFlagsValidatorTest(absltest.TestCase):
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
+ def test_mismatching_flagvalues(self):
+
+ def checker(flags_dict):
+ self.call_args.append(flags_dict)
+ values = flags_dict.values()
+ # Make sure all the flags have different values.
+ return len(set(values)) == len(values)
+
+ other_holder = _defines.DEFINE_integer(
+ 'other_flag',
+ 3,
+ 'Other integer flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.register_multi_flags_validator(
+ [self.foo_holder, self.bar_holder, other_holder],
+ checker,
+ message='Errors happen',
+ flag_values=self.flag_values)
+
class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
@@ -329,9 +461,9 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
super(MarkFlagsAsMutualExclusiveTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
- _defines.DEFINE_string(
+ self.flag_one_holder = _defines.DEFINE_string(
'flag_one', None, 'flag one', flag_values=self.flag_values)
- _defines.DEFINE_string(
+ self.flag_two_holder = _defines.DEFINE_string(
'flag_two', None, 'flag two', flag_values=self.flag_values)
_defines.DEFINE_string(
'flag_three', None, 'flag three', flag_values=self.flag_values)
@@ -358,6 +490,24 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertIsNone(self.flag_values.flag_one)
self.assertIsNone(self.flag_values.flag_two)
+ def test_no_flags_present_holder(self):
+ self._mark_flags_as_mutually_exclusive(
+ [self.flag_one_holder, self.flag_two_holder], False)
+ argv = ('./program',)
+
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
+
+ def test_no_flags_present_mixed(self):
+ self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'],
+ False)
+ argv = ('./program',)
+
+ self.flag_values(argv)
+ self.assertIsNone(self.flag_values.flag_one)
+ self.assertIsNone(self.flag_values.flag_two)
+
def test_no_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
argv = ('./program',)
@@ -494,6 +644,20 @@ class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertIn('--flag_not_none has a non-None default value',
str(caught_warnings[0].message))
+ def test_multiple_flagvalues(self):
+ other_holder = _defines.DEFINE_boolean(
+ 'other_flagvalues',
+ False,
+ 'other ',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ self._mark_flags_as_mutually_exclusive(
+ [self.flag_one_holder, other_holder], False)
+
class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
@@ -501,13 +665,13 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
- _defines.DEFINE_boolean(
+ self.false_1_holder = _defines.DEFINE_boolean(
'false_1', False, 'default false 1', flag_values=self.flag_values)
- _defines.DEFINE_boolean(
+ self.false_2_holder = _defines.DEFINE_boolean(
'false_2', False, 'default false 2', flag_values=self.flag_values)
- _defines.DEFINE_boolean(
+ self.true_1_holder = _defines.DEFINE_boolean(
'true_1', True, 'default true 1', flag_values=self.flag_values)
- _defines.DEFINE_integer(
+ self.non_bool_holder = _defines.DEFINE_integer(
'non_bool', None, 'non bool', flag_values=self.flag_values)
def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required):
@@ -520,6 +684,20 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self.assertEqual(False, self.flag_values.false_1)
self.assertEqual(False, self.flag_values.false_2)
+ def test_no_flags_present_holder(self):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, self.false_2_holder], False)
+ self.flag_values(('./program',))
+ self.assertEqual(False, self.flag_values.false_1)
+ self.assertEqual(False, self.flag_values.false_2)
+
+ def test_no_flags_present_mixed(self):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, 'false_2'], False)
+ self.flag_values(('./program',))
+ self.assertEqual(False, self.flag_values.false_1)
+ self.assertEqual(False, self.flag_values.false_2)
+
def test_no_flags_present_required(self):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True)
argv = ('./program',)
@@ -554,6 +732,17 @@ class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'],
False)
+ def test_multiple_flagvalues(self):
+ other_bool_holder = _defines.DEFINE_boolean(
+ 'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'multiple FlagValues instances used in invocation. '
+ 'FlagHolders must be registered to the same FlagValues instance as '
+ 'do flag names, if provided.')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ self._mark_bool_flags_as_mutually_exclusive(
+ [self.false_1_holder, other_bool_holder], False)
+
class MarkFlagAsRequiredTest(absltest.TestCase):
@@ -570,6 +759,22 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.flag_values(argv)
self.assertEqual('value', self.flag_values.string_flag)
+ def test_success_holder(self):
+ holder = _defines.DEFINE_string(
+ 'string_flag', None, 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(holder, flag_values=self.flag_values)
+ argv = ('./program', '--string_flag=value')
+ self.flag_values(argv)
+ self.assertEqual('value', self.flag_values.string_flag)
+
+ def test_success_holder_infer_flagvalues(self):
+ holder = _defines.DEFINE_string(
+ 'string_flag', None, 'string flag', flag_values=self.flag_values)
+ _validators.mark_flag_as_required(holder)
+ argv = ('./program', '--string_flag=value')
+ self.flag_values(argv)
+ self.assertEqual('value', self.flag_values.string_flag)
+
def test_catch_none_as_default(self):
_defines.DEFINE_string(
'string_flag', None, 'string flag', flag_values=self.flag_values)
@@ -608,6 +813,18 @@ class MarkFlagAsRequiredTest(absltest.TestCase):
self.assertIn('--flag_not_none has a non-None default value',
str(caught_warnings[0].message))
+ def test_mismatching_flagvalues(self):
+ flag_holder = _defines.DEFINE_string(
+ 'string_flag',
+ 'value',
+ 'string flag',
+ flag_values=_flagvalues.FlagValues())
+ expected = (
+ 'flag_values must not be customized when operating on a FlagHolder')
+ with self.assertRaisesWithLiteralMatch(ValueError, expected):
+ _validators.mark_flag_as_required(
+ flag_holder, flag_values=self.flag_values)
+
class MarkFlagsAsRequiredTest(absltest.TestCase):
@@ -627,6 +844,18 @@ class MarkFlagsAsRequiredTest(absltest.TestCase):
self.assertEqual('value_1', self.flag_values.string_flag_1)
self.assertEqual('value_2', self.flag_values.string_flag_2)
+ def test_success_holders(self):
+ flag_1_holder = _defines.DEFINE_string(
+ 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
+ flag_2_holder = _defines.DEFINE_string(
+ 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
+ _validators.mark_flags_as_required([flag_1_holder, flag_2_holder],
+ flag_values=self.flag_values)
+ argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
+ self.flag_values(argv)
+ self.assertEqual('value_1', self.flag_values.string_flag_1)
+ self.assertEqual('value_2', self.flag_values.string_flag_2)
+
def test_catch_none_as_default(self):
_defines.DEFINE_string(
'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index 8a42bc9..77ed307 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -2483,6 +2483,71 @@ class NonGlobalFlagsTest(absltest.TestCase):
flag_values['flag_name'] = 'flag_value'
+class SetDefaultTest(absltest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.flag_values = flags.FlagValues()
+
+ def test_success(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ flags.set_default(int_holder, 2)
+ self.flag_values.mark_as_parsed()
+
+ self.assertEqual(int_holder.value, 2)
+
+ def test_update_after_parse(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ flags.set_default(int_holder, 2)
+
+ self.assertEqual(int_holder.value, 2)
+
+ def test_overridden_by_explicit_assignment(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ self.flag_values.an_int = 3
+ flags.set_default(int_holder, 2)
+
+ self.assertEqual(int_holder.value, 3)
+
+ def test_restores_back_to_none(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', None, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+ flags.set_default(int_holder, 3)
+ flags.set_default(int_holder, None)
+
+ self.assertIsNone(int_holder.value)
+
+ def test_failure_on_invalid_type(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ flags.set_default(int_holder, 'a')
+
+ def test_failure_on_type_protected_none_default(self):
+ int_holder = flags.DEFINE_integer(
+ 'an_int', 1, 'an int', flag_values=self.flag_values)
+
+ self.flag_values.mark_as_parsed()
+
+ flags.set_default(int_holder, None) # NOTE: should be a type failure
+
+ with self.assertRaises(flags.IllegalFlagValueError):
+ _ = int_holder.value # Will also fail on later access.
+
+
class KeyFlagsTest(absltest.TestCase):
def setUp(self):
@@ -2646,6 +2711,40 @@ class KeyFlagsTest(absltest.TestCase):
self._get_names_of_key_flags(main_module, fv),
names_of_flags_defined_by_bar + ['flagfile', 'undefok'])
+ def test_key_flags_with_flagholders(self):
+ main_module = sys.argv[0]
+
+ self.assertListEqual(
+ self._get_names_of_key_flags(main_module, self.flag_values), [])
+ self.assertListEqual(
+ self._get_names_of_defined_flags(main_module, self.flag_values), [])
+
+ int_holder = flags.DEFINE_integer(
+ 'main_module_int_fg',
+ 1,
+ 'Integer flag in the main module.',
+ flag_values=self.flag_values)
+
+ flags.declare_key_flag(int_holder, self.flag_values)
+
+ self.assertCountEqual(
+ self.flag_values.get_flags_for_module(main_module),
+ self.flag_values.get_key_flags_for_module(main_module))
+
+ bool_holder = flags.DEFINE_boolean(
+ 'main_module_bool_fg',
+ False,
+ 'Boolean flag in the main module.',
+ flag_values=self.flag_values)
+
+ flags.declare_key_flag(bool_holder) # omitted flag_values
+
+ self.assertCountEqual(
+ self.flag_values.get_flags_for_module(main_module),
+ self.flag_values.get_key_flags_for_module(main_module))
+
+ self.assertLen(self.flag_values.get_flags_for_module(main_module), 2)
+
def test_main_module_help_with_key_flags(self):
# Similar to test_main_module_help, but this time we make sure to
# declare some key flags.
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index d1b0e59..9071f8f 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -735,7 +735,7 @@ class TestCase(unittest.TestCase):
(e.g. `TestCase.enter_context`), the context is exited after the test
class's tearDownClass call.
- Contexts are are exited in the reverse order of entering. They will always
+ Contexts are exited in the reverse order of entering. They will always
be exited, regardless of test failure/success.
This is useful to eliminate per-test boilerplate when context managers
diff --git a/absl/testing/tests/xml_reporter_test.py b/absl/testing/tests/xml_reporter_test.py
index 0261f64..c0d43a6 100644
--- a/absl/testing/tests/xml_reporter_test.py
+++ b/absl/testing/tests/xml_reporter_test.py
@@ -64,12 +64,12 @@ def xml_escaped_exception_type(exception_type):
OUTPUT_STRING = '\n'.join([
r'<\?xml version="1.0"\?>',
('<testsuites name="" tests="%(tests)d" failures="%(failures)d"'
- ' errors="%(errors)d" time="%(run_time).1f" timestamp="%(start_time)s">'),
+ ' errors="%(errors)d" time="%(run_time).3f" timestamp="%(start_time)s">'),
('<testsuite name="%(suite_name)s" tests="%(tests)d"'
- ' failures="%(failures)d" errors="%(errors)d" time="%(run_time).1f"'
+ ' failures="%(failures)d" errors="%(errors)d" time="%(run_time).3f"'
' timestamp="%(start_time)s">'),
(' <testcase name="%(test_name)s" status="%(status)s" result="%(result)s"'
- ' time="%(run_time).1f" classname="%(classname)s"'
+ ' time="%(run_time).3f" classname="%(classname)s"'
' timestamp="%(start_time)s">%(message)s'),
' </testcase>', '</testsuite>',
'</testsuites>',
@@ -696,8 +696,8 @@ class TextAndXMLTestResultTest(absltest.TestCase):
run_time = max(end_time1, end_time2) - min(start_time1, start_time2)
timestamp = self._iso_timestamp(start_time1)
expected_prefix = """<?xml version="1.0"?>
-<testsuites name="" tests="2" failures="0" errors="0" time="%.1f" timestamp="%s">
-<testsuite name="MockTest" tests="2" failures="0" errors="0" time="%.1f" timestamp="%s">
+<testsuites name="" tests="2" failures="0" errors="0" time="%.3f" timestamp="%s">
+<testsuite name="MockTest" tests="2" failures="0" errors="0" time="%.3f" timestamp="%s">
""" % (run_time, timestamp, run_time, timestamp)
xml_output = self.xml_stream.getvalue()
self.assertTrue(
diff --git a/absl/testing/xml_reporter.py b/absl/testing/xml_reporter.py
index 5996ce2..591eb7e 100644
--- a/absl/testing/xml_reporter.py
+++ b/absl/testing/xml_reporter.py
@@ -202,7 +202,7 @@ class _TestCaseResult(object):
('name', '%s' % self.name),
('status', '%s' % status),
('result', '%s' % result),
- ('time', '%.1f' % self.run_time),
+ ('time', '%.3f' % self.run_time),
('classname', self.full_class_name),
('timestamp', _iso8601_timestamp(self.start_time)),
]
@@ -263,7 +263,7 @@ class _TestSuiteResult(object):
('tests', '%d' % overall_test_count),
('failures', '%d' % overall_failures),
('errors', '%d' % overall_errors),
- ('time', '%.1f' % (self.overall_end_time - self.overall_start_time)),
+ ('time', '%.3f' % (self.overall_end_time - self.overall_start_time)),
('timestamp', _iso8601_timestamp(self.overall_start_time)),
]
_print_xml_element_header('testsuites', overall_attributes, stream)
@@ -285,7 +285,7 @@ class _TestSuiteResult(object):
('tests', '%d' % len(suite)),
('failures', '%d' % failures),
('errors', '%d' % errors),
- ('time', '%.1f' % (suite_end_time - suite_start_time)),
+ ('time', '%.3f' % (suite_end_time - suite_start_time)),
('timestamp', _iso8601_timestamp(suite_start_time)),
]
_print_xml_element_header('testsuite', suite_attributes, stream)
diff --git a/setup.py b/setup.py
index 23fcac2..f947fd7 100644
--- a/setup.py
+++ b/setup.py
@@ -43,7 +43,7 @@ with open(_README_PATH, 'rb') as fp:
setuptools.setup(
name='absl-py',
- version='1.2.0',
+ version='1.3.0',
description=(
'Abseil Python Common Libraries, '
'see https://github.com/abseil/abseil-py.'),