Skip to content

Commit f6d9353

Browse files
oprypincopybara-github
authored andcommitted
Add mypy-compliant type annotations to absl-py (part 2)
Closes abseil#133 PiperOrigin-RevId: 646946172
1 parent 555b4f2 commit f6d9353

File tree

11 files changed

+89
-60
lines changed

11 files changed

+89
-60
lines changed

absl/app.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,19 @@ def usage(shorthelp: Union[bool, int] = ...,
4040
exitcode: None = ...) -> None:
4141
...
4242

43+
@overload
44+
def usage(shorthelp: Union[bool, int],
45+
writeto_stdout: Union[bool, int],
46+
detailed_error: Optional[Any],
47+
exitcode: int) -> NoReturn:
48+
...
49+
4350
@overload
4451
def usage(shorthelp: Union[bool, int] = ...,
4552
writeto_stdout: Union[bool, int] = ...,
4653
detailed_error: Optional[Any] = ...,
47-
exitcode: int = ...) -> NoReturn:
54+
*,
55+
exitcode: int) -> NoReturn:
4856
...
4957

5058
def install_exception_handler(handler: ExceptionHandler) -> None:

absl/logging/__init__.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Any, Callable, Dict, NoReturn, Optional, Tuple, TypeVar, Union
16+
from typing import Any, Callable, Dict, IO, NoReturn, Optional, Tuple, TypeVar, Union
1717

1818
from absl import flags
1919

@@ -154,10 +154,10 @@ def skip_log_prefix(func: _SkipLogT) -> _SkipLogT:
154154
...
155155

156156

157-
_StreamT = TypeVar("_StreamT")
157+
_StreamT = TypeVar('_StreamT')
158158

159159

160-
class PythonHandler(logging.StreamHandler[_StreamT]):
160+
class PythonHandler(logging.StreamHandler[_StreamT]): # type: ignore[type-var]
161161

162162
def __init__(
163163
self,
@@ -241,7 +241,7 @@ class ABSLLogger(logging.Logger):
241241
def critical(self, msg: Any, *args: Any, **kwargs: Any) -> None:
242242
...
243243

244-
def fatal(self, msg: Any, *args: Any, **kwargs: Any) -> NoReturn:
244+
def fatal(self, msg: Any, *args: Any, **kwargs: Any) -> NoReturn: # type: ignore[override]
245245
...
246246

247247
def error(self, msg: Any, *args: Any, **kwargs: Any) -> None:

absl/logging/tests/logging_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def test_set_google_log_file_with_log_to_stderr(self):
153153
self.assertEqual(sys.stderr, self.python_handler.stream)
154154

155155
@mock.patch.object(logging, 'find_log_dir_and_names')
156-
@mock.patch.object(logging.time, 'localtime')
157-
@mock.patch.object(logging.time, 'time')
156+
@mock.patch.object(time, 'localtime')
157+
@mock.patch.object(time, 'time')
158158
@mock.patch.object(os.path, 'islink')
159159
@mock.patch.object(os, 'unlink')
160160
@mock.patch.object(os, 'getpid')

absl/testing/_pretty_print_reporter.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,13 @@ class TextTestRunner(unittest.TextTestRunner):
7474
# Usually this is set using --pdb_post_mortem.
7575
run_for_debugging = False
7676

77-
def run(self, test):
78-
# type: (TestCase) -> TestResult
77+
def run(self, test) -> unittest.TestResult:
7978
if self.run_for_debugging:
8079
return self._run_debug(test)
8180
else:
8281
return super(TextTestRunner, self).run(test)
8382

84-
def _run_debug(self, test):
85-
# type: (TestCase) -> TestResult
83+
def _run_debug(self, test) -> unittest.TestResult:
8684
test.debug()
8785
# Return an empty result to indicate success.
8886
return self._makeResult()

absl/testing/absltest.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171

7272
# Private typing symbols.
7373
_T = typing.TypeVar('_T') # Unbounded TypeVar for general usage
74-
_OutcomeType = unittest.case._Outcome # pytype: disable=module-attr
7574
_TEXT_OR_BINARY_TYPES = (str, bytes)
7675

7776
# Suppress surplus entries in AssertionError stack traces.
@@ -244,7 +243,7 @@ def wasSuccessful(self) -> bool:
244243
test_result = unittest.TestResult()
245244
test_result.addUnexpectedSuccess(unittest.FunctionTestCase(lambda: None))
246245
if test_result.wasSuccessful(): # The bug is present.
247-
unittest.TestResult.wasSuccessful = wasSuccessful
246+
unittest.TestResult.wasSuccessful = wasSuccessful # type: ignore[method-assign]
248247
if test_result.wasSuccessful(): # Warn the user if our hot-fix failed.
249248
sys.stderr.write('unittest.result.TestResult monkey patch to report'
250249
' unexpected passes as failures did not work.\n')
@@ -373,7 +372,7 @@ def _create(
373372
cls,
374373
base_path: str,
375374
file_path: Optional[str],
376-
content: AnyStr,
375+
content: Optional[AnyStr],
377376
mode: str,
378377
encoding: str,
379378
errors: str,
@@ -531,26 +530,32 @@ class _method(object):
531530
(e.g. Cls.method(self, ...)) but is still situationally useful.
532531
"""
533532

533+
_finstancemethod: Any
534+
_fclassmethod: Optional[Any]
535+
534536
def __init__(self, finstancemethod: Callable[..., Any]) -> None:
535537
self._finstancemethod = finstancemethod
536538
self._fclassmethod = None
537539

538540
def classmethod(self, fclassmethod: Callable[..., Any]) -> '_method':
539-
self._fclassmethod = classmethod(fclassmethod)
541+
if isinstance(fclassmethod, classmethod):
542+
self._fclassmethod = fclassmethod
543+
else:
544+
self._fclassmethod = classmethod(fclassmethod)
540545
return self
541546

542-
def __doc__(self) -> str:
543-
if getattr(self._finstancemethod, '__doc__'):
544-
return self._finstancemethod.__doc__
545-
elif getattr(self._fclassmethod, '__doc__'):
546-
return self._fclassmethod.__doc__
547-
return ''
547+
def __doc__(self) -> str: # type: ignore[override]
548+
return (
549+
getattr(self._finstancemethod, '__doc__')
550+
or getattr(self._fclassmethod, '__doc__')
551+
or ''
552+
)
548553

549554
def __get__(
550555
self, obj: Optional[Any], type_: Optional[Type[Any]]
551556
) -> Callable[..., Any]:
552557
func = self._fclassmethod if obj is None else self._finstancemethod
553-
return func.__get__(obj, type_) # pytype: disable=attribute-error
558+
return func.__get__(obj, type_) # type: ignore[attribute-error]
554559

555560

556561
class TestCase(unittest.TestCase):
@@ -572,10 +577,10 @@ class TestCase(unittest.TestCase):
572577
_exit_stack = None
573578
_cls_exit_stack = None
574579

575-
def __init__(self, *args, **kwargs):
580+
def __init__(self, *args, **kwargs) -> None:
576581
super(TestCase, self).__init__(*args, **kwargs)
577582
# This is to work around missing type stubs in unittest.pyi
578-
self._outcome: Optional[_OutcomeType] = getattr(self, '_outcome')
583+
self._outcome: Optional[Any] = getattr(self, '_outcome')
579584

580585
def setUp(self):
581586
super(TestCase, self).setUp()
@@ -749,7 +754,8 @@ def enter_context(self, manager: ContextManager[_T]) -> _T:
749754
return self._exit_stack.enter_context(manager)
750755

751756
@enter_context.classmethod
752-
def enter_context(cls, manager: ContextManager[_T]) -> _T: # pylint: disable=no-self-argument
757+
@classmethod
758+
def _enter_context_cls(cls, manager: ContextManager[_T]) -> _T:
753759
if sys.version_info >= (3, 11):
754760
return cls.enterClassContext(manager)
755761

@@ -1355,7 +1361,7 @@ def assertRaisesWithPredicateMatch(
13551361
*args, **kwargs) -> None:
13561362
# The purpose of this return statement is to work around
13571363
# https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1358-
return self._AssertRaisesContext(None, None, None)
1364+
return self._AssertRaisesContext(None, None, None) # type: ignore[return-value]
13591365

13601366
def assertRaisesWithPredicateMatch(self, expected_exception, predicate,
13611367
callable_obj=None, *args, **kwargs):
@@ -1399,7 +1405,7 @@ def assertRaisesWithLiteralMatch(
13991405
callable_obj: Callable[..., Any], *args, **kwargs) -> None:
14001406
# The purpose of this return statement is to work around
14011407
# https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1402-
return self._AssertRaisesContext(None, None, None)
1408+
return self._AssertRaisesContext(None, None, None) # type: ignore[return-value]
14031409

14041410
def assertRaisesWithLiteralMatch(self, expected_exception,
14051411
expected_exception_message,
@@ -1782,7 +1788,7 @@ def assertDataclassEqual(self, first, second, msg=None):
17821788
'Cannot detect difference by examining the fields of the dataclass.'
17831789
)
17841790

1785-
raise self.fail('\n'.join(message), msg)
1791+
self.fail('\n'.join(message), msg)
17861792

17871793
def assertUrlEqual(self, a, b, msg=None):
17881794
"""Asserts that urls are equal, ignoring ordering of query params."""
@@ -1881,7 +1887,7 @@ def _getAssertEqualityFunc(
18811887

18821888
def fail(self, msg=None, user_msg=None) -> NoReturn:
18831889
"""Fail immediately with the given standard message and user message."""
1884-
return super(TestCase, self).fail(self._formatMessage(user_msg, msg))
1890+
super(TestCase, self).fail(self._formatMessage(user_msg, msg))
18851891

18861892

18871893
def _sorted_list_difference(
@@ -1909,12 +1915,12 @@ def _sorted_list_difference(
19091915
try:
19101916
e = expected[i]
19111917
a = actual[j]
1912-
if e < a:
1918+
if e < a: # type: ignore[operator]
19131919
missing.append(e)
19141920
i += 1
19151921
while expected[i] == e:
19161922
i += 1
1917-
elif e > a:
1923+
elif e > a: # type: ignore[operator]
19181924
unexpected.append(a)
19191925
j += 1
19201926
while actual[j] == a:
@@ -2204,12 +2210,12 @@ def _run_in_app(
22042210
# This must be a separate loop since multiple flag names (short_name=) can
22052211
# point to the same flag object.
22062212
for name in FLAGS:
2207-
FLAGS[name].parse = noop_parse
2213+
FLAGS[name].parse = noop_parse # type: ignore[method-assign]
22082214
try:
22092215
argv = FLAGS(sys.argv)
22102216
finally:
22112217
for name in FLAGS:
2212-
FLAGS[name].parse = stored_parse_methods[name]
2218+
FLAGS[name].parse = stored_parse_methods[name] # type: ignore[method-assign]
22132219
sys.stdout.flush()
22142220

22152221
function(argv, args, kwargs)
@@ -2364,6 +2370,7 @@ def get_default_xml_output_filename() -> Optional[str]:
23642370
return os.path.join(
23652371
os.environ['TEST_XMLOUTPUTDIR'],
23662372
os.path.splitext(os.path.basename(sys.argv[0]))[0] + '.xml')
2373+
return None
23672374

23682375

23692376
def _setup_filtering(argv: MutableSequence[str]) -> bool:
@@ -2490,7 +2497,7 @@ def getShardedTestCaseNames(testCaseClass):
24902497
filtered_names.append(testcase)
24912498
return [x for x in ordered_names if x in filtered_names]
24922499

2493-
base_loader.getTestCaseNames = getShardedTestCaseNames
2500+
base_loader.getTestCaseNames = getShardedTestCaseNames # type: ignore[method-assign]
24942501
return base_loader, shard_index
24952502

24962503

@@ -2559,9 +2566,12 @@ def _run_and_get_tests_result(
25592566
# XML file name is based upon (sorted by priority):
25602567
# --xml_output_file flag, XML_OUTPUT_FILE variable,
25612568
# TEST_XMLOUTPUTDIR variable or RUNNING_UNDER_TEST_DAEMON variable.
2562-
if not FLAGS.xml_output_file:
2563-
FLAGS.xml_output_file = get_default_xml_output_filename()
2564-
xml_output_file = FLAGS.xml_output_file
2569+
if FLAGS.xml_output_file:
2570+
xml_output_file = FLAGS.xml_output_file
2571+
else:
2572+
xml_output_file = get_default_xml_output_filename()
2573+
if xml_output_file:
2574+
FLAGS.xml_output_file = xml_output_file
25652575

25662576
xml_buffer = None
25672577
if xml_output_file:

absl/testing/flagsaver.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def some_func():
9292
import collections
9393
import functools
9494
import inspect
95-
from typing import overload, Any, Callable, Mapping, Tuple, TypeVar, Type, Sequence, Union
95+
from typing import Any, Callable, Dict, Mapping, Sequence, Tuple, Type, TypeVar, Union, overload
9696

9797
from absl import flags
9898

@@ -104,19 +104,20 @@ def some_func():
104104

105105

106106
@overload
107-
def flagsaver(*args: Tuple[flags.FlagHolder, Any],
108-
**kwargs: Any) -> '_FlagOverrider':
107+
def flagsaver(func: _CallableT) -> _CallableT:
109108
...
110109

111110

112111
@overload
113-
def flagsaver(func: _CallableT) -> _CallableT:
112+
def flagsaver(
113+
*args: Tuple[flags.FlagHolder, Any], **kwargs: Any
114+
) -> '_FlagOverrider':
114115
...
115116

116117

117118
def flagsaver(*args, **kwargs):
118119
"""The main flagsaver interface. See module doc for usage."""
119-
return _construct_overrider(_FlagOverrider, *args, **kwargs)
120+
return _construct_overrider(_FlagOverrider, *args, **kwargs) # type: ignore[bad-return-type]
120121

121122

122123
@overload
@@ -167,15 +168,18 @@ def _construct_overrider(
167168

168169

169170
@overload
170-
def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
171-
*args: Tuple[flags.FlagHolder, Any],
172-
**kwargs: Any) -> '_FlagOverrider':
171+
def _construct_overrider(
172+
flag_overrider_cls: Type['_FlagOverrider'], func: _CallableT
173+
) -> _CallableT:
173174
...
174175

175176

176177
@overload
177-
def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
178-
func: _CallableT) -> _CallableT:
178+
def _construct_overrider(
179+
flag_overrider_cls: Type['_FlagOverrider'],
180+
*args: Tuple[flags.FlagHolder, Any],
181+
**kwargs: Any,
182+
) -> '_FlagOverrider':
179183
...
180184

181185

@@ -220,7 +224,8 @@ def _construct_overrider(flag_overrider_cls, *args, **kwargs):
220224

221225

222226
def save_flag_values(
223-
flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]:
227+
flag_values: flags.FlagValues = FLAGS,
228+
) -> Dict[str, Dict[str, Any]]:
224229
"""Returns copy of flag values as a dict.
225230
226231
Args:
@@ -234,8 +239,10 @@ def save_flag_values(
234239
return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
235240

236241

237-
def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
238-
flag_values: flags.FlagValues = FLAGS):
242+
def restore_flag_values(
243+
saved_flag_values: Mapping[str, Dict[str, Any]],
244+
flag_values: flags.FlagValues = FLAGS,
245+
) -> None:
239246
"""Restores flag values based on the dictionary of flag values.
240247
241248
Args:
@@ -368,7 +375,7 @@ def __enter__(self):
368375
raise
369376

370377

371-
def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
378+
def _copy_flag_dict(flag: flags.Flag) -> Dict[str, Any]:
372379
"""Returns a copy of the flag object's ``__dict__``.
373380
374381
It's mostly a shallow copy of the ``__dict__``, except it also does a shallow

absl/testing/parameterized.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def id(self):
676676

677677

678678
# This function is kept CamelCase because it's used as a class's base class.
679-
def CoopTestCase(other_base_class): # pylint: disable=invalid-name
679+
def CoopTestCase(other_base_class) -> type: # pylint: disable=invalid-name, g-bare-generic
680680
"""Returns a new base class with a cooperative metaclass base.
681681
682682
This enables the TestCase to be used in combination
@@ -715,10 +715,10 @@ class CoopTestCaseBase(other_base_class, TestCase):
715715
return CoopTestCaseBase
716716
else:
717717

718-
class CoopMetaclass(type(other_base_class), TestGeneratorMetaclass): # pylint: disable=unused-variable
718+
class CoopMetaclass(type(other_base_class), TestGeneratorMetaclass): # type: ignore # pylint: disable=unused-variable
719719
pass
720720

721-
class CoopTestCaseBase(other_base_class, TestCase, metaclass=CoopMetaclass):
721+
class CoopTestCaseBase(other_base_class, TestCase, metaclass=CoopMetaclass): # type: ignore
722722
pass
723723

724724
return CoopTestCaseBase

0 commit comments

Comments
 (0)