diff --git a/telegram/ext/conversationhandler.py b/telegram/ext/conversationhandler.py index ae615935cab..b2f1c91bf97 100644 --- a/telegram/ext/conversationhandler.py +++ b/telegram/ext/conversationhandler.py @@ -21,8 +21,10 @@ import logging import warnings +import functools +import datetime from threading import Lock -from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Tuple, cast, ClassVar +from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Union, Tuple, cast, ClassVar from telegram import Update from telegram.ext import ( @@ -143,6 +145,13 @@ class ConversationHandler(Handler[Update]): received update and the corresponding ``context`` will be handled by ALL the handler's who's :attr:`check_update` method returns :obj:`True` that are in the state :attr:`ConversationHandler.TIMEOUT`. + + Note: + Using `conversation_timeout` with nested conversations is currently not + supported. You can still try to use it, but it will likely behave differently + from what you expect. + + name (:obj:`str`, optional): The name for this conversationhandler. Required for persistence. persistent (:obj:`bool`, optional): If the conversations dict for this handler should be @@ -215,7 +224,7 @@ def __init__( per_chat: bool = True, per_user: bool = True, per_message: bool = False, - conversation_timeout: int = None, + conversation_timeout: Union[float, datetime.timedelta] = None, name: str = None, persistent: bool = False, map_to_parent: Dict[object, object] = None, @@ -291,6 +300,16 @@ def __init__( ) break + if self.conversation_timeout: + for handler in all_handlers: + if isinstance(handler, self.__class__): + warnings.warn( + "Using `conversation_timeout` with nested conversations is currently not " + "supported. You can still try to use it, but it will likely behave " + "differently from what you expect." + ) + break + if self.run_async: for handler in all_handlers: handler.run_async = True @@ -352,7 +371,9 @@ def per_message(self, value: object) -> NoReturn: raise ValueError('You can not assign a new value to per_message after initialization.') @property - def conversation_timeout(self) -> Optional[int]: + def conversation_timeout( + self, + ) -> Optional[Union[float, datetime.timedelta]]: return self._conversation_timeout @conversation_timeout.setter @@ -423,6 +444,45 @@ def _get_key(self, update: Update) -> Tuple[int, ...]: return tuple(key) + def _resolve_promise(self, state: Tuple) -> object: + old_state, new_state = state + try: + res = new_state.result(0) + res = res if res is not None else old_state + except Exception as exc: + self.logger.exception("Promise function raised exception") + self.logger.exception("%s", exc) + res = old_state + finally: + if res is None and old_state is None: + res = self.END + return res + + def _schedule_job( + self, + new_state: object, + dispatcher: 'Dispatcher', + update: Update, + context: Optional[CallbackContext], + conversation_key: Tuple[int, ...], + ) -> None: + if new_state != self.END: + try: + # both job_queue & conversation_timeout are checked before calling _schedule_job + j_queue = dispatcher.job_queue + self.timeout_jobs[conversation_key] = j_queue.run_once( # type: ignore[union-attr] + self._trigger_timeout, + self.conversation_timeout, # type: ignore[arg-type] + context=_ConversationTimeoutContext( + conversation_key, update, dispatcher, context + ), + ) + except Exception as exc: + self.logger.exception( + "Failed to schedule timeout job due to the following exception:" + ) + self.logger.exception("%s", exc) + def check_update(self, update: object) -> CheckUpdateType: # pylint: disable=R0911 """ Determines whether an update should be handled by this conversationhandler, and if so in @@ -455,21 +515,14 @@ def check_update(self, update: object) -> CheckUpdateType: # pylint: disable=R0 if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], Promise): self.logger.debug('waiting for promise...') - old_state, new_state = state - if new_state.done.wait(0): - try: - res = new_state.result(0) - res = res if res is not None else old_state - except Exception as exc: - self.logger.exception("Promise function raised exception") - self.logger.exception("%s", exc) - res = old_state - finally: - if res is None and old_state is None: - res = self.END - self.update_state(res, key) - with self._conversations_lock: - state = self.conversations.get(key) + # check if promise is finished or not + if state[1].done.wait(0): + res = self._resolve_promise(state) + self.update_state(res, key) + with self._conversations_lock: + state = self.conversations.get(key) + + # if not then handle WAITING state instead else: hdlrs = self.states.get(self.WAITING, []) for hdlr in hdlrs: @@ -551,15 +604,27 @@ def handle_update( # type: ignore[override] new_state = exception.state raise_dp_handler_stop = True with self._timeout_jobs_lock: - if self.conversation_timeout and new_state != self.END and dispatcher.job_queue: - # Add the new timeout job - self.timeout_jobs[conversation_key] = dispatcher.job_queue.run_once( - self._trigger_timeout, # type: ignore[arg-type] - self.conversation_timeout, - context=_ConversationTimeoutContext( - conversation_key, update, dispatcher, context - ), - ) + if self.conversation_timeout: + if dispatcher.job_queue is not None: + # Add the new timeout job + if isinstance(new_state, Promise): + new_state.add_done_callback( + functools.partial( + self._schedule_job, + dispatcher=dispatcher, + update=update, + context=context, + conversation_key=conversation_key, + ) + ) + elif new_state != self.END: + self._schedule_job( + new_state, dispatcher, update, context, conversation_key + ) + else: + self.logger.warning( + "Ignoring `conversation_timeout` because the Dispatcher has no JobQueue." + ) if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent: self.update_state(self.END, conversation_key) @@ -597,35 +662,35 @@ def update_state(self, new_state: object, key: Tuple[int, ...]) -> None: if self.persistent and self.persistence and self.name: self.persistence.update_conversation(self.name, key, new_state) - def _trigger_timeout(self, context: _ConversationTimeoutContext, job: 'Job' = None) -> None: + def _trigger_timeout(self, context: CallbackContext, job: 'Job' = None) -> None: self.logger.debug('conversation timeout was triggered!') # Backward compatibility with bots that do not use CallbackContext - callback_context = None if isinstance(context, CallbackContext): job = context.job + ctxt = cast(_ConversationTimeoutContext, job.context) # type: ignore[union-attr] + else: + ctxt = cast(_ConversationTimeoutContext, job.context) - context = job.context # type:ignore[union-attr,assignment] - callback_context = context.callback_context + callback_context = ctxt.callback_context with self._timeout_jobs_lock: - found_job = self.timeout_jobs[context.conversation_key] + found_job = self.timeout_jobs[ctxt.conversation_key] if found_job is not job: - # The timeout has been canceled in handle_update + # The timeout has been cancelled in handle_update return - del self.timeout_jobs[context.conversation_key] + del self.timeout_jobs[ctxt.conversation_key] handlers = self.states.get(self.TIMEOUT, []) for handler in handlers: - check = handler.check_update(context.update) + check = handler.check_update(ctxt.update) if check is not None and check is not False: try: - handler.handle_update( - context.update, context.dispatcher, check, callback_context - ) + handler.handle_update(ctxt.update, ctxt.dispatcher, check, callback_context) except DispatcherHandlerStop: self.logger.warning( 'DispatcherHandlerStop in TIMEOUT state of ' 'ConversationHandler has no effect. Ignoring.' ) - self.update_state(self.END, context.conversation_key) + + self.update_state(self.END, ctxt.conversation_key) diff --git a/telegram/ext/utils/promise.py b/telegram/ext/utils/promise.py index 60442686af5..48508e0747d 100644 --- a/telegram/ext/utils/promise.py +++ b/telegram/ext/utils/promise.py @@ -69,6 +69,7 @@ def __init__( self.update = update self.error_handling = error_handling self.done = Event() + self._done_callback: Optional[Callable] = None self._result: Optional[RT] = None self._exception: Optional[Exception] = None @@ -83,6 +84,15 @@ def run(self) -> None: finally: self.done.set() + if self._done_callback: + try: + self._done_callback(self.result()) + except Exception as exc: + logger.warning( + "`done_callback` of a Promise raised the following exception." + " The exception won't be handled by error handlers." + ) + logger.warning("Full traceback:", exc_info=exc) def __call__(self) -> None: self.run() @@ -106,6 +116,20 @@ def result(self, timeout: float = None) -> Optional[RT]: raise self._exception # pylint: disable=raising-bad-type return self._result + def add_done_callback(self, callback: Callable) -> None: + """ + Callback to be run when :class:`telegram.ext.utils.promise.Promise` becomes done. + + Args: + callback (:obj:`callable`): The callable that will be called when promise is done. + callback will be called by passing ``Promise.result()`` as only positional argument. + + """ + if self.done.wait(0): + callback(self.result()) + else: + self._done_callback = callback + @property def exception(self) -> Optional[Exception]: """The exception raised by :attr:`pooled_function` or ``None`` if no exception has been diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py index f8db5dafa4e..f8e73dc4346 100644 --- a/tests/test_conversationhandler.py +++ b/tests/test_conversationhandler.py @@ -753,6 +753,125 @@ def test_all_update_types(self, dp, bot, user1): assert not handler.check_update(Update(0, pre_checkout_query=pre_checkout_query)) assert not handler.check_update(Update(0, shipping_query=shipping_query)) + def test_no_jobqueue_warning(self, dp, bot, user1, caplog): + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + # save dp.job_queue in temp variable jqueue + # and then set dp.job_queue to None. + jqueue = dp.job_queue + dp.job_queue = None + dp.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + with caplog.at_level(logging.WARNING): + dp.process_update(Update(update_id=0, message=message)) + sleep(0.5) + assert len(caplog.records) == 1 + assert ( + caplog.records[0].message + == "Ignoring `conversation_timeout` because the Dispatcher has no JobQueue." + ) + # now set dp.job_queue back to it's original value + dp.job_queue = jqueue + + def test_schedule_job_exception(self, dp, bot, user1, monkeypatch, caplog): + def mocked_run_once(*a, **kw): + raise Exception("job error") + + monkeypatch.setattr(dp.job_queue, "run_once", mocked_run_once) + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=100, + ) + dp.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + with caplog.at_level(logging.ERROR): + dp.process_update(Update(update_id=0, message=message)) + sleep(0.5) + assert len(caplog.records) == 2 + assert ( + caplog.records[0].message + == "Failed to schedule timeout job due to the following exception:" + ) + assert caplog.records[1].message == "job error" + + def test_promise_exception(self, dp, bot, user1, caplog): + """ + Here we make sure that when a run_async handle raises an + exception, the state isn't changed. + """ + + def conv_entry(*a, **kw): + return 1 + + def raise_error(*a, **kw): + raise Exception("promise exception") + + handler = ConversationHandler( + entry_points=[CommandHandler("start", conv_entry)], + states={1: [MessageHandler(Filters.all, raise_error)]}, + fallbacks=self.fallbacks, + run_async=True, + ) + dp.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + # start the conversation + dp.process_update(Update(update_id=0, message=message)) + sleep(0.1) + message.text = "error" + dp.process_update(Update(update_id=0, message=message)) + sleep(0.1) + message.text = "resolve promise pls" + caplog.clear() + with caplog.at_level(logging.ERROR): + dp.process_update(Update(update_id=0, message=message)) + sleep(0.5) + assert len(caplog.records) == 3 + assert caplog.records[0].message == "Promise function raised exception" + assert caplog.records[1].message == "promise exception" + # assert res is old state + assert handler.conversations.get((self.group.id, user1.id))[0] == 1 + def test_conversation_timeout(self, dp, bot, user1): handler = ConversationHandler( entry_points=self.entry_points, @@ -789,6 +908,49 @@ def test_conversation_timeout(self, dp, bot, user1): sleep(0.7) assert handler.conversations.get((self.group.id, user1.id)) is None + def test_timeout_not_triggered_on_conv_end_async(self, bot, dp, user1): + def timeout(*a, **kw): + self.test_flag = True + + self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]}) + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + run_async=True, + ) + dp.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + # start the conversation + dp.process_update(Update(update_id=0, message=message)) + sleep(0.1) + message.text = '/brew' + message.entities[0].length = len('/brew') + dp.process_update(Update(update_id=1, message=message)) + sleep(0.1) + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + dp.process_update(Update(update_id=2, message=message)) + sleep(0.1) + message.text = '/end' + message.entities[0].length = len('/end') + dp.process_update(Update(update_id=3, message=message)) + sleep(1) + # assert timeout handler didn't got called + assert self.test_flag is False + def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, caplog): handler = ConversationHandler( entry_points=self.entry_points, @@ -1126,6 +1288,39 @@ def slowbrew(_bot, update): assert handler.conversations.get((self.group.id, user1.id)) is None assert self.is_timeout + def test_conversation_timeout_warning_only_shown_once(self, recwarn): + ConversationHandler( + entry_points=self.entry_points, + states={ + self.THIRSTY: [ + ConversationHandler( + entry_points=self.entry_points, + states={ + self.BREWING: [CommandHandler('pourCoffee', self.drink)], + }, + fallbacks=self.fallbacks, + ) + ], + self.DRINKING: [ + ConversationHandler( + entry_points=self.entry_points, + states={ + self.CODING: [CommandHandler('startCoding', self.code)], + }, + fallbacks=self.fallbacks, + ) + ], + }, + fallbacks=self.fallbacks, + conversation_timeout=100, + ) + assert len(recwarn) == 1 + assert str(recwarn[0].message) == ( + "Using `conversation_timeout` with nested conversations is currently not " + "supported. You can still try to use it, but it will likely behave " + "differently from what you expect." + ) + def test_per_message_warning_is_only_shown_once(self, recwarn): ConversationHandler( entry_points=self.entry_points, diff --git a/tests/test_promise.py b/tests/test_promise.py index 46e3d29b65b..a0768b5c63e 100644 --- a/tests/test_promise.py +++ b/tests/test_promise.py @@ -16,6 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import logging import pytest from telegram import TelegramError @@ -63,3 +64,66 @@ def callback(): with pytest.raises(TelegramError, match='Error'): promise.result() + + def test_done_cb_after_run(self): + def callback(): + return "done!" + + def done_callback(_): + self.test_flag = True + + promise = Promise(callback, [], {}) + promise.run() + promise.add_done_callback(done_callback) + assert promise.result() == "done!" + assert self.test_flag is True + + def test_done_cb_after_run_excp(self): + def callback(): + return "done!" + + def done_callback(_): + raise Exception("Error!") + + promise = Promise(callback, [], {}) + promise.run() + assert promise.result() == "done!" + with pytest.raises(Exception) as err: + promise.add_done_callback(done_callback) + assert str(err) == "Error!" + + def test_done_cb_before_run(self): + def callback(): + return "done!" + + def done_callback(_): + self.test_flag = True + + promise = Promise(callback, [], {}) + promise.add_done_callback(done_callback) + assert promise.result(0) != "done!" + assert self.test_flag is False + promise.run() + assert promise.result() == "done!" + assert self.test_flag is True + + def test_done_cb_before_run_excp(self, caplog): + def callback(): + return "done!" + + def done_callback(_): + raise Exception("Error!") + + promise = Promise(callback, [], {}) + promise.add_done_callback(done_callback) + assert promise.result(0) != "done!" + caplog.clear() + with caplog.at_level(logging.WARNING): + promise.run() + assert len(caplog.records) == 2 + assert caplog.records[0].message == ( + "`done_callback` of a Promise raised the following exception." + " The exception won't be handled by error handlers." + ) + assert caplog.records[1].message.startswith("Full traceback:") + assert promise.result() == "done!"