diff --git a/telegram/ext/updater.py b/telegram/ext/updater.py index 02f2ed6e35c..897d5fefa97 100644 --- a/telegram/ext/updater.py +++ b/telegram/ext/updater.py @@ -22,8 +22,10 @@ import ssl from threading import Thread, Lock, current_thread, Event from time import sleep +from datetime import datetime, timedelta, timezone from signal import signal, SIGINT, SIGTERM, SIGABRT from queue import Queue +from functools import partial from telegram import Bot, TelegramError from telegram.ext import Dispatcher, JobQueue @@ -233,8 +235,25 @@ def start_polling(self, poll_interval (:obj:`float`, optional): Time to wait between polling updates from Telegram in seconds. Default is 0.0. timeout (:obj:`float`, optional): Passed to :attr:`telegram.Bot.get_updates`. - clean (:obj:`bool`, optional): Whether to clean any pending updates on Telegram servers - before actually starting to poll. Default is False. + clean (:obj:`bool` | :obj:`datetime.timedelta` | :obj:`datetime.timedelta`, optional): + Whether to clean any pending updates on Telegram servers before actually starting + to poll. This parameter will be interpreted depending on its type. + + * :obj:`bool` ``True`` cleans all updates. Default is ``False``. + * :obj:`datetime.timedelta` will be interpreted as "time before now" cut off. + Pending updates older than the cut off will be cleaned up. + :obj:`datetime.timedelta` is sign independent, both positive and negative deltas + are interpreted as "in the past". + * :obj:`datetime.datetime` will be interpreted as a specific date and time as + cut off. Pending updates older than the cut off will be cleaned up. + If the timezone (``datetime.tzinfo``) is ``None``, UTC will be assumed. + + Note: + If :attr:`clean` is :obj:`datetime.timedelta` or :obj:`datetime.datetime` and + if a :class:`telegram.Update.effective_message` is found with + :attr:`telegram.Message.date` is ``None``, before the :obj:`datetime.timedelta` + or :obj:`datetime.datetime` condition is met, all updates will pass through. + bootstrap_retries (:obj:`int`, optional): Whether the bootstrapping phase of the `Updater` will retry on failures on the Telegram server. @@ -251,11 +270,38 @@ def start_polling(self, Returns: :obj:`Queue`: The update queue that can be filled from the main thread. + Raises: + ValueError: if :attr:`clean` is :obj:`datetime.timedelta` and is < 1 second. + ValueError: if :attr:`clean` is :obj:`datetime.datetime` is not a least 1 second older + than `now()`. + """ with self.__lock: if not self.running: self.running = True + if isinstance(clean, timedelta): + if clean.total_seconds() < 0: + clean = clean * -1 + + if clean.total_seconds() < 1: + raise ValueError('Clean as timedelta needs to be >= 1 second') + else: + # convert to datetime + clean = datetime.now(tz=timezone.utc) - clean + elif isinstance(clean, datetime): + + if ( + clean.tzinfo is None or + (clean.tzinfo is not None and clean.tzinfo.utcoffset(clean) is None) + ): + clean=clean.replace(tzinfo=timezone.utc) + + if clean > (datetime.now(tz=timezone.utc) - timedelta(seconds=1)): + raise ValueError('Clean as datetime ("%s") needs to be at least 1 second' + ' older than "now"("%s")' % (clean, + datetime.now(tz=timezone.utc))) + # Create & start threads self.job_queue.start() dispatcher_ready = Event() @@ -291,8 +337,25 @@ def start_webhook(self, url_path (:obj:`str`, optional): Path inside url. cert (:obj:`str`, optional): Path to the SSL certificate file. key (:obj:`str`, optional): Path to the SSL key file. - clean (:obj:`bool`, optional): Whether to clean any pending updates on Telegram servers - before actually starting the webhook. Default is ``False``. + clean (:obj:`bool` | :obj:`datetime.timedelta` | :obj:`datetime.timedelta`, optional): + Whether to clean any pending updates on Telegram servers before actually starting + to poll. This parameter will be interpreted depending on its type. + + * :obj:`bool` ``True`` cleans all updates. Default is ``False``. + * :obj:`datetime.timedelta` will be interpreted as "time before now" cut off. + Pending updates older than the cut off will be cleaned up. + :obj:`datetime.timedelta` is sign independent, both positive and negative deltas + are interpreted as "in the past". + * :obj:`datetime.datetime` will be interpreted as a specific date and time as + cut off. Pending updates older than the cut off will be cleaned up. + If the timezone (``datetime.tzinfo``) is ``None``, UTC will be assumed. + + Note: + If :attr:`clean` is :obj:`datetime.timedelta` or :obj:`datetime.datetime` and + if a :class:`telegram.Update.effective_message` is found with + :attr:`telegram.Message.date` is ``None``, before the :obj:`datetime.timedelta` + or :obj:`datetime.datetime` condition is met, all updates will pass through. + bootstrap_retries (:obj:`int`, optional): Whether the bootstrapping phase of the `Updater` will retry on failures on the Telegram server. @@ -308,11 +371,37 @@ def start_webhook(self, Returns: :obj:`Queue`: The update queue that can be filled from the main thread. + Raises: + ValueError: if :attr:`clean` is :obj:`datetime.timedelta` and is < 1 second. + ValueError: if :attr:`clean` is :obj:`datetime.datetime` is not a least 1 second older + than `now()`. + """ with self.__lock: if not self.running: self.running = True + if isinstance(clean, timedelta): + if clean.total_seconds() < 0: + clean = clean * -1 + + if clean.total_seconds() < 1: + raise ValueError('Clean as timedelta needs to be >= 1 second') + else: + # convert to datetime + clean = datetime.now(tz=timezone.utc) - clean + elif isinstance(clean, datetime): + if ( + clean.tzinfo is None or + (clean.tzinfo is not None and clean.tzinfo.utcoffset(clean) is None) + ): + clean=clean.replace(tzinfo=timezone.utc) + + if clean > (datetime.now(tz=timezone.utc) - timedelta(seconds=1)): + raise ValueError('Clean as datetime ("%s") needs to be at least 1 second' + ' older than "now"("%s")' % (clean, + datetime.now(tz=timezone.utc))) + # Create & start threads self.job_queue.start() self._init_thread(self.dispatcher.start, "dispatcher"), @@ -475,6 +564,23 @@ def bootstrap_clean_updates(): updates = self.bot.get_updates(updates[-1].update_id + 1) return False + def bootstrap_clean_updates_datetime(datetime_cutoff): + self.logger.debug('Cleaning updates from Telegram server with datetime "%s"', + datetime_cutoff) + updates = self.bot.get_updates() + + # reversed as we just need to find the first msg that's too old + for up in reversed(updates): + if up.effective_message.date is None: + # break out and leave all updates as is + return False + elif up.effective_message and (up.effective_message.date < datetime_cutoff): + # break out, we want to process the 'next' and all following msg's + updates = self.bot.get_updates(up.update_id + 1) + return False + + return False + def bootstrap_set_webhook(): self.bot.set_webhook(url=webhook_url, certificate=cert, @@ -500,11 +606,16 @@ def bootstrap_onerr_cb(exc): retries[0] = 0 # Clean pending messages, if requested. - if clean: + if isinstance(clean, bool) and clean: self._network_loop_retry(bootstrap_clean_updates, bootstrap_onerr_cb, 'bootstrap clean updates', bootstrap_interval) retries[0] = 0 - sleep(1) + elif isinstance(clean, datetime): + bootstrap_clean_updates_datetime_p = partial(bootstrap_clean_updates_datetime, + datetime_cutoff=clean) + self._network_loop_retry(bootstrap_clean_updates_datetime_p, bootstrap_onerr_cb, + 'bootstrap clean updates datetime', bootstrap_interval) + retries[0] = 0 # Restore/set webhook settings, if needed. Again, we don't know ahead if a webhook is set, # so we set it anyhow. diff --git a/tests/test_updater.py b/tests/test_updater.py index 59009cb5236..9c400853f9e 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -81,6 +81,7 @@ class TestUpdater(object): attempts = 0 err_handler_called = Event() cb_handler_called = Event() + update_id = 0 @pytest.fixture(autouse=True) def reset(self): @@ -98,208 +99,17 @@ def callback(self, bot, update): self.received = update.message.text self.cb_handler_called.set() - # TODO: test clean= argument of Updater._bootstrap + # TODO: test clean= argument, both bool and timedelta, of Updater._bootstrap + - @pytest.mark.parametrize(('error',), - argvalues=[(TelegramError('Test Error 2'),), - (Unauthorized('Test Unauthorized'),)], - ids=('TelegramError', 'Unauthorized')) - def test_get_updates_normal_err(self, monkeypatch, updater, error): - def test(*args, **kwargs): - raise error - monkeypatch.setattr(updater.bot, 'get_updates', test) - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - updater.dispatcher.add_error_handler(self.error_handler) - updater.start_polling(0.01) - # Make sure that the error handler was called - self.err_handler_called.wait() - assert self.received == error.message - # Make sure that Updater polling thread keeps running - self.err_handler_called.clear() - self.err_handler_called.wait() - def test_get_updates_bailout_err(self, monkeypatch, updater, caplog): - error = InvalidToken() - def test(*args, **kwargs): - raise error - with caplog.at_level(logging.DEBUG): - monkeypatch.setattr(updater.bot, 'get_updates', test) - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - updater.dispatcher.add_error_handler(self.error_handler) - updater.start_polling(0.01) - assert self.err_handler_called.wait(1) is not True - - sleep(1) - # NOTE: This test might hit a race condition and fail (though the 1 seconds delay above - # should work around it). - # NOTE: Checking Updater.running is problematic because it is not set to False when there's - # an unhandled exception. - # TODO: We should have a way to poll Updater status and decide if it's running or not. - import pprint - pprint.pprint([rec.getMessage() for rec in caplog.get_records('call')]) - assert any('unhandled exception in Bot:{}:updater'.format(updater.bot.id) in - rec.getMessage() for rec in caplog.get_records('call')) - @pytest.mark.parametrize(('error',), - argvalues=[(RetryAfter(0.01),), - (TimedOut(),)], - ids=('RetryAfter', 'TimedOut')) - def test_get_updates_retries(self, monkeypatch, updater, error): - event = Event() - - def test(*args, **kwargs): - event.set() - raise error - monkeypatch.setattr(updater.bot, 'get_updates', test) - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - updater.dispatcher.add_error_handler(self.error_handler) - updater.start_polling(0.01) - - # Make sure that get_updates was called, but not the error handler - event.wait() - assert self.err_handler_called.wait(0.5) is not True - assert self.received != error.message - - # Make sure that Updater polling thread keeps running - event.clear() - event.wait() - assert self.err_handler_called.wait(0.5) is not True - - def test_webhook(self, monkeypatch, updater): - q = Queue() - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - updater.start_webhook( - ip, - port, - url_path='TOKEN') - sleep(.2) - try: - # Now, we send an update to the server via urlopen - update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''), - text='Webhook')) - self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN') - sleep(.2) - assert q.get(False) == update - - # Returns 404 if path is incorrect - with pytest.raises(HTTPError) as excinfo: - self._send_webhook_msg(ip, port, None, 'webookhandler.py') - assert excinfo.value.code == 404 - - with pytest.raises(HTTPError) as excinfo: - self._send_webhook_msg(ip, port, None, 'webookhandler.py', - get_method=lambda: 'HEAD') - assert excinfo.value.code == 404 - - # Test multiple shutdown() calls - updater.httpd.shutdown() - finally: - updater.httpd.shutdown() - sleep(.2) - assert not updater.httpd.is_running - updater.stop() - - def test_webhook_ssl(self, monkeypatch, updater): - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - tg_err = False - try: - updater._start_webhook( - ip, - port, - url_path='TOKEN', - cert='./tests/test_updater.py', - key='./tests/test_updater.py', - bootstrap_retries=0, - clean=False, - webhook_url=None, - allowed_updates=None) - except TelegramError: - tg_err = True - assert tg_err - - def test_webhook_no_ssl(self, monkeypatch, updater): - q = Queue() - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - updater.start_webhook(ip, port, webhook_url=None) - sleep(.2) - - # Now, we send an update to the server via urlopen - update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''), - text='Webhook 2')) - self._send_webhook_msg(ip, port, update.to_json()) - sleep(.2) - assert q.get(False) == update - updater.stop() - - def test_webhook_default_quote(self, monkeypatch, updater): - updater._default_quote = True - q = Queue() - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - updater.start_webhook( - ip, - port, - url_path='TOKEN') - sleep(.2) - - # Now, we send an update to the server via urlopen - update = Update(1, message=Message(1, User(1, '', False), None, Chat(1, ''), - text='Webhook')) - self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN') - sleep(.2) - # assert q.get(False) == update - assert q.get(False).message.default_quote is True - updater.stop() - - @pytest.mark.skipif(not (sys.platform.startswith("win") and sys.version_info >= (3, 8)), - reason="only relevant on win with py>=3.8") - def test_webhook_tornado_win_py38_workaround(self, updater, monkeypatch): - updater._default_quote = True - q = Queue() - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - updater.start_webhook( - ip, - port, - url_path='TOKEN') - sleep(.2) - - try: - from asyncio import (WindowsSelectorEventLoopPolicy) - except ImportError: - pass - # not affected - else: - assert isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy) - - updater.stop() @pytest.mark.parametrize(('error',), argvalues=[(TelegramError(''),)], @@ -337,147 +147,54 @@ def attempt(*args, **kwargs): updater._bootstrap(retries, False, 'path', None, bootstrap_interval=0) assert self.attempts == attempts - @flaky(3, 1) - def test_webhook_invalid_posts(self, updater): - ip = '127.0.0.1' - port = randrange(1024, 49152) # select random port for travis - thr = Thread( - target=updater._start_webhook, - args=(ip, port, '', None, None, 0, False, None, None)) - thr.start() - - sleep(.2) - - try: - with pytest.raises(HTTPError) as excinfo: - self._send_webhook_msg(ip, port, 'data', - content_type='application/xml') - assert excinfo.value.code == 403 - - with pytest.raises(HTTPError) as excinfo: - self._send_webhook_msg(ip, port, 'dummy-payload', content_len=-2) - assert excinfo.value.code == 500 - - # TODO: prevent urllib or the underlying from adding content-length - # with pytest.raises(HTTPError) as excinfo: - # self._send_webhook_msg(ip, port, 'dummy-payload', content_len=None) - # assert excinfo.value.code == 411 - - with pytest.raises(HTTPError): - self._send_webhook_msg(ip, port, 'dummy-payload', content_len='not-a-number') - assert excinfo.value.code == 500 - - finally: - updater.httpd.shutdown() - thr.join() - - def _send_webhook_msg(self, - ip, - port, - payload_str, - url_path='', - content_len=-1, - content_type='application/json', - get_method=None): - headers = {'content-type': content_type, } - - if not payload_str: - content_len = None - payload = None - else: - payload = bytes(payload_str, encoding='utf-8') - - if content_len == -1: - content_len = len(payload) - - if content_len is not None: - headers['content-length'] = str(content_len) - - url = 'http://{ip}:{port}/{path}'.format(ip=ip, port=port, path=url_path) - - req = Request(url, data=payload, headers=headers) - - if get_method is not None: - req.get_method = get_method - - return urlopen(req) - - def signal_sender(self, updater): - sleep(0.2) - while not updater.running: - sleep(0.2) - - os.kill(os.getpid(), signal.SIGTERM) - - @signalskip - def test_idle(self, updater, caplog): - updater.start_polling(0.01) - Thread(target=partial(self.signal_sender, updater=updater)).start() - - with caplog.at_level(logging.INFO): - updater.idle() - - rec = caplog.records[-1] - assert rec.msg.startswith('Received signal {}'.format(signal.SIGTERM)) - assert rec.levelname == 'INFO' - - # If we get this far, idle() ran through - sleep(.5) - assert updater.running is False - - @signalskip - def test_user_signal(self, updater): - temp_var = {'a': 0} - - def user_signal_inc(signum, frame): - temp_var['a'] = 1 - - updater.user_sig_handler = user_signal_inc - updater.start_polling(0.01) - Thread(target=partial(self.signal_sender, updater=updater)).start() - updater.idle() - # If we get this far, idle() ran through - sleep(.5) - assert updater.running is False - assert temp_var['a'] != 0 - - def test_create_bot(self): - updater = Updater('123:abcd') - assert updater.bot is not None - - def test_mutual_exclude_token_bot(self): - bot = Bot('123:zyxw') - with pytest.raises(ValueError): - Updater(token='123:abcd', bot=bot) - - def test_no_token_or_bot_or_dispatcher(self): - with pytest.raises(ValueError): - Updater() - - def test_mutual_exclude_bot_private_key(self): - bot = Bot('123:zyxw') - with pytest.raises(ValueError): - Updater(bot=bot, private_key=b'key') - - def test_mutual_exclude_bot_dispatcher(self): - dispatcher = Dispatcher(None, None) - bot = Bot('123:zyxw') - with pytest.raises(ValueError): - Updater(bot=bot, dispatcher=dispatcher) - - def test_mutual_exclude_persistence_dispatcher(self): - dispatcher = Dispatcher(None, None) - persistence = DictPersistence() - with pytest.raises(ValueError): - Updater(dispatcher=dispatcher, persistence=persistence) - - def test_mutual_exclude_workers_dispatcher(self): - dispatcher = Dispatcher(None, None) - with pytest.raises(ValueError): - Updater(dispatcher=dispatcher, workers=8) - - def test_mutual_exclude_use_context_dispatcher(self): - dispatcher = Dispatcher(None, None) - use_context = not dispatcher.use_context - with pytest.raises(ValueError): - Updater(dispatcher=dispatcher, use_context=use_context) + @pytest.mark.parametrize(('error', ), + argvalues=[(TelegramError(''),)], + ids=('TelegramError', )) + def test_bootstrap_clean_bool(self, monkeypatch, updater, error): + clean = True + expected_id = 4 # max 9 otherwise we hit our inf loop protection + self.update_id = 0 + + def updates(*args, **kwargs): + # we're hitting this func twice + # 1. no args, return list of updates + # 2. with 1 arg, int => if int == expected_id => test successful + + # case inf loop protection + if self.update_id>10: + raise ValueError + + # case 2 + if len(args) > 0: + # we expect to get int(4) + self.update_id = int(args[0]) + raise error + + class fakeUpdate(object): + pass + + # case 1 + # return list of obj's + + # inf loop protection + self.update_id+=1 + + # build list of fake updates + # returns list of 3 objects with + # update_id's 1, 2 and 3 + i=1 + ls = [] + while i < (expected_id): + o = fakeUpdate() + o.update_id = i + ls.append(o) + i+=1 + return ls + + monkeypatch.setattr(updater.bot, 'get_updates', updates) + + updater.running = True + with pytest.raises(type(error)): + updater._bootstrap(1, clean, None, None, bootstrap_interval=0) + assert self.update_id == expected_id +