Skip to content

Commit 226a4f1

Browse files
authored
Apply --strict-equality special-casing for bytes and bytearray on Python 2 (#7493)
Fixes #7465 The previous fix only worked on Python 3.
1 parent 41db9a0 commit 226a4f1

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

mypy/checkexpr.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,8 +2136,9 @@ def dangerous_comparison(self, left: Type, right: Type,
21362136
if isinstance(left, UnionType) and isinstance(right, UnionType):
21372137
left = remove_optional(left)
21382138
right = remove_optional(right)
2139-
if (original_container and has_bytes_component(original_container) and
2140-
has_bytes_component(left)):
2139+
py2 = self.chk.options.python_version < (3, 0)
2140+
if (original_container and has_bytes_component(original_container, py2) and
2141+
has_bytes_component(left, py2)):
21412142
# We need to special case bytes and bytearray, because 97 in b'abc', b'a' in b'abc',
21422143
# b'a' in bytearray(b'abc') etc. all return True (and we want to show the error only
21432144
# if the check can _never_ be True).
@@ -4179,13 +4180,16 @@ def custom_equality_method(typ: Type) -> bool:
41794180
return False
41804181

41814182

4182-
def has_bytes_component(typ: Type) -> bool:
4183+
def has_bytes_component(typ: Type, py2: bool = False) -> bool:
41834184
"""Is this one of builtin byte types, or a union that contains it?"""
41844185
typ = get_proper_type(typ)
4186+
if py2:
4187+
byte_types = {'builtins.str', 'builtins.bytearray'}
4188+
else:
4189+
byte_types = {'builtins.bytes', 'builtins.bytearray'}
41854190
if isinstance(typ, UnionType):
41864191
return any(has_bytes_component(t) for t in typ.items)
4187-
if isinstance(typ, Instance) and typ.type.fullname() in {'builtins.bytes',
4188-
'builtins.bytearray'}:
4192+
if isinstance(typ, Instance) and typ.type.fullname() in byte_types:
41894193
return True
41904194
return False
41914195

test-data/unit/check-expressions.test

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2066,6 +2066,12 @@ bytearray(b'abc') in b'abcde' # OK on Python 3
20662066
[builtins fixtures/primitives.pyi]
20672067
[typing fixtures/typing-full.pyi]
20682068

2069+
[case testBytesVsByteArray_python2]
2070+
# flags: --strict-equality --py2
2071+
b'hi' in bytearray(b'hi')
2072+
[builtins_py2 fixtures/python2.pyi]
2073+
[typing fixtures/typing-full.pyi]
2074+
20692075
[case testStrictEqualityNoPromotePy3]
20702076
# flags: --strict-equality
20712077
'a' == b'a' # E: Non-overlapping equality check (left operand type: "Literal['a']", right operand type: "Literal[b'a']")

test-data/unit/fixtures/python2.pyi

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generic, Iterable, TypeVar
1+
from typing import Generic, Iterable, TypeVar, Sequence, Iterator
22

33
class object:
44
def __init__(self) -> None: pass
@@ -13,9 +13,16 @@ class function: pass
1313
class int: pass
1414
class str: pass
1515
class unicode: pass
16-
class bool: pass
16+
class bool(int): pass
17+
class bytearray(Sequence[int]):
18+
def __init__(self, string: str) -> None: pass
19+
def __contains__(self, item: object) -> bool: pass
20+
def __iter__(self) -> Iterator[int]: pass
21+
def __getitem__(self, item: int) -> int: pass
1722

1823
T = TypeVar('T')
19-
class list(Iterable[T], Generic[T]): pass
24+
class list(Iterable[T], Generic[T]):
25+
def __iter__(self) -> Iterator[T]: pass
26+
def __getitem__(self, item: int) -> T: pass
2027

2128
# Definition of None is implicit

0 commit comments

Comments
 (0)