From 522d3f00da1adb8b6638fd9ab3782938a54aa3dd Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 29 Dec 2024 16:07:34 +0100 Subject: [PATCH 1/6] Make `ADDITIONAL_TYPES` easier to configure --- tests/test_official/arg_type_checker.py | 20 +++++++++++++++----- tests/test_official/exceptions.py | 21 ++++++++++++--------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/tests/test_official/arg_type_checker.py b/tests/test_official/arg_type_checker.py index 5f90abc00b5..87670e1e6ee 100644 --- a/tests/test_official/arg_type_checker.py +++ b/tests/test_official/arg_type_checker.py @@ -144,7 +144,7 @@ def check_param_type( ) # CHECKING: - # Each branch manipulates the `mapped_type` (except for 4) ) to match the `ptb_annotation`. + # Each branch manipulates the `mapped_type` (except for 5) ) to match the `ptb_annotation`. # 1) HANDLING ARRAY TYPES: # Now let's do the checking, starting with "Array of ..." types. @@ -174,12 +174,22 @@ def check_param_type( # 2) HANDLING OTHER TYPES: # Special case for send_* methods where we accept more types than the official API: - elif ptb_param.name in PTCE.ADDITIONAL_TYPES and obj.__name__.startswith("send"): - log("Checking that `%s` has an additional argument!\n", ptb_param.name) - mapped_type = mapped_type | PTCE.ADDITIONAL_TYPES[ptb_param.name] + elif mappings := [ + mapping + for pattern, mapping in PTCE.ADDITIONAL_TYPES.items() + if (re.match(pattern, obj.__name__)) + ]: + log("Checking that `%s` accepts additional types for some parameters!\n", obj.__name__) + for mapping in mappings: + for key, value in mapping.items(): + if not re.match(key, ptb_param.name): + continue + + log("Checking that `%s` is an additional type for `%s`!\n", value, ptb_param.name) + mapped_type = mapped_type | value # 3) HANDLING DATETIMES: - elif ( + if ( re.search( DATETIME_REGEX, ptb_param.name, diff --git a/tests/test_official/exceptions.py b/tests/test_official/exceptions.py index d6eb421e8ba..c16cff100e0 100644 --- a/tests/test_official/exceptions.py +++ b/tests/test_official/exceptions.py @@ -35,16 +35,19 @@ class ParamTypeCheckingExceptions: # Types for certain parameters accepted by PTB but not in the official API + # structure: method/class_name/regex: {param_name/regex: type} ADDITIONAL_TYPES = { - "photo": PhotoSize, - "video": Video, - "video_note": VideoNote, - "audio": Audio, - "document": Document, - "animation": Animation, - "voice": Voice, - "sticker": Sticker, - "gift_id": Gift, + "send_*": { + "photo$": PhotoSize, + "video$": Video, + "video_note": VideoNote, + "audio": Audio, + "document": Document, + "animation": Animation, + "voice": Voice, + "sticker": Sticker, + "gift_id": Gift, + } } # TODO: Look into merging this with COMPLEX_TYPES From b2b8d573163d75c421820345145d02bfc7f5a958 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 29 Dec 2024 16:37:36 +0100 Subject: [PATCH 2/6] Allow Input of Type `Sticker` for Several Methods --- telegram/_bot.py | 72 +++++++++++++++++++++++-------- telegram/ext/_extbot.py | 12 +++--- tests/_files/test_sticker.py | 48 +++++++++++++++++++++ tests/test_official/exceptions.py | 8 +++- 4 files changed, 115 insertions(+), 25 deletions(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index dc1ef7ff43f..5a083c31b0e 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -6622,7 +6622,7 @@ async def add_sticker_to_set( async def set_sticker_position_in_set( self, - sticker: str, + sticker: Union[str, "Sticker"], position: int, *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -6634,7 +6634,11 @@ async def set_sticker_position_in_set( """Use this method to move a sticker in a set created by the bot to a specific position. Args: - sticker (:obj:`str`): File identifier of the sticker. + sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the sticker or + the sticker object. + + .. versionchanged:: NEXT.VERSION + Accepts also :class:`telegram.Sticker` instances. position (:obj:`int`): New sticker position in the set, zero-based. Returns: @@ -6644,7 +6648,10 @@ async def set_sticker_position_in_set( :class:`telegram.error.TelegramError` """ - data: JSONDict = {"sticker": sticker, "position": position} + data: JSONDict = { + "sticker": sticker if isinstance(sticker, str) else sticker.file_id, + "position": position, + } return await self._post( "setStickerPositionInSet", data, @@ -6749,7 +6756,7 @@ async def create_new_sticker_set( async def delete_sticker_from_set( self, - sticker: str, + sticker: Union[str, "Sticker"], *, read_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, @@ -6760,7 +6767,11 @@ async def delete_sticker_from_set( """Use this method to delete a sticker from a set created by the bot. Args: - sticker (:obj:`str`): File identifier of the sticker. + sticker (:obj:`str` | :class:`telegram.Sticker`): File identifier of the sticker or + the sticker object. + + .. versionchanged:: NEXT.VERSION + Accepts also :class:`telegram.Sticker` instances. Returns: :obj:`bool`: On success, :obj:`True` is returned. @@ -6769,7 +6780,7 @@ async def delete_sticker_from_set( :class:`telegram.error.TelegramError` """ - data: JSONDict = {"sticker": sticker} + data: JSONDict = {"sticker": sticker if isinstance(sticker, str) else sticker.file_id} return await self._post( "deleteStickerFromSet", data, @@ -6937,7 +6948,7 @@ async def set_sticker_set_title( async def set_sticker_emoji_list( self, - sticker: str, + sticker: Union[str, "Sticker"], emoji_list: Sequence[str], *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -6953,7 +6964,11 @@ async def set_sticker_emoji_list( .. versionadded:: 20.2 Args: - sticker (:obj:`str`): File identifier of the sticker. + sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the sticker or + the sticker object. + + .. versionchanged:: NEXT.VERSION + Accepts also :class:`telegram.Sticker` instances. emoji_list (Sequence[:obj:`str`]): A sequence of :tg-const:`telegram.constants.StickerLimit.MIN_STICKER_EMOJI`- :tg-const:`telegram.constants.StickerLimit.MAX_STICKER_EMOJI` emoji associated with @@ -6965,7 +6980,10 @@ async def set_sticker_emoji_list( Raises: :class:`telegram.error.TelegramError` """ - data: JSONDict = {"sticker": sticker, "emoji_list": emoji_list} + data: JSONDict = { + "sticker": sticker if isinstance(sticker, str) else sticker.file_id, + "emoji_list": emoji_list, + } return await self._post( "setStickerEmojiList", data, @@ -6978,7 +6996,7 @@ async def set_sticker_emoji_list( async def set_sticker_keywords( self, - sticker: str, + sticker: Union[str, "Sticker"], keywords: Optional[Sequence[str]] = None, *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -6994,7 +7012,11 @@ async def set_sticker_keywords( .. versionadded:: 20.2 Args: - sticker (:obj:`str`): File identifier of the sticker. + sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the sticker or + the sticker object. + + .. versionchanged:: NEXT.VERSION + Accepts also :class:`telegram.Sticker` instances. keywords (Sequence[:obj:`str`]): A sequence of 0-:tg-const:`telegram.constants.StickerLimit.MAX_SEARCH_KEYWORDS` search keywords for the sticker with total length up to @@ -7006,7 +7028,10 @@ async def set_sticker_keywords( Raises: :class:`telegram.error.TelegramError` """ - data: JSONDict = {"sticker": sticker, "keywords": keywords} + data: JSONDict = { + "sticker": sticker if isinstance(sticker, str) else sticker.file_id, + "keywords": keywords, + } return await self._post( "setStickerKeywords", data, @@ -7019,7 +7044,7 @@ async def set_sticker_keywords( async def set_sticker_mask_position( self, - sticker: str, + sticker: Union[str, "Sticker"], mask_position: Optional[MaskPosition] = None, *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -7035,7 +7060,11 @@ async def set_sticker_mask_position( .. versionadded:: 20.2 Args: - sticker (:obj:`str`): File identifier of the sticker. + sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the sticker or + the sticker object. + + .. versionchanged:: NEXT.VERSION + Accepts also :class:`telegram.Sticker` instances. mask_position (:class:`telegram.MaskPosition`, optional): A object with the position where the mask should be placed on faces. Omit the parameter to remove the mask position. @@ -7046,7 +7075,10 @@ async def set_sticker_mask_position( Raises: :class:`telegram.error.TelegramError` """ - data: JSONDict = {"sticker": sticker, "mask_position": mask_position} + data: JSONDict = { + "sticker": sticker if isinstance(sticker, str) else sticker.file_id, + "mask_position": mask_position, + } return await self._post( "setStickerMaskPosition", data, @@ -9248,7 +9280,7 @@ async def replace_sticker_in_set( self, user_id: int, name: str, - old_sticker: str, + old_sticker: Union[str, Sticker], sticker: "InputSticker", *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -9266,7 +9298,11 @@ async def replace_sticker_in_set( Args: user_id (:obj:`int`): User identifier of the sticker set owner. name (:obj:`str`): Sticker set name. - old_sticker (:obj:`str`): File identifier of the replaced sticker. + old_sticker (:obj:`str` | :class:`~telegram.Sticker`): File identifier of the replaced + sticker or the sticker object itself. + + .. versionchanged:: NEXT.VERSION + Accepts also :class:`telegram.Sticker` instances. sticker (:class:`telegram.InputSticker`): An object with information about the added sticker. If exactly the same sticker had already been added to the set, then the set remains unchanged. @@ -9280,7 +9316,7 @@ async def replace_sticker_in_set( data: JSONDict = { "user_id": user_id, "name": name, - "old_sticker": old_sticker, + "old_sticker": old_sticker if isinstance(old_sticker, str) else old_sticker.file_id, "sticker": sticker, } diff --git a/telegram/ext/_extbot.py b/telegram/ext/_extbot.py index 66ec43c49f6..910cff98157 100644 --- a/telegram/ext/_extbot.py +++ b/telegram/ext/_extbot.py @@ -1426,7 +1426,7 @@ async def delete_my_commands( async def delete_sticker_from_set( self, - sticker: str, + sticker: Union[str, "Sticker"], *, read_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, @@ -3660,7 +3660,7 @@ async def set_passport_data_errors( async def set_sticker_position_in_set( self, - sticker: str, + sticker: Union[str, "Sticker"], position: int, *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -4114,7 +4114,7 @@ async def delete_sticker_set( async def set_sticker_emoji_list( self, - sticker: str, + sticker: Union[str, "Sticker"], emoji_list: Sequence[str], *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -4136,7 +4136,7 @@ async def set_sticker_emoji_list( async def set_sticker_keywords( self, - sticker: str, + sticker: Union[str, "Sticker"], keywords: Optional[Sequence[str]] = None, *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -4158,7 +4158,7 @@ async def set_sticker_keywords( async def set_sticker_mask_position( self, - sticker: str, + sticker: Union[str, "Sticker"], mask_position: Optional[MaskPosition] = None, *, read_timeout: ODVInput[float] = DEFAULT_NONE, @@ -4250,7 +4250,7 @@ async def replace_sticker_in_set( self, user_id: int, name: str, - old_sticker: str, + old_sticker: Union[str, "Sticker"], sticker: "InputSticker", *, read_timeout: ODVInput[float] = DEFAULT_NONE, diff --git a/tests/_files/test_sticker.py b/tests/_files/test_sticker.py index d77f93ac776..07cf2e932c3 100644 --- a/tests/_files/test_sticker.py +++ b/tests/_files/test_sticker.py @@ -699,6 +699,54 @@ async def make_assertion(*_, **kwargs): monkeypatch.setattr(sticker.get_bot(), "get_file", make_assertion) assert await sticker.get_file() + async def test_delete_sticker_from_set_sticker_input(self, offline_bot, sticker, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters["sticker"] == sticker.file_id + + monkeypatch.setattr(offline_bot.request, "post", make_assertion) + assert await offline_bot.delete_sticker_from_set(sticker) + + async def test_replace_sticker_in_set_sticker_input(self, offline_bot, sticker, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters["old_sticker"] == sticker.file_id + + monkeypatch.setattr(offline_bot.request, "post", make_assertion) + assert await offline_bot.replace_sticker_in_set( + user_id=1, name="name", sticker="sticker", old_sticker=sticker + ) + + async def test_set_sticker_emoji_list_sticker_input(self, offline_bot, sticker, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters["sticker"] == sticker.file_id + + monkeypatch.setattr(offline_bot.request, "post", make_assertion) + assert await offline_bot.set_sticker_emoji_list(sticker, ["emoji"]) + + async def test_set_sticker_mask_position_sticker_input( + self, offline_bot, sticker, monkeypatch + ): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters["sticker"] == sticker.file_id + + monkeypatch.setattr(offline_bot.request, "post", make_assertion) + assert await offline_bot.set_sticker_mask_position(sticker, MaskPosition("eyes", 1, 2, 3)) + + async def test_set_sticker_position_in_set_sticker_input( + self, offline_bot, sticker, monkeypatch + ): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters["sticker"] == sticker.file_id + + monkeypatch.setattr(offline_bot.request, "post", make_assertion) + assert await offline_bot.set_sticker_position_in_set(sticker, 1) + + async def test_set_sticker_keywords_sticker_input(self, offline_bot, sticker, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters["sticker"] == sticker.file_id + + monkeypatch.setattr(offline_bot.request, "post", make_assertion) + assert await offline_bot.set_sticker_keywords(sticker, ["keyword"]) + @pytest.mark.xdist_group("stickerset") class TestStickerSetWithRequest: diff --git a/tests/test_official/exceptions.py b/tests/test_official/exceptions.py index c16cff100e0..4a059196012 100644 --- a/tests/test_official/exceptions.py +++ b/tests/test_official/exceptions.py @@ -47,7 +47,13 @@ class ParamTypeCheckingExceptions: "voice": Voice, "sticker": Sticker, "gift_id": Gift, - } + }, + "(delete|set)_sticker.*": { + "sticker$": Sticker, + }, + "replace_sticker_in_set": { + "old_sticker$": Sticker, + }, } # TODO: Look into merging this with COMPLEX_TYPES From 67fefebd0bc64e6cdefcf3b0b0c2b2776407900e Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 29 Dec 2024 16:50:57 +0100 Subject: [PATCH 3/6] Fix defaults testing --- tests/auxil/bot_method_checks.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/auxil/bot_method_checks.py b/tests/auxil/bot_method_checks.py index 9a0751dae5a..dff21b0b440 100644 --- a/tests/auxil/bot_method_checks.py +++ b/tests/auxil/bot_method_checks.py @@ -37,6 +37,7 @@ InputTextMessageContent, LinkPreviewOptions, ReplyParameters, + Sticker, TelegramObject, ) from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue @@ -317,6 +318,16 @@ def build_kwargs( kws["error_message"] = "error" elif name == "options": kws[name] = ["option1", "option2"] + elif name in ("sticker", "old_sticker"): + kws[name] = Sticker( + file_id="file_id", + file_unique_id="file_unique_id", + width=1, + height=1, + is_animated=False, + is_video=False, + type="regular", + ) else: kws[name] = True From 7909fffc7c121a30c82694c24d6285d86765680c Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 29 Dec 2024 17:29:40 +0100 Subject: [PATCH 4/6] Fix test official --- pyproject.toml | 4 +-- tests/test_official/arg_type_checker.py | 16 ++++++++---- tests/test_official/exceptions.py | 34 ++++++++++++++++--------- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 58752295610..6fba965299d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,8 +181,8 @@ markers = [ "req", ] asyncio_mode = "auto" -log_format = "%(funcName)s - Line %(lineno)d - %(message)s" -# log_level = "DEBUG" # uncomment to see DEBUG logs +log_cli_format = "%(funcName)s - Line %(lineno)d - %(message)s" +# log_cli_level = "DEBUG" # uncomment to see DEBUG logs # MYPY: [tool.mypy] diff --git a/tests/test_official/arg_type_checker.py b/tests/test_official/arg_type_checker.py index 87670e1e6ee..c90df757bcd 100644 --- a/tests/test_official/arg_type_checker.py +++ b/tests/test_official/arg_type_checker.py @@ -215,15 +215,21 @@ def check_param_type( # 5) COMPLEX TYPES: # Some types are too complicated, so we replace our annotation with a simpler type: - elif any(ptb_param.name in key for key in PTCE.COMPLEX_TYPES): - log("Converting `%s` to a simpler type!\n", ptb_param.name) - for (param_name, is_expected_class), exception_type in PTCE.COMPLEX_TYPES.items(): - if ptb_param.name == param_name and is_class is is_expected_class: + elif mappings := [ + mapping + for pattern, mapping in PTCE.COMPLEX_TYPES.items() + if (re.match(pattern, obj.__name__)) + ]: + for mapping in mappings: + for key, exception_type in mapping.items(): + if not re.match(key, ptb_param.name): + continue + log("Converting `%s` to a simpler type!\n", ptb_param.name) ptb_annotation = wrap_with_none(tg_parameter, exception_type, obj) # 6) HANDLING DEFAULTS PARAMETERS: # Classes whose parameters are all ODVInput should be converted and checked. - elif obj.__name__ in PTCE.IGNORED_DEFAULTS_CLASSES: + if obj.__name__ in PTCE.IGNORED_DEFAULTS_CLASSES: log("Checking that `%s`'s param is ODVInput:\n", obj.__name__) mapped_type = ODVInput[mapped_type] elif not ( diff --git a/tests/test_official/exceptions.py b/tests/test_official/exceptions.py index 4a059196012..2b732cce93d 100644 --- a/tests/test_official/exceptions.py +++ b/tests/test_official/exceptions.py @@ -37,7 +37,7 @@ class ParamTypeCheckingExceptions: # Types for certain parameters accepted by PTB but not in the official API # structure: method/class_name/regex: {param_name/regex: type} ADDITIONAL_TYPES = { - "send_*": { + r"send_\w*": { "photo$": PhotoSize, "video$": Video, "video_note": VideoNote, @@ -70,19 +70,29 @@ class ParamTypeCheckingExceptions: } # Special cases for other parameters that accept more types than the official API, and are - # too complex to compare/predict with official API: + # too complex to compare/predict with official API + # structure: class/method_name: {param_name: reduced form of annotation} COMPLEX_TYPES = ( { # (param_name, is_class (i.e appears in a class?)): reduced form of annotation - ("correct_option_id", False): int, # actual: Literal - ("file_id", False): str, # actual: Union[str, objs_with_file_id_attr] - ("invite_link", False): str, # actual: Union[str, ChatInviteLink] - ("provider_data", False): str, # actual: Union[str, obj] - ("callback_data", True): str, # actual: Union[str, obj] - ("media", True): str, # actual: Union[str, InputMedia*, FileInput] - ( - "data", - True, - ): str, # actual: Union[IdDocumentData, PersonalDetails, ResidentialAddress] + "send_poll": {"correct_option_id": int}, # actual: Literal + "get_file": { + "file_id": str, # actual: Union[str, objs_with_file_id_attr] + }, + r"\w+invite_link": { + "invite_link": str, # actual: Union[str, ChatInviteLink] + }, + "send_invoice|create_invoice_link": { + "provider_data": str, # actual: Union[str, obj] + }, + "InlineKeyboardButton": { + "callback_data": str, # actual: Union[str, obj] + }, + "Input(Paid)?Media.*": { + "media": str, # actual: Union[str, InputMedia*, FileInput] + }, + "EncryptedPassportElement": { + "data": str, # actual: Union[IdDocumentData, PersonalDetails, ResidentialAddress] + }, } ) From ea74f083006de1209f0821c939065a883b9752dc Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 29 Dec 2024 17:36:43 +0100 Subject: [PATCH 5/6] another failing test :) --- telegram/_bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index 5a083c31b0e..4ebf9218c94 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -9280,7 +9280,7 @@ async def replace_sticker_in_set( self, user_id: int, name: str, - old_sticker: Union[str, Sticker], + old_sticker: Union[str, "Sticker"], sticker: "InputSticker", *, read_timeout: ODVInput[float] = DEFAULT_NONE, From af4cdda683f77f947377efd1f89633fcc3282837 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 29 Dec 2024 18:00:10 +0100 Subject: [PATCH 6/6] Introduce helper method & restore previous (el)if behavior --- tests/test_official/arg_type_checker.py | 36 ++++++++----------------- tests/test_official/helpers.py | 21 ++++++++++++++- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/tests/test_official/arg_type_checker.py b/tests/test_official/arg_type_checker.py index c90df757bcd..676b8ba2106 100644 --- a/tests/test_official/arg_type_checker.py +++ b/tests/test_official/arg_type_checker.py @@ -38,6 +38,7 @@ _get_params_base, _unionizer, cached_type_hints, + extract_mappings, resolve_forward_refs_in_type, wrap_with_none, ) @@ -174,22 +175,14 @@ def check_param_type( # 2) HANDLING OTHER TYPES: # Special case for send_* methods where we accept more types than the official API: - elif mappings := [ - mapping - for pattern, mapping in PTCE.ADDITIONAL_TYPES.items() - if (re.match(pattern, obj.__name__)) - ]: + elif additional_types := extract_mappings(PTCE.ADDITIONAL_TYPES, obj, ptb_param.name): log("Checking that `%s` accepts additional types for some parameters!\n", obj.__name__) - for mapping in mappings: - for key, value in mapping.items(): - if not re.match(key, ptb_param.name): - continue - - log("Checking that `%s` is an additional type for `%s`!\n", value, ptb_param.name) - mapped_type = mapped_type | value + for at in additional_types: + log("Checking that `%s` is an additional type for `%s`!\n", at, ptb_param.name) + mapped_type = mapped_type | at # 3) HANDLING DATETIMES: - if ( + elif ( re.search( DATETIME_REGEX, ptb_param.name, @@ -215,21 +208,14 @@ def check_param_type( # 5) COMPLEX TYPES: # Some types are too complicated, so we replace our annotation with a simpler type: - elif mappings := [ - mapping - for pattern, mapping in PTCE.COMPLEX_TYPES.items() - if (re.match(pattern, obj.__name__)) - ]: - for mapping in mappings: - for key, exception_type in mapping.items(): - if not re.match(key, ptb_param.name): - continue - log("Converting `%s` to a simpler type!\n", ptb_param.name) - ptb_annotation = wrap_with_none(tg_parameter, exception_type, obj) + elif overrides := extract_mappings(PTCE.COMPLEX_TYPES, obj, ptb_param.name): + exception_type = overrides[0] + log("Converting `%s` to a simpler type!\n", ptb_param.name) + ptb_annotation = wrap_with_none(tg_parameter, exception_type, obj) # 6) HANDLING DEFAULTS PARAMETERS: # Classes whose parameters are all ODVInput should be converted and checked. - if obj.__name__ in PTCE.IGNORED_DEFAULTS_CLASSES: + elif obj.__name__ in PTCE.IGNORED_DEFAULTS_CLASSES: log("Checking that `%s`'s param is ODVInput:\n", obj.__name__) mapped_type = ODVInput[mapped_type] elif not ( diff --git a/tests/test_official/helpers.py b/tests/test_official/helpers.py index 68ffffa09e3..7573d76b7be 100644 --- a/tests/test_official/helpers.py +++ b/tests/test_official/helpers.py @@ -21,7 +21,7 @@ import functools import re from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, _eval_type, get_type_hints +from typing import TYPE_CHECKING, Any, Optional, TypeVar, _eval_type, get_type_hints from bs4 import PageElement, Tag @@ -110,3 +110,22 @@ def cached_type_hints(obj: Any, is_class: bool) -> dict[str, Any]: def resolve_forward_refs_in_type(obj: type) -> type: """Resolves forward references in a type hint.""" return _eval_type(obj, localns=tg_objects, globalns=None) + + +T = TypeVar("T") + + +def extract_mappings( + exceptions: dict[str, dict[str, T]], obj: object, param_name: str +) -> Optional[list[T]]: + mappings = ( + mapping for pattern, mapping in exceptions.items() if (re.match(pattern, obj.__name__)) + ) + out = [ + value + for mapping in mappings + for key, value in mapping.items() + if re.match(key, param_name) + ] + + return None or out