aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYilei Yang <yileiyang@google.com>2022-12-12 22:08:08 -0800
committerYilei Yang <yileiyang@google.com>2022-12-12 22:08:08 -0800
commit8da67e52a482c58296a89ad1960b7a22941eeff8 (patch)
treeb7885bd93479413f154d972f43b134b47983132d
parent83adb264544454b5ab6ade0c58d0d701761d0eb9 (diff)
parent814e1f373cd83041cf34b6916586a5ed1c1253ce (diff)
downloadabsl-py-8da67e52a482c58296a89ad1960b7a22941eeff8.tar.gz
Merge commit for internal changes.
-rw-r--r--CHANGELOG.md8
-rw-r--r--absl/flags/__init__.pyi3
-rw-r--r--absl/flags/_defines.py10
-rw-r--r--absl/flags/_flag.pyi3
-rw-r--r--absl/flags/_flagvalues.py4
-rw-r--r--absl/flags/tests/flags_test.py11
-rw-r--r--absl/logging/__init__.py21
-rw-r--r--absl/testing/absltest.py7
-rw-r--r--absl/testing/flagsaver.py50
9 files changed, 86 insertions, 31 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1405a14..56f832b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,7 +8,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
### Changed
-* If no log dir is specified `logging.find_log_dir()` now falls back to `tempfile.gettempdir()` instead of `/tmp/`.
+* (logging) If no log dir is specified `logging.find_log_dir()` now falls back
+ to `tempfile.gettempdir()` instead of `/tmp/`.
+
+### Fixed
+
+* (flags) Additional kwargs (e.g. `short_name=`) to `DEFINE_multi_enum_class`
+ are now correctly passed to the underlying `Flag` object.
## 1.3.0 (2022-10-11)
diff --git a/absl/flags/__init__.pyi b/absl/flags/__init__.pyi
index 4eee59e..7bf6842 100644
--- a/absl/flags/__init__.pyi
+++ b/absl/flags/__init__.pyi
@@ -52,6 +52,9 @@ mark_flags_as_required = _validators.mark_flags_as_required
mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive
+# Flag modifiers.
+set_default = _defines.set_default
+
# Key flag related functions.
declare_key_flag = _defines.declare_key_flag
adopt_module_key_flags = _defines.adopt_module_key_flags
diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py
index dce53ea..61354e9 100644
--- a/absl/flags/_defines.py
+++ b/absl/flags/_defines.py
@@ -859,11 +859,17 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
"""
return DEFINE_flag(
_flag.MultiEnumClassFlag(
- name, default, help, enum_class, case_sensitive=case_sensitive),
+ name,
+ default,
+ help,
+ enum_class,
+ case_sensitive=case_sensitive,
+ **args,
+ ),
flag_values,
module_name,
required=required,
- **args)
+ )
def DEFINE_alias( # pylint: disable=invalid-name
diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi
index 9b4a3d3..3506644 100644
--- a/absl/flags/_flag.pyi
+++ b/absl/flags/_flag.pyi
@@ -20,7 +20,7 @@ import functools
from absl.flags import _argument_parser
import enum
-from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence
+from typing import Callable, Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence
_T = TypeVar('_T')
_ET = TypeVar('_ET', bound=enum.Enum)
@@ -44,6 +44,7 @@ class Flag(Generic[_T]):
using_default_value = ... # type: bool
allow_overwrite = ... # type: bool
allow_using_method_names = ... # type: bool
+ validators = ... # type: List[Callable[[Any], bool]]
def __init__(self,
parser: _argument_parser.ArgumentParser[_T],
diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py
index 937dc6c..6661b78 100644
--- a/absl/flags/_flagvalues.py
+++ b/absl/flags/_flagvalues.py
@@ -411,7 +411,9 @@ class FlagValues:
"""Registers a new flag variable."""
fl = self._flags()
if not isinstance(flag, _flag.Flag):
- raise _exceptions.IllegalFlagValueError(flag)
+ raise _exceptions.IllegalFlagValueError(
+ f'Expect Flag instances, found type {type(flag)}. '
+ "Maybe you didn't mean to use FlagValue.__setitem__?")
if not isinstance(name, str):
raise _exceptions.Error('Flag name must be a string')
if not name:
diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py
index 77ed307..7cacbc8 100644
--- a/absl/flags/tests/flags_test.py
+++ b/absl/flags/tests/flags_test.py
@@ -1591,6 +1591,17 @@ class MultiEnumFlagsTest(absltest.TestCase):
class MultiEnumClassFlagsTest(absltest.TestCase):
+ def test_short_name(self):
+ fv = flags.FlagValues()
+ flags.DEFINE_multi_enum_class(
+ 'fruit',
+ None,
+ Fruit,
+ 'Enum option that can occur multiple times',
+ flag_values=fv,
+ short_name='me')
+ self.assertEqual(fv['fruit'].short_name, 'me')
+
def test_define_results_in_registered_flag_with_none(self):
fv = flags.FlagValues()
enum_defaults = None
diff --git a/absl/logging/__init__.py b/absl/logging/__init__.py
index 33276cd..f4e7967 100644
--- a/absl/logging/__init__.py
+++ b/absl/logging/__init__.py
@@ -86,6 +86,7 @@ import os
import socket
import struct
import sys
+import tempfile
import threading
import tempfile
import time
@@ -707,22 +708,26 @@ def find_log_dir(log_dir=None):
FileNotFoundError: raised in Python 3 when it cannot find a log directory.
OSError: raised in Python 2 when it cannot find a log directory.
"""
- # Get a possible log dir.
+ # Get a list of possible log dirs (will try to use them in order).
+ # NOTE: Google's internal implementation has a special handling for Google
+ # machines, which uses a list of directories. Hence the following uses `dirs`
+ # instead of a single directory.
if log_dir:
# log_dir was explicitly specified as an arg, so use it and it alone.
- log_dir_candidate = log_dir
+ dirs = [log_dir]
elif FLAGS['log_dir'].value:
# log_dir flag was provided, so use it and it alone (this mimics the
# behavior of the same flag in logging.cc).
- log_dir_candidate = FLAGS['log_dir'].value
+ dirs = [FLAGS['log_dir'].value]
else:
- log_dir_candidate = tempfile.gettempdir()
+ dirs = [tempfile.gettempdir()]
- # Test if log dir candidate is usable.
- if os.path.isdir(log_dir_candidate) and os.access(log_dir_candidate, os.W_OK):
- return log_dir_candidate
+ # Find the first usable log dir.
+ for d in dirs:
+ if os.path.isdir(d) and os.access(d, os.W_OK):
+ return d
raise FileNotFoundError(
- "Can't find a writable directory for logs, tried %s" % log_dir_candidate)
+ "Can't find a writable directory for logs, tried %s" % dirs)
def get_absl_log_prefix(record):
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index 9071f8f..1bbcee7 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -533,7 +533,10 @@ class _TempFile(object):
# currently `Any` to avoid [bad-return-type] errors in the open_* methods.
@contextlib.contextmanager
def _open(
- self, mode: str, encoding: str = 'utf8', errors: str = 'strict'
+ self,
+ mode: str,
+ encoding: Optional[str] = 'utf8',
+ errors: Optional[str] = 'strict',
) -> Iterator[Any]:
with io.open(
self.full_path, mode=mode, encoding=encoding, errors=errors) as fp:
@@ -638,7 +641,7 @@ class TestCase(unittest.TestCase):
self.assertTrue(os.path.exists(expected_paths[1]))
self.assertEqual('foo', out_log.read_text())
- See also: :meth:`create_tempdir` for creating temporary files.
+ See also: :meth:`create_tempfile` for creating temporary files.
Args:
name: Optional name of the directory. If not given, a unique
diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py
index 37926d7..774c698 100644
--- a/absl/testing/flagsaver.py
+++ b/absl/testing/flagsaver.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Decorator and context manager for saving and restoring flag values.
There are many ways to save and restore. Always use the most convenient method
@@ -61,11 +60,26 @@ and then restore flag values, the added flag will be deleted with no errors.
import functools
import inspect
+from typing import overload, Any, Callable, Mapping, Tuple, TypeVar
from absl import flags
FLAGS = flags.FLAGS
+# The type of pre/post wrapped functions.
+_CallableT = TypeVar('_CallableT', bound=Callable)
+
+
+@overload
+def flagsaver(*args: Tuple[flags.FlagHolder, Any],
+ **kwargs: Any) -> '_FlagOverrider':
+ ...
+
+
+@overload
+def flagsaver(func: _CallableT) -> _CallableT:
+ ...
+
def flagsaver(*args, **kwargs):
"""The main flagsaver interface. See module doc for usage."""
@@ -94,12 +108,14 @@ def flagsaver(*args, **kwargs):
return _FlagOverrider(**kwargs)
-def save_flag_values(flag_values=FLAGS):
+def save_flag_values(
+ flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]:
"""Returns copy of flag values as a dict.
Args:
- flag_values: FlagValues, the FlagValues instance with which the flag will
- be saved. This should almost never need to be overridden.
+ flag_values: FlagValues, the FlagValues instance with which the flag will be
+ saved. This should almost never need to be overridden.
+
Returns:
Dictionary mapping keys to values. Keys are flag names, values are
corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``.
@@ -107,13 +123,14 @@ def save_flag_values(flag_values=FLAGS):
return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
-def restore_flag_values(saved_flag_values, flag_values=FLAGS):
+def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
+ flag_values: flags.FlagValues = FLAGS):
"""Restores flag values based on the dictionary of flag values.
Args:
saved_flag_values: {'flag_name': value_dict, ...}
- flag_values: FlagValues, the FlagValues instance from which the flag will
- be restored. This should almost never need to be overridden.
+ flag_values: FlagValues, the FlagValues instance from which the flag will be
+ restored. This should almost never need to be overridden.
"""
new_flag_names = list(flag_values)
for name in new_flag_names:
@@ -127,23 +144,24 @@ def restore_flag_values(saved_flag_values, flag_values=FLAGS):
flag_values[name].__dict__ = saved
-def _wrap(func, overrides):
+def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT:
"""Creates a wrapper function that saves/restores flag values.
Args:
- func: function object - This will be called between saving flags and
- restoring flags.
- overrides: {str: object} - Flag names mapped to their values. These flags
- will be set after saving the original flag state.
+ func: This will be called between saving flags and restoring flags.
+ overrides: Flag names mapped to their values. These flags will be set after
+ saving the original flag state.
Returns:
- return value from func()
+ A wrapped version of func.
"""
+
@functools.wraps(func)
def _flagsaver_wrapper(*args, **kwargs):
"""Wrapper function that saves and restores flags."""
with _FlagOverrider(**overrides):
return func(*args, **kwargs)
+
return _flagsaver_wrapper
@@ -154,11 +172,11 @@ class _FlagOverrider(object):
completes.
"""
- def __init__(self, **overrides):
+ def __init__(self, **overrides: Any):
self._overrides = overrides
self._saved_flag_values = None
- def __call__(self, func):
+ def __call__(self, func: _CallableT) -> _CallableT:
if inspect.isclass(func):
raise TypeError('flagsaver cannot be applied to a class.')
return _wrap(func, self._overrides)
@@ -176,7 +194,7 @@ class _FlagOverrider(object):
restore_flag_values(self._saved_flag_values, FLAGS)
-def _copy_flag_dict(flag):
+def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
"""Returns a copy of the flag object's ``__dict__``.
It's mostly a shallow copy of the ``__dict__``, except it also does a shallow