diff --git a/telegram/_bot.py b/telegram/_bot.py index 78075a3d351..4bb85e1f5ee 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -6833,6 +6833,8 @@ async def send_poll( message_thread_id: Optional[int] = None, reply_parameters: Optional["ReplyParameters"] = None, business_connection_id: Optional[str] = None, + question_parse_mode: ODVInput[str] = DEFAULT_NONE, + question_entities: Optional[Sequence["MessageEntity"]] = None, *, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: Optional[int] = None, @@ -6917,6 +6919,16 @@ async def send_poll( business_connection_id (:obj:`str`, optional): |business_id_str| .. versionadded:: 21.1 + question_parse_mode (:obj:`str`, optional): Mode for parsing entities in the question. + See the constants in :class:`telegram.constants.ParseMode` for the available modes. + Currently, only custom emoji entities are allowed. + + .. versionadded:: NEXT.VERSION + question_entities (Sequence[:class:`telegram.Message`], optional): Special entities + that appear in the poll :paramref:`question`. It can be specified instead of + :paramref:`question_parse_mode`. + + .. versionadded:: NEXT.VERSION Keyword Args: allow_sending_without_reply (:obj:`bool`, optional): |allow_sending_without_reply| @@ -6962,6 +6974,8 @@ async def send_poll( "explanation_entities": explanation_entities, "open_period": open_period, "close_date": close_date, + "question_parse_mode": question_parse_mode, + "question_entities": question_entities, } return await self._send_message( diff --git a/telegram/_chat.py b/telegram/_chat.py index 1c832a26223..86ca956844f 100644 --- a/telegram/_chat.py +++ b/telegram/_chat.py @@ -2903,6 +2903,8 @@ async def send_poll( message_thread_id: Optional[int] = None, reply_parameters: Optional["ReplyParameters"] = None, business_connection_id: Optional[str] = None, + question_parse_mode: ODVInput[str] = DEFAULT_NONE, + question_entities: Optional[Sequence["MessageEntity"]] = None, *, reply_to_message_id: Optional[int] = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -2949,6 +2951,8 @@ async def send_poll( protect_content=protect_content, message_thread_id=message_thread_id, business_connection_id=business_connection_id, + question_parse_mode=question_parse_mode, + question_entities=question_entities, ) async def send_copy( diff --git a/telegram/_message.py b/telegram/_message.py index 586d4dd97fe..61b538c038d 100644 --- a/telegram/_message.py +++ b/telegram/_message.py @@ -65,6 +65,7 @@ from telegram._utils.argumentparsing import parse_sequence_arg from telegram._utils.datetime import extract_tzinfo_from_defaults, from_timestamp from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue +from telegram._utils.entities import parse_message_entities, parse_message_entity from telegram._utils.types import ( CorrectOptionID, FileInput, @@ -2922,6 +2923,8 @@ async def reply_poll( protect_content: ODVInput[bool] = DEFAULT_NONE, message_thread_id: ODVInput[int] = DEFAULT_NONE, reply_parameters: Optional["ReplyParameters"] = None, + question_parse_mode: ODVInput[str] = DEFAULT_NONE, + question_entities: Optional[Sequence["MessageEntity"]] = None, *, reply_to_message_id: Optional[int] = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -2992,6 +2995,8 @@ async def reply_poll( protect_content=protect_content, message_thread_id=message_thread_id, business_connection_id=self.business_connection_id, + question_parse_mode=question_parse_mode, + question_entities=question_entities, ) async def reply_dice( @@ -4202,9 +4207,7 @@ def parse_entity(self, entity: MessageEntity) -> str: if not self.text: raise RuntimeError("This Message has no 'text'.") - entity_text = self.text.encode("utf-16-le") - entity_text = entity_text[entity.offset * 2 : (entity.offset + entity.length) * 2] - return entity_text.decode("utf-16-le") + return parse_message_entity(self.text, entity) def parse_caption_entity(self, entity: MessageEntity) -> str: """Returns the text from a given :class:`telegram.MessageEntity`. @@ -4228,9 +4231,7 @@ def parse_caption_entity(self, entity: MessageEntity) -> str: if not self.caption: raise RuntimeError("This Message has no 'caption'.") - entity_text = self.caption.encode("utf-16-le") - entity_text = entity_text[entity.offset * 2 : (entity.offset + entity.length) * 2] - return entity_text.decode("utf-16-le") + return parse_message_entity(self.caption, entity) def parse_entities(self, types: Optional[List[str]] = None) -> Dict[MessageEntity, str]: """ @@ -4255,12 +4256,7 @@ def parse_entities(self, types: Optional[List[str]] = None) -> Dict[MessageEntit the text that belongs to them, calculated based on UTF-16 codepoints. """ - if types is None: - types = MessageEntity.ALL_TYPES - - return { - entity: self.parse_entity(entity) for entity in self.entities if entity.type in types - } + return parse_message_entities(self.text, self.entities, types=types) def parse_caption_entities( self, types: Optional[List[str]] = None @@ -4287,14 +4283,7 @@ def parse_caption_entities( the text that belongs to them, calculated based on UTF-16 codepoints. """ - if types is None: - types = MessageEntity.ALL_TYPES - - return { - entity: self.parse_caption_entity(entity) - for entity in self.caption_entities - if entity.type in types - } + return parse_message_entities(self.caption, self.caption_entities, types=types) @classmethod def _parse_html( diff --git a/telegram/_poll.py b/telegram/_poll.py index fccdd8da87b..656b7e8b875 100644 --- a/telegram/_poll.py +++ b/telegram/_poll.py @@ -29,6 +29,7 @@ from telegram._utils.argumentparsing import parse_sequence_arg from telegram._utils.datetime import extract_tzinfo_from_defaults, from_timestamp from telegram._utils.defaultvalue import DEFAULT_NONE +from telegram._utils.entities import parse_message_entities, parse_message_entity from telegram._utils.types import JSONDict, ODVInput if TYPE_CHECKING: @@ -113,26 +114,101 @@ class PollOption(TelegramObject): :tg-const:`telegram.PollOption.MIN_LENGTH`-:tg-const:`telegram.PollOption.MAX_LENGTH` characters. voter_count (:obj:`int`): Number of users that voted for this option. + text_entities (Sequence[:class:`telegram.MessageEntity`], optional): Special entities + that appear in the option text. Currently, only custom emoji entities are allowed in + poll option texts. + + .. versionadded:: NEXT.VERSION Attributes: text (:obj:`str`): Option text, :tg-const:`telegram.PollOption.MIN_LENGTH`-:tg-const:`telegram.PollOption.MAX_LENGTH` characters. voter_count (:obj:`int`): Number of users that voted for this option. + text_entities (Tuple[:class:`telegram.MessageEntity`]): Special entities + that appear in the option text. Currently, only custom emoji entities are allowed in + poll option texts. + This list is empty if the question does not contain entities. + + .. versionadded:: NEXT.VERSION """ - __slots__ = ("text", "voter_count") + __slots__ = ("text", "text_entities", "voter_count") - def __init__(self, text: str, voter_count: int, *, api_kwargs: Optional[JSONDict] = None): + def __init__( + self, + text: str, + voter_count: int, + text_entities: Optional[Sequence[MessageEntity]] = None, + *, + api_kwargs: Optional[JSONDict] = None, + ): super().__init__(api_kwargs=api_kwargs) self.text: str = text self.voter_count: int = voter_count + self.text_entities: Tuple[MessageEntity, ...] = parse_sequence_arg(text_entities) self._id_attrs = (self.text, self.voter_count) self._freeze() + @classmethod + def de_json(cls, data: Optional[JSONDict], bot: "Bot") -> Optional["PollOption"]: + """See :meth:`telegram.TelegramObject.de_json`.""" + data = cls._parse_data(data) + + if not data: + return None + + data["text_entities"] = MessageEntity.de_list(data.get("text_entities"), bot) + + return super().de_json(data=data, bot=bot) + + def parse_entity(self, entity: MessageEntity) -> str: + """Returns the text in :attr:`text` + from a given :class:`telegram.MessageEntity` of :attr:`text_entities`. + + Note: + This method is present because Telegram calculates the offset and length in + UTF-16 codepoint pairs, which some versions of Python don't handle automatically. + (That is, you can't just slice ``Message.text`` with the offset and length.) + + .. versionadded:: NEXT.VERSION + + Args: + entity (:class:`telegram.MessageEntity`): The entity to extract the text from. It must + be an entity that belongs to :attr:`text_entities`. + + Returns: + :obj:`str`: The text of the given entity. + """ + return parse_message_entity(self.text, entity) + + def parse_entities(self, types: Optional[List[str]] = None) -> Dict[MessageEntity, str]: + """ + Returns a :obj:`dict` that maps :class:`telegram.MessageEntity` to :obj:`str`. + It contains entities from this polls question filtered by their ``type`` attribute as + the key, and the text that each entity belongs to as the value of the :obj:`dict`. + + Note: + This method should always be used instead of the :attr:`text_entities` + attribute, since it calculates the correct substring from the message text based on + UTF-16 codepoints. See :attr:`parse_entity` for more info. + + .. versionadded:: NEXT.VERSION + + Args: + types (List[:obj:`str`], optional): List of ``MessageEntity`` types as strings. If the + ``type`` attribute of an entity is contained in this list, it will be returned. + Defaults to :attr:`telegram.MessageEntity.ALL_TYPES`. + + Returns: + Dict[:class:`telegram.MessageEntity`, :obj:`str`]: A dictionary of entities mapped to + the text that belongs to them, calculated based on UTF-16 codepoints. + """ + return parse_message_entities(self.text, self.text_entities, types) + MIN_LENGTH: Final[int] = constants.PollLimit.MIN_OPTION_LENGTH """:const:`telegram.constants.PollLimit.MIN_OPTION_LENGTH` @@ -282,6 +358,11 @@ class Poll(TelegramObject): .. versionchanged:: 20.3 |datetime_localization| + question_entities (Sequence[:class:`telegram.MessageEntity`], optional): Special entities + that appear in the :attr:`question`. Currently, only custom emoji entities are allowed + in poll questions. + + .. versionadded:: NEXT.VERSION Attributes: id (:obj:`str`): Unique poll identifier. @@ -318,6 +399,12 @@ class Poll(TelegramObject): .. versionchanged:: 20.3 |datetime_localization| + question_entities (Tuple[:class:`telegram.MessageEntity`]): Special entities + that appear in the :attr:`question`. Currently, only custom emoji entities are allowed + in poll questions. + This list is empty if the question does not contain entities. + + .. versionadded:: NEXT.VERSION """ @@ -333,6 +420,7 @@ class Poll(TelegramObject): "open_period", "options", "question", + "question_entities", "total_voter_count", "type", ) @@ -352,6 +440,7 @@ def __init__( explanation_entities: Optional[Sequence[MessageEntity]] = None, open_period: Optional[int] = None, close_date: Optional[datetime.datetime] = None, + question_entities: Optional[Sequence[MessageEntity]] = None, *, api_kwargs: Optional[JSONDict] = None, ): @@ -371,6 +460,7 @@ def __init__( ) self.open_period: Optional[int] = open_period self.close_date: Optional[datetime.datetime] = close_date + self.question_entities: Tuple[MessageEntity, ...] = parse_sequence_arg(question_entities) self._id_attrs = (self.id,) @@ -390,11 +480,13 @@ def de_json(cls, data: Optional[JSONDict], bot: "Bot") -> Optional["Poll"]: data["options"] = [PollOption.de_json(option, bot) for option in data["options"]] data["explanation_entities"] = MessageEntity.de_list(data.get("explanation_entities"), bot) data["close_date"] = from_timestamp(data.get("close_date"), tzinfo=loc_tzinfo) + data["question_entities"] = MessageEntity.de_list(data.get("question_entities"), bot) return super().de_json(data=data, bot=bot) def parse_explanation_entity(self, entity: MessageEntity) -> str: - """Returns the text from a given :class:`telegram.MessageEntity`. + """Returns the text in :attr:`explanation` from a given :class:`telegram.MessageEntity` of + :attr:`explanation_entities`. Note: This method is present because Telegram calculates the offset and length in @@ -403,7 +495,7 @@ def parse_explanation_entity(self, entity: MessageEntity) -> str: Args: entity (:class:`telegram.MessageEntity`): The entity to extract the text from. It must - be an entity that belongs to this message. + be an entity that belongs to :attr:`explanation_entities`. Returns: :obj:`str`: The text of the given entity. @@ -415,10 +507,7 @@ def parse_explanation_entity(self, entity: MessageEntity) -> str: if not self.explanation: raise RuntimeError("This Poll has no 'explanation'.") - entity_text = self.explanation.encode("utf-16-le") - entity_text = entity_text[entity.offset * 2 : (entity.offset + entity.length) * 2] - - return entity_text.decode("utf-16-le") + return parse_message_entity(self.explanation, entity) def parse_explanation_entities( self, types: Optional[List[str]] = None @@ -442,15 +531,61 @@ def parse_explanation_entities( Dict[:class:`telegram.MessageEntity`, :obj:`str`]: A dictionary of entities mapped to the text that belongs to them, calculated based on UTF-16 codepoints. + Raises: + RuntimeError: If the poll has no explanation. + + """ + if not self.explanation: + raise RuntimeError("This Poll has no 'explanation'.") + + return parse_message_entities(self.explanation, self.explanation_entities, types) + + def parse_question_entity(self, entity: MessageEntity) -> str: + """Returns the text in :attr:`question` from a given :class:`telegram.MessageEntity` of + :attr:`question_entities`. + + .. versionadded:: NEXT.VERSION + + Note: + This method is present because Telegram calculates the offset and length in + UTF-16 codepoint pairs, which some versions of Python don't handle automatically. + (That is, you can't just slice ``Message.text`` with the offset and length.) + + Args: + entity (:class:`telegram.MessageEntity`): The entity to extract the text from. It must + be an entity that belongs to :attr:`question_entities`. + + Returns: + :obj:`str`: The text of the given entity. + """ + return parse_message_entity(self.question, entity) + + def parse_question_entities( + self, types: Optional[List[str]] = None + ) -> Dict[MessageEntity, str]: + """ + Returns a :obj:`dict` that maps :class:`telegram.MessageEntity` to :obj:`str`. + It contains entities from this polls question filtered by their ``type`` attribute as + the key, and the text that each entity belongs to as the value of the :obj:`dict`. + + .. versionadded:: NEXT.VERSION + + Note: + This method should always be used instead of the :attr:`question_entities` + attribute, since it calculates the correct substring from the message text based on + UTF-16 codepoints. See :attr:`parse_question_entity` for more info. + + Args: + types (List[:obj:`str`], optional): List of ``MessageEntity`` types as strings. If the + ``type`` attribute of an entity is contained in this list, it will be returned. + Defaults to :attr:`telegram.MessageEntity.ALL_TYPES`. + + Returns: + Dict[:class:`telegram.MessageEntity`, :obj:`str`]: A dictionary of entities mapped to + the text that belongs to them, calculated based on UTF-16 codepoints. + """ - if types is None: - types = MessageEntity.ALL_TYPES - - return { - entity: self.parse_explanation_entity(entity) - for entity in self.explanation_entities - if entity.type in types - } + return parse_message_entities(self.question, self.question_entities, types) REGULAR: Final[str] = constants.PollType.REGULAR """:const:`telegram.constants.PollType.REGULAR`""" diff --git a/telegram/_user.py b/telegram/_user.py index f783ccd0a88..17b58f2df6f 100644 --- a/telegram/_user.py +++ b/telegram/_user.py @@ -1500,6 +1500,8 @@ async def send_poll( message_thread_id: Optional[int] = None, reply_parameters: Optional["ReplyParameters"] = None, business_connection_id: Optional[str] = None, + question_parse_mode: ODVInput[str] = DEFAULT_NONE, + question_entities: Optional[Sequence["MessageEntity"]] = None, *, reply_to_message_id: Optional[int] = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1549,6 +1551,8 @@ async def send_poll( protect_content=protect_content, message_thread_id=message_thread_id, business_connection_id=business_connection_id, + question_parse_mode=question_parse_mode, + question_entities=question_entities, ) async def send_copy( diff --git a/telegram/_utils/entities.py b/telegram/_utils/entities.py new file mode 100644 index 00000000000..a3994cd0426 --- /dev/null +++ b/telegram/_utils/entities.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2024 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains auxiliary functionality for parsing MessageEntity objects. + +Warning: + Contents of this module are intended to be used internally by the library and *not* by the + user. Changes to this module are not considered breaking changes and may not be documented in + the changelog. +""" +from typing import Dict, Optional, Sequence + +from telegram._messageentity import MessageEntity + + +def parse_message_entity(text: str, entity: MessageEntity) -> str: + """Returns the text from a given :class:`telegram.MessageEntity`. + + Args: + text (:obj:`str`): The text to extract the entity from. + entity (:class:`telegram.MessageEntity`): The entity to extract the text from. + + Returns: + :obj:`str`: The text of the given entity. + """ + entity_text = text.encode("utf-16-le") + entity_text = entity_text[entity.offset * 2 : (entity.offset + entity.length) * 2] + + return entity_text.decode("utf-16-le") + + +def parse_message_entities( + text: str, entities: Sequence[MessageEntity], types: Optional[Sequence[str]] = None +) -> Dict[MessageEntity, str]: + """ + Returns a :obj:`dict` that maps :class:`telegram.MessageEntity` to :obj:`str`. + It contains entities filtered by their ``type`` attribute as + the key, and the text that each entity belongs to as the value of the :obj:`dict`. + + Args: + text (:obj:`str`): The text to extract the entity from. + entities (List[:class:`telegram.MessageEntity`]): The entities to extract the text from. + types (List[:obj:`str`], optional): List of ``MessageEntity`` types as strings. If the + ``type`` attribute of an entity is contained in this list, it will be returned. + Defaults to :attr:`telegram.MessageEntity.ALL_TYPES`. + + Returns: + Dict[:class:`telegram.MessageEntity`, :obj:`str`]: A dictionary of entities mapped to + the text that belongs to them, calculated based on UTF-16 codepoints. + """ + if types is None: + types = MessageEntity.ALL_TYPES + + return { + entity: parse_message_entity(text, entity) for entity in entities if entity.type in types + } diff --git a/telegram/ext/_defaults.py b/telegram/ext/_defaults.py index 61aae16b248..da27fb6eb69 100644 --- a/telegram/ext/_defaults.py +++ b/telegram/ext/_defaults.py @@ -179,13 +179,14 @@ def __init__( # Gather all defaults that actually have a default value self._api_defaults = {} for kwarg in ( - "parse_mode", - "explanation_parse_mode", - "disable_notification", "allow_sending_without_reply", - "protect_content", - "link_preview_options", + "disable_notification", "do_quote", + "explanation_parse_mode", + "link_preview_options", + "parse_mode", + "protect_content", + "question_parse_mode", ): value = getattr(self, kwarg) if value is not None: @@ -267,7 +268,7 @@ def quote_parse_mode(self, _: object) -> NoReturn: @property def text_parse_mode(self) -> Optional[str]: """:obj:`str`: Optional. Alias for :attr:`parse_mode`, used for - the corresponding parameter of :meth:`telegram.InputPollOption`. + the corresponding parameter of :class:`telegram.InputPollOption`. .. versionadded:: NEXT.VERSION """ @@ -279,6 +280,21 @@ def text_parse_mode(self, _: object) -> NoReturn: "You can not assign a new value to text_parse_mode after initialization." ) + @property + def question_parse_mode(self) -> Optional[str]: + """:obj:`str`: Optional. Alias for :attr:`parse_mode`, used for + the corresponding parameter of :meth:`telegram.Bot.send_poll`. + + .. versionadded:: NEXT.VERSION + """ + return self._parse_mode + + @question_parse_mode.setter + def question_parse_mode(self, _: object) -> NoReturn: + raise AttributeError( + "You can not assign a new value to question_parse_mode after initialization." + ) + @property def disable_notification(self) -> Optional[bool]: """:obj:`bool`: Optional. Sends the message silently. Users will diff --git a/telegram/ext/_extbot.py b/telegram/ext/_extbot.py index 563276c1803..afb4400b040 100644 --- a/telegram/ext/_extbot.py +++ b/telegram/ext/_extbot.py @@ -489,6 +489,7 @@ def _insert_defaults(self, data: Dict[str, object]) -> None: data[key] = new_value + # 6) elif isinstance(val, Sequence) and all( isinstance(obj, InputPollOption) for obj in val ): @@ -2950,6 +2951,8 @@ async def send_poll( message_thread_id: Optional[int] = None, reply_parameters: Optional["ReplyParameters"] = None, business_connection_id: Optional[str] = None, + question_parse_mode: ODVInput[str] = DEFAULT_NONE, + question_entities: Optional[Sequence["MessageEntity"]] = None, *, reply_to_message_id: Optional[int] = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -2987,6 +2990,8 @@ async def send_poll( connect_timeout=connect_timeout, pool_timeout=pool_timeout, api_kwargs=self._merge_api_rl_kwargs(api_kwargs, rate_limit_args), + question_parse_mode=question_parse_mode, + question_entities=question_entities, ) async def send_sticker( diff --git a/tests/test_bot.py b/tests/test_bot.py index c216932fb67..4f1cfeff483 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -1949,42 +1949,47 @@ async def make_assertion(url, request_data: RequestData, *args, **kwargs): ], indirect=["default_bot"], ) - async def test_send_poll_default_text_parse_mode( + async def test_send_poll_default_text_question_parse_mode( self, default_bot, raw_bot, chat_id, custom, monkeypatch ): async def make_assertion(url, request_data: RequestData, *args, **kwargs): + expected = default_bot.defaults.text_parse_mode if custom == "NOTHING" else custom + option_1 = request_data.parameters["options"][0] option_2 = request_data.parameters["options"][1] assert option_1.get("text_parse_mode") == (default_bot.defaults.text_parse_mode) - assert option_2.get("text_parse_mode") == ( - default_bot.defaults.text_parse_mode if custom == "NOTHING" else custom - ) + assert option_2.get("text_parse_mode") == expected + assert request_data.parameters.get("question_parse_mode") == expected + return make_message("dummy reply").to_dict() async def make_raw_assertion(url, request_data: RequestData, *args, **kwargs): + expected = None if custom == "NOTHING" else custom + option_1 = request_data.parameters["options"][0] option_2 = request_data.parameters["options"][1] assert option_1.get("text_parse_mode") is None - assert option_2.get("text_parse_mode") == (None if custom == "NOTHING" else custom) + assert option_2.get("text_parse_mode") == expected + + assert request_data.parameters.get("question_parse_mode") == expected + return make_message("dummy reply").to_dict() if custom == "NOTHING": option_2 = InputPollOption("option2") + kwargs = {} else: option_2 = InputPollOption("option2", text_parse_mode=custom) + kwargs = {"question_parse_mode": custom} monkeypatch.setattr(default_bot.request, "post", make_assertion) await default_bot.send_poll( - chat_id, - question="question", - options=["option1", option_2], + chat_id, question="question", options=["option1", option_2], **kwargs ) monkeypatch.setattr(raw_bot.request, "post", make_raw_assertion) await raw_bot.send_poll( - chat_id, - question="question", - options=["option1", option_2], + chat_id, question="question", options=["option1", option_2], **kwargs ) @pytest.mark.parametrize( @@ -2017,6 +2022,30 @@ async def make_assertion(url, request_data: RequestData, *args, **kwargs): reply_parameters=ReplyParameters(**kwargs), ) + async def test_send_poll_question_parse_mode_entities(self, bot, monkeypatch): + # Currently only custom emoji are supported as entities which we can't test + # We just test that the correct data is passed for now + + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + assert request_data.parameters["question_entities"] == [ + {"type": "custom_emoji", "offset": 0, "length": 1}, + {"type": "custom_emoji", "offset": 2, "length": 1}, + ] + assert request_data.parameters["question_parse_mode"] == ParseMode.MARKDOWN_V2 + return make_message("dummy reply").to_dict() + + monkeypatch.setattr(bot.request, "post", make_assertion) + await bot.send_poll( + 1, + question="😀😃", + options=["option1", "option2"], + question_entities=[ + MessageEntity(MessageEntity.CUSTOM_EMOJI, 0, 1), + MessageEntity(MessageEntity.CUSTOM_EMOJI, 2, 1), + ], + question_parse_mode=ParseMode.MARKDOWN_V2, + ) + @pytest.mark.parametrize( ("default_bot", "custom"), [ diff --git a/tests/test_official/arg_type_checker.py b/tests/test_official/arg_type_checker.py index 2ccd7808cb5..24ef867ba70 100644 --- a/tests/test_official/arg_type_checker.py +++ b/tests/test_official/arg_type_checker.py @@ -148,8 +148,11 @@ def check_param_type( # Now let's do the checking, starting with "Array of ..." types. if "Array of " in tg_param_type: # For exceptions just check if they contain the annotation - if ptb_param.name in PTCE.ARRAY_OF_EXCEPTIONS: - return PTCE.ARRAY_OF_EXCEPTIONS[ptb_param.name] in str(ptb_annotation), Sequence + if any(ptb_param.name in key for key in PTCE.ARRAY_OF_EXCEPTIONS): + for (p_name, is_expected_class), exception_type in PTCE.ARRAY_OF_EXCEPTIONS.items(): + if ptb_param.name == p_name and is_class is is_expected_class: + log("Checking that `%s` is an exception!\n", ptb_param.name) + return exception_type in str(ptb_annotation), Sequence obj_match: re.Match | None = re.search(ARRAY_OF_PATTERN, tg_param_type) if obj_match is None: diff --git a/tests/test_official/exceptions.py b/tests/test_official/exceptions.py index 07fc5b07f77..9bc536c2e68 100644 --- a/tests/test_official/exceptions.py +++ b/tests/test_official/exceptions.py @@ -47,15 +47,17 @@ class ParamTypeCheckingExceptions: "sticker": Sticker, } + # TODO: Look into merging this with COMPLEX_TYPES # Exceptions to the "Array of" types, where we accept more types than the official API - # key: parameter name, value: type which must be present in the annotation + # key: (parameter name, is_class), value: type which must be present in the annotation ARRAY_OF_EXCEPTIONS = { - "results": "InlineQueryResult", # + Callable - "commands": "BotCommand", # + tuple[str, str] - "keyboard": "KeyboardButton", # + sequence[sequence[str]] - "reaction": "ReactionType", # + str + ("results", False): "InlineQueryResult", # + Callable + ("commands", False): "BotCommand", # + tuple[str, str] + ("keyboard", True): "KeyboardButton", # + sequence[sequence[str]] + ("reaction", False): "ReactionType", # + str + ("options", False): "InputPollOption", # + str # TODO: Deprecated and will be corrected (and removed) in next major PTB version: - "file_hashes": "List[str]", + ("file_hashes", True): "List[str]", } # Special cases for other parameters that accept more types than the official API, and are diff --git a/tests/test_poll.py b/tests/test_poll.py index 8e41998b254..92c58339daf 100644 --- a/tests/test_poll.py +++ b/tests/test_poll.py @@ -105,7 +105,11 @@ def test_equality(self): @pytest.fixture(scope="module") def poll_option(): - out = PollOption(text=TestPollOptionBase.text, voter_count=TestPollOptionBase.voter_count) + out = PollOption( + text=TestPollOptionBase.text, + voter_count=TestPollOptionBase.voter_count, + text_entities=TestPollOptionBase.text_entities, + ) out._unfreeze() return out @@ -113,6 +117,10 @@ def poll_option(): class TestPollOptionBase: text = "test option" voter_count = 3 + text_entities = [ + MessageEntity(MessageEntity.BOLD, 0, 4), + MessageEntity(MessageEntity.ITALIC, 5, 6), + ] class TestPollOptionWithoutRequest(TestPollOptionBase): @@ -129,12 +137,43 @@ def test_de_json(self): assert poll_option.text == self.text assert poll_option.voter_count == self.voter_count + def test_de_json_all(self): + json_dict = { + "text": self.text, + "voter_count": self.voter_count, + "text_entities": [e.to_dict() for e in self.text_entities], + } + poll_option = PollOption.de_json(json_dict, None) + assert PollOption.de_json(None, None) is None + assert poll_option.api_kwargs == {} + + assert poll_option.text == self.text + assert poll_option.voter_count == self.voter_count + assert poll_option.text_entities == tuple(self.text_entities) + def test_to_dict(self, poll_option): poll_option_dict = poll_option.to_dict() assert isinstance(poll_option_dict, dict) assert poll_option_dict["text"] == poll_option.text assert poll_option_dict["voter_count"] == poll_option.voter_count + assert poll_option_dict["text_entities"] == [ + e.to_dict() for e in poll_option.text_entities + ] + + def test_parse_entity(self, poll_option): + entity = MessageEntity(MessageEntity.BOLD, 0, 4) + poll_option.text_entities = [entity] + + assert poll_option.parse_entity(entity) == "test" + + def test_parse_entities(self, poll_option): + entity = MessageEntity(MessageEntity.BOLD, 0, 4) + entity_2 = MessageEntity(MessageEntity.ITALIC, 5, 6) + poll_option.text_entities = [entity, entity_2] + + assert poll_option.parse_entities(MessageEntity.BOLD) == {entity: "test"} + assert poll_option.parse_entities() == {entity: "test", entity_2: "option"} def test_equality(self): a = PollOption("text", 1) @@ -237,6 +276,7 @@ def poll(): explanation_entities=TestPollBase.explanation_entities, open_period=TestPollBase.open_period, close_date=TestPollBase.close_date, + question_entities=TestPollBase.question_entities, ) poll._unfreeze() return poll @@ -244,7 +284,7 @@ def poll(): class TestPollBase: id_ = "id" - question = "Test?" + question = "Test Question?" options = [PollOption("test", 10), PollOption("test2", 11)] total_voter_count = 0 is_closed = True @@ -258,6 +298,10 @@ class TestPollBase: explanation_entities = [MessageEntity(13, 17, MessageEntity.URL)] open_period = 42 close_date = datetime.now(timezone.utc) + question_entities = [ + MessageEntity(MessageEntity.BOLD, 0, 4), + MessageEntity(MessageEntity.ITALIC, 5, 8), + ] class TestPollWithoutRequest(TestPollBase): @@ -275,6 +319,7 @@ def test_de_json(self, bot): "explanation_entities": [self.explanation_entities[0].to_dict()], "open_period": self.open_period, "close_date": to_timestamp(self.close_date), + "question_entities": [e.to_dict() for e in self.question_entities], } poll = Poll.de_json(json_dict, bot) assert poll.api_kwargs == {} @@ -296,6 +341,7 @@ def test_de_json(self, bot): assert poll.open_period == self.open_period assert abs(poll.close_date - self.close_date) < timedelta(seconds=1) assert to_timestamp(poll.close_date) == to_timestamp(self.close_date) + assert poll.question_entities == tuple(self.question_entities) def test_de_json_localization(self, tz_bot, bot, raw_bot): json_dict = { @@ -311,6 +357,7 @@ def test_de_json_localization(self, tz_bot, bot, raw_bot): "explanation_entities": [self.explanation_entities[0].to_dict()], "open_period": self.open_period, "close_date": to_timestamp(self.close_date), + "question_entities": [e.to_dict() for e in self.question_entities], } poll_raw = Poll.de_json(json_dict, raw_bot) @@ -343,6 +390,7 @@ def test_to_dict(self, poll): assert poll_dict["explanation_entities"] == [poll.explanation_entities[0].to_dict()] assert poll_dict["open_period"] == poll.open_period assert poll_dict["close_date"] == to_timestamp(poll.close_date) + assert poll_dict["question_entities"] == [e.to_dict() for e in poll.question_entities] def test_equality(self): a = Poll(123, "question", ["O1", "O2"], 1, False, True, Poll.REGULAR, True) @@ -383,7 +431,7 @@ def test_enum_init(self): ) assert poll.type is PollType.QUIZ - def test_parse_entity(self, poll): + def test_parse_explanation_entity(self, poll): entity = MessageEntity(type=MessageEntity.URL, offset=13, length=17) poll.explanation_entities = [entity] @@ -401,10 +449,36 @@ def test_parse_entity(self, poll): allows_multiple_answers=False, ).parse_explanation_entity(entity) - def test_parse_entities(self, poll): + def test_parse_explanation_entities(self, poll): entity = MessageEntity(type=MessageEntity.URL, offset=13, length=17) entity_2 = MessageEntity(type=MessageEntity.BOLD, offset=13, length=1) poll.explanation_entities = [entity_2, entity] assert poll.parse_explanation_entities(MessageEntity.URL) == {entity: "http://google.com"} assert poll.parse_explanation_entities() == {entity: "http://google.com", entity_2: "h"} + + with pytest.raises(RuntimeError, match="Poll has no"): + Poll( + "id", + "question", + [PollOption("text", voter_count=0)], + total_voter_count=0, + is_closed=False, + is_anonymous=False, + type=Poll.QUIZ, + allows_multiple_answers=False, + ).parse_explanation_entities() + + def test_parse_question_entity(self, poll): + entity = MessageEntity(MessageEntity.ITALIC, 5, 8) + poll.question_entities = [entity] + + assert poll.parse_question_entity(entity) == "Question" + + def test_parse_question_entities(self, poll): + entity = MessageEntity(MessageEntity.ITALIC, 5, 8) + entity_2 = MessageEntity(MessageEntity.BOLD, 0, 4) + poll.question_entities = [entity_2, entity] + + assert poll.parse_question_entities(MessageEntity.ITALIC) == {entity: "Question"} + assert poll.parse_question_entities() == {entity: "Question", entity_2: "Test"}