Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ show_error_codes = True
[mypy-telegram.vendor.*]
ignore_errors = True

# Disable strict optional for telegram objects with class methods
# We don't want to clutter the code with 'if self.text is None: raise RuntimeError()'
[mypy-telegram._callbackquery,telegram._chat,telegram._message,telegram._user,telegram._files.*,telegram._inline.inlinequery,telegram._payment.precheckoutquery,telegram._payment.shippingquery,telegram._passport.passportdata,telegram._passport.credentials,telegram._passport.passportfile,telegram.ext.filters,telegram._chatjoinrequest]
# For some files, it's easier to just disable strict-optional all together instead of
# cluttering the code with `# type: ignore`s or stuff like
# `if self.text is None: raise RuntimeError()`
[mypy-telegram._callbackquery,telegram._file,telegram._message,telegram._files.file]
strict_optional = False

# type hinting for asyncio in webhookhandler is a bit tricky because it depends on the OS
Expand Down
18 changes: 9 additions & 9 deletions telegram/_files/inputfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,28 @@ class InputFile:
__slots__ = ('filename', 'attach', 'input_file_content', 'mimetype')

def __init__(self, obj: Union[IO, bytes], filename: str = None, attach: bool = None):
self.filename = None
if isinstance(obj, bytes):
self.input_file_content = obj
else:
self.input_file_content = obj.read()
self.attach = 'attached' + uuid4().hex if attach else None

if filename:
self.filename = filename
elif hasattr(obj, 'name') and not isinstance(obj.name, int): # type: ignore[union-attr]
self.filename = Path(obj.name).name # type: ignore[union-attr]
if (
not filename
and hasattr(obj, 'name')
and not isinstance(obj.name, int) # type: ignore[union-attr]
):
filename = Path(obj.name).name # type: ignore[union-attr]

image_mime_type = self.is_image(self.input_file_content)
if image_mime_type:
self.mimetype = image_mime_type
elif self.filename:
self.mimetype = mimetypes.guess_type(self.filename)[0] or DEFAULT_MIME_TYPE
elif filename:
self.mimetype = mimetypes.guess_type(filename)[0] or DEFAULT_MIME_TYPE
else:
self.mimetype = DEFAULT_MIME_TYPE

if not self.filename:
self.filename = self.mimetype.replace('/', '.')
self.filename = filename or self.mimetype.replace('/', '.')

@property
def field_tuple(self) -> Tuple[str, bytes, str]: # skipcq: PY-D0003
Expand Down
2 changes: 1 addition & 1 deletion telegram/_files/sticker.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def to_dict(self) -> JSONDict:
"""See :meth:`telegram.TelegramObject.to_dict`."""
data = super().to_dict()

data['stickers'] = [s.to_dict() for s in data.get('stickers')]
data['stickers'] = [s.to_dict() for s in data.get('stickers')] # type: ignore[union-attr]

return data

Expand Down
16 changes: 8 additions & 8 deletions telegram/_passport/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@

CRYPTO_INSTALLED = True
except ImportError:
default_backend = None
MGF1, OAEP, Cipher, AES, CBC = (None, None, None, None, None) # type: ignore[misc]
SHA1, SHA256, SHA512, Hash = (None, None, None, None) # type: ignore[misc]
default_backend = None # type: ignore[assignment]
MGF1, OAEP, Cipher, AES, CBC = (None, None, None, None, None) # type: ignore[misc,assignment]
SHA1, SHA256, SHA512, Hash = (None, None, None, None) # type: ignore[misc,assignment]

CRYPTO_INSTALLED = False

Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self, data: str, hash: str, secret: str, bot: 'Bot' = None, **_kwar
self._id_attrs = (self.data, self.hash, self.secret)

self.set_bot(bot)
self._decrypted_secret = None
self._decrypted_secret: Optional[str] = None
self._decrypted_data: Optional['Credentials'] = None

@property
Expand All @@ -175,7 +175,7 @@ def decrypted_secret(self) -> str:
# is the default for OAEP, the algorithm is the default for PHP which is what
# Telegram's backend servers run.
try:
self._decrypted_secret = self.get_bot().private_key.decrypt(
self._decrypted_secret = self.get_bot().private_key.decrypt( # type: ignore
b64decode(self.secret),
OAEP(mgf=MGF1(algorithm=SHA1()), algorithm=SHA1(), label=None), # skipcq
)
Expand All @@ -200,7 +200,7 @@ def decrypted_data(self) -> 'Credentials':
decrypt_json(self.decrypted_secret, b64decode(self.hash), b64decode(self.data)),
self.get_bot(),
)
return self._decrypted_data
return self._decrypted_data # type: ignore[return-value]


class Credentials(TelegramObject):
Expand Down Expand Up @@ -403,8 +403,8 @@ def to_dict(self) -> JSONDict:
"""See :meth:`telegram.TelegramObject.to_dict`."""
data = super().to_dict()

data['files'] = [p.to_dict() for p in self.files]
data['translation'] = [p.to_dict() for p in self.translation]
data['files'] = [p.to_dict() for p in self.files] # type: ignore[union-attr]
data['translation'] = [p.to_dict() for p in self.translation] # type: ignore[union-attr]

return data

Expand Down
2 changes: 1 addition & 1 deletion telegram/_passport/passportdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def decrypted_data(self) -> List[EncryptedPassportElement]:
"""
if self._decrypted_data is None:
self._decrypted_data = [
EncryptedPassportElement.de_json_decrypted(
EncryptedPassportElement.de_json_decrypted( # type: ignore[misc]
element.to_dict(), self.get_bot(), self.decrypted_credentials
)
for element in self.data
Expand Down
2 changes: 1 addition & 1 deletion telegram/ext/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_instance(cls) -> 'Dispatcher':
raise RuntimeError(f'{cls.__name__} not initialized or multiple instances exist')

def _pooled(self) -> None:
thr_name = current_thread().getName()
thr_name = current_thread().name
while 1:
promise = self.__async_queue.get()

Expand Down
45 changes: 24 additions & 21 deletions telegram/ext/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ class MessageFilter(BaseFilter):
__slots__ = ()

def check_update(self, update: Update) -> Optional[Union[bool, DataDict]]:
return self.filter(update.effective_message) if super().check_update(update) else False
if super().check_update(update):
return self.filter(update.effective_message) # type: ignore[arg-type]
return False

@abstractmethod
def filter(self, message: Message) -> Optional[Union[bool, DataDict]]:
Expand Down Expand Up @@ -559,22 +561,22 @@ def get_chat_or_user(self, message: Message) -> Union[TGChat, TGUser, None]:
...

@staticmethod
def _parse_chat_id(chat_id: SLT[int]) -> Set[int]:
def _parse_chat_id(chat_id: Optional[SLT[int]]) -> Set[int]:
if chat_id is None:
return set()
if isinstance(chat_id, int):
return {chat_id}
return set(chat_id)

@staticmethod
def _parse_username(username: SLT[str]) -> Set[str]:
def _parse_username(username: Optional[SLT[str]]) -> Set[str]:
if username is None:
return set()
if isinstance(username, str):
return {username[1:] if username.startswith('@') else username}
return {chat[1:] if chat.startswith('@') else chat for chat in username}

def _set_chat_ids(self, chat_id: SLT[int]) -> None:
def _set_chat_ids(self, chat_id: Optional[SLT[int]]) -> None:
with self.__lock:
if chat_id and self._usernames:
raise RuntimeError(
Expand All @@ -583,7 +585,7 @@ def _set_chat_ids(self, chat_id: SLT[int]) -> None:
)
self._chat_ids = self._parse_chat_id(chat_id)

def _set_usernames(self, username: SLT[str]) -> None:
def _set_usernames(self, username: Optional[SLT[str]]) -> None:
with self.__lock:
if username and self._chat_ids:
raise RuntimeError(
Expand Down Expand Up @@ -1077,7 +1079,7 @@ def __init__(self, category: str):
super().__init__(name=f"filters.Document.Category('{self._category}')")

def filter(self, message: Message) -> bool:
if message.document:
if message.document and message.document.mime_type:
return message.document.mime_type.startswith(self._category)
return False

Expand Down Expand Up @@ -1141,7 +1143,7 @@ def __init__(self, file_extension: Optional[str], case_sensitive: bool = False):
self.name = f"filters.Document.FileExtension({file_extension.lower()!r})"

def filter(self, message: Message) -> bool:
if message.document is None:
if message.document is None or message.document.file_name is None:
return False
if self._file_extension is None:
return "." not in message.document.file_name
Expand Down Expand Up @@ -1179,35 +1181,35 @@ def filter(self, message: Message) -> bool:

APK = MimeType('application/vnd.android.package-archive')
"""Use as ``filters.Document.APK``."""
DOC = MimeType(mimetypes.types_map.get('.doc'))
DOC = MimeType(mimetypes.types_map['.doc'])
"""Use as ``filters.Document.DOC``."""
DOCX = MimeType('application/vnd.openxmlformats-officedocument.wordprocessingml.document')
"""Use as ``filters.Document.DOCX``."""
EXE = MimeType(mimetypes.types_map.get('.exe'))
EXE = MimeType(mimetypes.types_map['.exe'])
"""Use as ``filters.Document.EXE``."""
MP4 = MimeType(mimetypes.types_map.get('.mp4'))
MP4 = MimeType(mimetypes.types_map['.mp4'])
"""Use as ``filters.Document.MP4``."""
GIF = MimeType(mimetypes.types_map.get('.gif'))
GIF = MimeType(mimetypes.types_map['.gif'])
"""Use as ``filters.Document.GIF``."""
JPG = MimeType(mimetypes.types_map.get('.jpg'))
JPG = MimeType(mimetypes.types_map['.jpg'])
"""Use as ``filters.Document.JPG``."""
MP3 = MimeType(mimetypes.types_map.get('.mp3'))
MP3 = MimeType(mimetypes.types_map['.mp3'])
"""Use as ``filters.Document.MP3``."""
PDF = MimeType(mimetypes.types_map.get('.pdf'))
PDF = MimeType(mimetypes.types_map['.pdf'])
"""Use as ``filters.Document.PDF``."""
PY = MimeType(mimetypes.types_map.get('.py'))
PY = MimeType(mimetypes.types_map['.py'])
"""Use as ``filters.Document.PY``."""
SVG = MimeType(mimetypes.types_map.get('.svg'))
SVG = MimeType(mimetypes.types_map['.svg'])
"""Use as ``filters.Document.SVG``."""
TXT = MimeType(mimetypes.types_map.get('.txt'))
TXT = MimeType(mimetypes.types_map['.txt'])
"""Use as ``filters.Document.TXT``."""
TARGZ = MimeType('application/x-compressed-tar')
"""Use as ``filters.Document.TARGZ``."""
WAV = MimeType(mimetypes.types_map.get('.wav'))
WAV = MimeType(mimetypes.types_map['.wav'])
"""Use as ``filters.Document.WAV``."""
XML = MimeType(mimetypes.types_map.get('.xml'))
XML = MimeType(mimetypes.types_map['.xml'])
"""Use as ``filters.Document.XML``."""
ZIP = MimeType(mimetypes.types_map.get('.zip'))
ZIP = MimeType(mimetypes.types_map['.zip'])
"""Use as ``filters.Document.ZIP``."""

def filter(self, message: Message) -> bool:
Expand Down Expand Up @@ -1374,7 +1376,8 @@ def __init__(self, lang: SLT[str]):

def filter(self, message: Message) -> bool:
return bool(
message.from_user.language_code
message.from_user
and message.from_user.language_code
and any(message.from_user.language_code.startswith(x) for x in self.lang)
)

Expand Down
4 changes: 2 additions & 2 deletions telegram/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def post(self, url: str, data: JSONDict, timeout: float = None) -> Union[JSONDic
media_dict = med.to_dict()
media.append(media_dict)
if isinstance(med.media, InputFile):
data[med.media.attach] = med.media.field_tuple
data[med.media.attach] = med.media.field_tuple # type: ignore[index]
# if the file has a thumb, we also need to attach it to the data
if "thumb" in media_dict:
data[med.thumb.attach] = med.thumb.field_tuple
Expand All @@ -345,7 +345,7 @@ def post(self, url: str, data: JSONDict, timeout: float = None) -> Union[JSONDic
# Attach and set val to attached name
media_dict = val.to_dict()
if isinstance(val.media, InputFile):
data[val.media.attach] = val.media.field_tuple
data[val.media.attach] = val.media.field_tuple # type: ignore[index]
# if the file has a thumb, we also need to attach it to the data
if "thumb" in media_dict:
data[val.thumb.attach] = val.thumb.field_tuple
Expand Down
10 changes: 10 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,10 @@ def test_filters_document_type(self, update):
assert filters.Document.Category("application/").check_update(update)
assert filters.Document.MimeType("application/x-sh").check_update(update)

update.message.document.mime_type = None
assert not filters.Document.Category("application/").check_update(update)
assert not filters.Document.MimeType("application/x-sh").check_update(update)

def test_filters_file_extension_basic(self, update):
update.message.document = Document(
"file_id",
Expand All @@ -715,6 +719,9 @@ def test_filters_file_extension_basic(self, update):
assert not filters.Document.FileExtension("tgz").check_update(update)
assert not filters.Document.FileExtension("jpg").check_update(update)

update.message.document.file_name = None
assert not filters.Document.FileExtension("jpg").check_update(update)

update.message.document = None
assert not filters.Document.FileExtension("jpg").check_update(update)

Expand Down Expand Up @@ -1811,6 +1818,9 @@ def test_language_filter_single(self, update):
assert not filters.Language('en_GB').check_update(update)
assert filters.Language('da').check_update(update)

update.message.from_user = None
assert not filters.Language('da').check_update(update)

def test_language_filter_multiple(self, update):
f = filters.Language(['en_US', 'da'])
update.message.from_user.language_code = 'en_US'
Expand Down