Skip to content

Commit 555b4f2

Browse files
oprypincopybara-github
authored andcommitted
Add mypy-compliant type annotations to absl-py (part 1)
PiperOrigin-RevId: 645120089
1 parent 2d86d97 commit 555b4f2

File tree

6 files changed

+94
-80
lines changed

6 files changed

+94
-80
lines changed

absl/flags/_argument_parser.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import enum
2424
import io
2525
import string
26-
from typing import Generic, Iterable, List, Optional, Sequence, Type, TypeVar, Union
26+
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Type, TypeVar, Union
2727
from xml.dom import minidom
2828

2929
from absl.flags import _helpers
@@ -33,16 +33,10 @@
3333
_N = TypeVar('_N', int, float)
3434

3535

36-
def _is_integer_type(instance):
37-
"""Returns True if instance is an integer, and not a bool."""
38-
return (isinstance(instance, int) and
39-
not isinstance(instance, bool))
40-
41-
4236
class _ArgumentParserCache(type):
4337
"""Metaclass used to cache and share argument parsers among flags."""
4438

45-
_instances = {}
39+
_instances: Dict[Any, Any] = {}
4640

4741
def __call__(cls, *args, **kwargs):
4842
"""Returns an instance of the argument parser cls.
@@ -115,7 +109,7 @@ def parse(self, argument: str) -> Optional[_T]:
115109
if not isinstance(argument, str):
116110
raise TypeError('flag value must be a string, found "{}"'.format(
117111
type(argument)))
118-
return argument
112+
return argument # type: ignore[return-value]
119113

120114
def flag_type(self) -> str:
121115
"""Returns a string representing the type of the flag."""
@@ -155,7 +149,7 @@ def is_outside_bounds(self, val: _N) -> bool:
155149
return ((self.lower_bound is not None and val < self.lower_bound) or
156150
(self.upper_bound is not None and val > self.upper_bound))
157151

158-
def parse(self, argument: str) -> _N:
152+
def parse(self, argument: Union[str, _N]) -> _N:
159153
"""See base class."""
160154
val = self.convert(argument)
161155
if self.is_outside_bounds(val):
@@ -174,7 +168,7 @@ def _custom_xml_dom_elements(
174168
doc, 'upper_bound', self.upper_bound))
175169
return elements
176170

177-
def convert(self, argument: str) -> _N:
171+
def convert(self, argument: Union[str, _N]) -> _N:
178172
"""Returns the correct numeric value of argument.
179173
180174
Subclass must implement this method, and raise TypeError if argument is not
@@ -222,8 +216,11 @@ def __init__(
222216

223217
def convert(self, argument: Union[int, float, str]) -> float:
224218
"""Returns the float value of argument."""
225-
if (_is_integer_type(argument) or isinstance(argument, float) or
226-
isinstance(argument, str)):
219+
if (
220+
(isinstance(argument, int) and not isinstance(argument, bool))
221+
or isinstance(argument, float)
222+
or isinstance(argument, str)
223+
):
227224
return float(argument)
228225
else:
229226
raise TypeError(
@@ -269,7 +266,7 @@ def __init__(
269266

270267
def convert(self, argument: Union[int, str]) -> int:
271268
"""Returns the int value of argument."""
272-
if _is_integer_type(argument):
269+
if isinstance(argument, int) and not isinstance(argument, bool):
273270
return argument
274271
elif isinstance(argument, str):
275272
base = 10
@@ -470,6 +467,8 @@ class EnumClassListSerializer(ListSerializer[_ET]):
470467
provided separator.
471468
"""
472469

470+
_element_serializer: 'EnumClassSerializer'
471+
473472
def __init__(self, list_sep: str, **kwargs) -> None:
474473
"""Initializes EnumClassListSerializer.
475474

absl/flags/_flag.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class Flag(Generic[_T]):
8787
default_as_str: Optional[str]
8888
default_unparsed: Union[Optional[_T], str]
8989

90+
parser: _argument_parser.ArgumentParser[_T]
91+
9092
def __init__(
9193
self,
9294
parser: _argument_parser.ArgumentParser[_T],
@@ -111,7 +113,7 @@ def __init__(
111113
self.short_name = short_name
112114
self.boolean = boolean
113115
self.present = 0
114-
self.parser = parser
116+
self.parser = parser # type: ignore[annotation-type-mismatch]
115117
self.serializer = serializer
116118
self.allow_override = allow_override
117119
self.allow_override_cpp = allow_override_cpp
@@ -120,8 +122,8 @@ def __init__(
120122
self.allow_using_method_names = allow_using_method_names
121123

122124
self.using_default_value = True
123-
self._value = None
124-
self.validators = []
125+
self._value: Optional[_T] = None
126+
self.validators: List[Any] = []
125127
if self.allow_hide_cpp and self.allow_override_cpp:
126128
raise _exceptions.Error(
127129
"Can't have both allow_hide_cpp (means use Python flag) and "
@@ -177,7 +179,7 @@ def _get_parsed_value_as_string(self, value: Optional[_T]) -> Optional[str]:
177179
return repr('false')
178180
return repr(str(value))
179181

180-
def parse(self, argument: Union[str, Optional[_T]]) -> None:
182+
def parse(self, argument: Union[str, _T]) -> None:
181183
"""Parses string and sets flag value.
182184
183185
Args:
@@ -202,7 +204,7 @@ def _parse(self, argument: Union[str, _T]) -> Optional[_T]:
202204
The parsed value.
203205
"""
204206
try:
205-
return self.parser.parse(argument)
207+
return self.parser.parse(argument) # type: ignore[arg-type]
206208
except (TypeError, ValueError, OverflowError) as e:
207209
# Recast as IllegalFlagValueError.
208210
raise _exceptions.IllegalFlagValueError(
@@ -298,7 +300,7 @@ def _create_xml_dom_element(
298300
else:
299301
default_serialized = ''
300302
else:
301-
default_serialized = self.default
303+
default_serialized = self.default # type: ignore[assignment]
302304
element.appendChild(_helpers.create_xml_dom_element(
303305
doc, 'default', default_serialized))
304306
value_serialized = self._serialize_value_for_xml(self.value)
@@ -363,6 +365,8 @@ def __init__(
363365
class EnumFlag(Flag[str]):
364366
"""Basic enum flag; its value can be any string from list of enum_values."""
365367

368+
parser: _argument_parser.EnumParser
369+
366370
def __init__(
367371
self,
368372
name: str,
@@ -374,11 +378,11 @@ def __init__(
374378
**args
375379
):
376380
p = _argument_parser.EnumParser(enum_values, case_sensitive)
381+
g: _argument_parser.ArgumentSerializer[str]
377382
g = _argument_parser.ArgumentSerializer()
378383
super(EnumFlag, self).__init__(
379-
p, g, name, default, help, short_name, **args)
380-
# NOTE: parser should be typed EnumParser but the constructor
381-
# restricts the available interface to ArgumentParser[str].
384+
p, g, name, default, help, short_name, **args
385+
)
382386
self.parser = p
383387
self.help = '<%s>: %s' % ('|'.join(p.enum_values), self.help)
384388

@@ -395,6 +399,8 @@ def _extra_xml_dom_elements(
395399
class EnumClassFlag(Flag[_ET]):
396400
"""Basic enum flag; its value is an enum class's member."""
397401

402+
parser: _argument_parser.EnumClassParser
403+
398404
def __init__(
399405
self,
400406
name: str,
@@ -406,12 +412,13 @@ def __init__(
406412
**args
407413
):
408414
p = _argument_parser.EnumClassParser(
409-
enum_class, case_sensitive=case_sensitive)
415+
enum_class, case_sensitive=case_sensitive
416+
)
417+
g: _argument_parser.EnumClassSerializer[_ET]
410418
g = _argument_parser.EnumClassSerializer(lowercase=not case_sensitive)
411419
super(EnumClassFlag, self).__init__(
412-
p, g, name, default, help, short_name, **args)
413-
# NOTE: parser should be typed EnumClassParser[_ET] but the constructor
414-
# restricts the available interface to ArgumentParser[_ET].
420+
p, g, name, default, help, short_name, **args
421+
)
415422
self.parser = p
416423
self.help = '<%s>: %s' % ('|'.join(p.member_names), self.help)
417424

@@ -456,23 +463,28 @@ def parse(self, arguments: Union[str, _T, Iterable[_T]]): # pylint: disable=arg
456463
"""
457464
new_values = self._parse(arguments)
458465
if self.present:
466+
assert self.value is not None
459467
self.value.extend(new_values)
460468
else:
461469
self.value = new_values
462470
self.present += len(new_values)
463471

464-
def _parse(self, arguments: Union[str, Optional[Iterable[_T]]]) -> List[_T]: # pylint: disable=arguments-renamed
465-
if (isinstance(arguments, abc.Iterable) and
466-
not isinstance(arguments, str)):
467-
arguments = list(arguments)
472+
def _parse(self, arguments: Union[str, _T, Iterable[_T]]) -> List[_T]: # pylint: disable=arguments-renamed
473+
arguments_list: List[Union[str, _T]]
474+
475+
if isinstance(arguments, str):
476+
arguments_list = [arguments]
468477

469-
if not isinstance(arguments, list):
478+
elif isinstance(arguments, abc.Iterable):
479+
arguments_list = list(arguments)
480+
481+
else:
470482
# Default value may be a list of values. Most other arguments
471483
# will not be, so convert them into a single-item list to make
472484
# processing simpler below.
473-
arguments = [arguments]
485+
arguments_list = [arguments]
474486

475-
return [super(MultiFlag, self)._parse(item) for item in arguments]
487+
return [super(MultiFlag, self)._parse(item) for item in arguments_list] # type: ignore
476488

477489
def _serialize(self, value: Optional[List[_T]]) -> str:
478490
"""See base class."""
@@ -483,7 +495,8 @@ def _serialize(self, value: Optional[List[_T]]) -> str:
483495
return ''
484496

485497
serialized_items = [
486-
super(MultiFlag, self)._serialize(value_item) for value_item in value
498+
super(MultiFlag, self)._serialize(value_item) # type: ignore[arg-type]
499+
for value_item in value
487500
]
488501

489502
return '\n'.join(serialized_items)
@@ -511,6 +524,8 @@ class MultiEnumClassFlag(MultiFlag[_ET]): # pytype: disable=not-indexable
511524
type.
512525
"""
513526

527+
parser: _argument_parser.EnumClassParser[_ET] # type: ignore[assignment]
528+
514529
def __init__(
515530
self,
516531
name: str,
@@ -522,6 +537,7 @@ def __init__(
522537
):
523538
p = _argument_parser.EnumClassParser(
524539
enum_class, case_sensitive=case_sensitive)
540+
g: _argument_parser.EnumClassListSerializer
525541
g = _argument_parser.EnumClassListSerializer(
526542
list_sep=',', lowercase=not case_sensitive)
527543
super(MultiEnumClassFlag, self).__init__(

0 commit comments

Comments
 (0)