diff --git a/examples/autowiring.py b/examples/autowiring.py new file mode 100644 index 00000000000..382e8ba1a41 --- /dev/null +++ b/examples/autowiring.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Bot that displays the autowiring functionality +# This program is dedicated to the public domain under the CC0 license. +""" +This bot shows how to use `autowire=True` in Handler definitions to save a lot of effort typing +the explicit pass_* flags. + +Usage: +Autowiring example: Try sending /start, /data, "My name is Leandro", or some random text. +Press Ctrl-C on the command line or send a signal to the process to stop the +bot. +""" + +import logging + +from telegram.ext import Updater, CommandHandler, RegexHandler + +# Enable logging +logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + level=logging.INFO) + +logger = logging.getLogger(__name__) + + +def error(bot, update, error): + logger.warning('Update "%s" caused error "%s"' % (update, error)) + + +def start(bot, update, args): + query = ' '.join(args) # `args` is magically defined + if query: + update.message.reply_text(query) + else: + update.message.reply_text("Example: /start here I am") + + +def simple_update_only(update): + """ + A simple handler that only needs an `update` object. + Useful e.g. for basic commands like /help that need to do nothing but reply with some text. + """ + update.message.reply_text("This should have produced an error " + "for the MessageHandler in group=1.") + + +def callback_with_data(bot, update, chat_data, user_data): + msg = 'Adding something to chat_data...\n' + chat_data['value'] = "I'm a chat_data value" + msg += chat_data['value'] + + msg += '\n\n' + + msg += 'Adding something to user_data...\n' + user_data['value'] = "I'm a user_data value" + msg += user_data['value'] + + update.message.reply_text(msg, quote=True) + + +def regex_with_groups(bot, update, groups, groupdict): + update.message.reply_text("Nice, your {} is {}.".format(groups[0], groups[1])) + update.message.reply_text('Groupdict: {}'.format(groupdict)) + + +def callback_undefined_arguments(bot, update, chat_data, groups): + pass + + +def main(): + # Create the Updater and pass it your bot's token. + updater = Updater("TOKEN") + + # Get the dispatcher to register handlers + dp = updater.dispatcher + + # Inject the `args` parameter automagically + dp.add_handler(CommandHandler("start", start, autowire=True)) + + # A RegexHandler example where `groups` and `groupdict` are passed automagically + # Examples: Send "My name is Leandro" or "My cat is blue". + dp.add_handler(RegexHandler(r'[Mm]y (?P.*) is (?P.*)', + regex_with_groups, + autowire=True)) + + # This will raise an error because the bot argument is missing... + dp.add_handler(CommandHandler('help', simple_update_only), group=1) + # ... but with the autowiring capability, you can have callbacks with only an `update` argument. + dp.add_handler(CommandHandler('help', simple_update_only, autowire=True), group=2) + + # Passing `chat_data` and `user_data` explicitly... + dp.add_handler(CommandHandler("data", callback_with_data, + pass_chat_data=True, + pass_user_data=True)) + # ... is equivalent to passing them automagically. + dp.add_handler(CommandHandler("data", callback_with_data, autowire=True)) + + # An example of using the `groups` parameter which is not defined for a CommandHandler. + # Uncomment the line below and you will see a warning. + # dp.add_handler(CommandHandler("erroneous", callback_undefined_arguments, autowire=True)) + + dp.add_error_handler(error) + updater.start_polling() + + # Run the bot until you press Ctrl-C or the process receives SIGINT, + # SIGTERM or SIGABRT. This should be used most of the time, since + # start_polling() is non-blocking and will stop the bot gracefully. + updater.idle() + + +if __name__ == '__main__': + main() diff --git a/telegram/ext/callbackqueryhandler.py b/telegram/ext/callbackqueryhandler.py index 626b3875dcb..081611d4078 100644 --- a/telegram/ext/callbackqueryhandler.py +++ b/telegram/ext/callbackqueryhandler.py @@ -33,6 +33,8 @@ class CallbackQueryHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to @@ -58,6 +60,10 @@ class CallbackQueryHandler(Handler): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that an update should be processed by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield a warning. pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -79,11 +85,11 @@ class CallbackQueryHandler(Handler): ``user_data`` will be passed to the callback function. Default is ``False``. pass_chat_data (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``chat_data`` will be passed to the callback function. Default is ``False``. - """ def __init__(self, callback, + autowire=False, pass_update_queue=False, pass_job_queue=False, pattern=None, @@ -93,6 +99,7 @@ def __init__(self, pass_chat_data=False): super(CallbackQueryHandler, self).__init__( callback, + autowire=autowire, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, @@ -105,6 +112,10 @@ def __init__(self, self.pass_groups = pass_groups self.pass_groupdict = pass_groupdict + if self.autowire: + self.set_autowired_flags(passable={'groups', 'groupdict', 'user_data', + 'chat_data', 'update_queue', 'job_queue'}) + def check_update(self, update): """Determines whether an update should be passed to this handlers :attr:`callback`. @@ -131,7 +142,9 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher, update) + if self.pattern: match = re.match(self.pattern, update.callback_query.data) @@ -140,4 +153,4 @@ def handle_update(self, update, dispatcher): if self.pass_groupdict: optional_args['groupdict'] = match.groupdict() - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) diff --git a/telegram/ext/choseninlineresulthandler.py b/telegram/ext/choseninlineresulthandler.py index 59d0ec0e640..1180ba78ed2 100644 --- a/telegram/ext/choseninlineresulthandler.py +++ b/telegram/ext/choseninlineresulthandler.py @@ -28,6 +28,8 @@ class ChosenInlineResultHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to @@ -47,6 +49,10 @@ class ChosenInlineResultHandler(Handler): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that an update should be processed by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield a warning. pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -64,16 +70,20 @@ class ChosenInlineResultHandler(Handler): def __init__(self, callback, + autowire=False, pass_update_queue=False, pass_job_queue=False, pass_user_data=False, pass_chat_data=False): super(ChosenInlineResultHandler, self).__init__( callback, + autowire=autowire, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, pass_chat_data=pass_chat_data) + if self.autowire: + self.set_autowired_flags() def check_update(self, update): """Determines whether an update should be passed to this handlers :attr:`callback`. @@ -95,9 +105,10 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher, update) - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) # old non-PEP8 Handler methods m = "telegram.ChosenInlineResultHandler." diff --git a/telegram/ext/commandhandler.py b/telegram/ext/commandhandler.py index b33dc7959f4..914a6a60f2b 100644 --- a/telegram/ext/commandhandler.py +++ b/telegram/ext/commandhandler.py @@ -21,8 +21,8 @@ from future.utils import string_types -from .handler import Handler from telegram import Update +from .handler import Handler class CommandHandler(Handler): @@ -39,6 +39,8 @@ class CommandHandler(Handler): Filters. allow_edited (:obj:`bool`): Optional. Determines Whether the handler should also accept edited messages. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_args (:obj:`bool`): Optional. Determines whether the handler should be passed ``args``. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be @@ -68,6 +70,10 @@ class CommandHandler(Handler): operators (& for and, | for or, ~ for not). allow_edited (:obj:`bool`, optional): Determines whether the handler should also accept edited messages. Default is ``False``. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield a warning. pass_args (:obj:`bool`, optional): Determines whether the handler should be passed the arguments passed to the command as a keyword argument called ``args``. It will contain a list of strings, which is the text following the command split on single or @@ -92,6 +98,7 @@ def __init__(self, callback, filters=None, allow_edited=False, + autowire=False, pass_args=False, pass_update_queue=False, pass_job_queue=False, @@ -99,18 +106,23 @@ def __init__(self, pass_chat_data=False): super(CommandHandler, self).__init__( callback, + autowire=autowire, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, pass_chat_data=pass_chat_data) + self.pass_args = pass_args + if self.autowire: + self.set_autowired_flags( + {'update_queue', 'job_queue', 'user_data', 'chat_data', 'args'}) + if isinstance(command, string_types): self.command = [command.lower()] else: self.command = [x.lower() for x in command] self.filters = filters self.allow_edited = allow_edited - self.pass_args = pass_args # We put this up here instead of with the rest of checking code # in check_update since we don't wanna spam a ton @@ -129,8 +141,8 @@ def check_update(self, update): :obj:`bool` """ - if (isinstance(update, Update) - and (update.message or update.edited_message and self.allow_edited)): + if (isinstance(update, Update) and + (update.message or update.edited_message and self.allow_edited)): message = update.message or update.edited_message if message.text: @@ -161,6 +173,7 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher, update) message = update.message or update.edited_message @@ -168,4 +181,4 @@ def handle_update(self, update, dispatcher): if self.pass_args: optional_args['args'] = message.text.split()[1:] - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) diff --git a/telegram/ext/handler.py b/telegram/ext/handler.py index 99b89109def..7b00958da7e 100644 --- a/telegram/ext/handler.py +++ b/telegram/ext/handler.py @@ -17,13 +17,24 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the base class for handlers as used by the Dispatcher.""" +import warnings + +from telegram.utils.inspection import inspect_arguments class Handler(object): - """The base class for all update handlers. Create custom handlers by inheriting from it. + """ + The base class for all update handlers. Create custom handlers by inheriting from it. + + If your subclass needs the *autowiring* functionality, make sure to call + ``set_autowired_flags`` **after** initializing the ``pass_*`` members. The ``passable`` + argument to this method denotes all the flags your Handler supports, e.g. + ``{'update_queue', 'job_queue', 'args'}``. Attributes: callback (:obj:`callable`): The callback function for this handler. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to @@ -43,6 +54,10 @@ class Handler(object): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that an update should be processed by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield a warning. pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -58,17 +73,25 @@ class Handler(object): """ + PASSABLE_OBJECTS = {'update_queue', 'job_queue', 'user_data', 'chat_data', + 'args', 'groups', 'groupdict'} + def __init__(self, callback, + autowire=False, pass_update_queue=False, pass_job_queue=False, pass_user_data=False, pass_chat_data=False): self.callback = callback + self.autowire = autowire self.pass_update_queue = pass_update_queue self.pass_job_queue = pass_job_queue self.pass_user_data = pass_user_data self.pass_chat_data = pass_chat_data + self._autowire_initialized = False + self._callback_args = None + self._passable = None def check_update(self, update): """ @@ -99,8 +122,79 @@ def handle_update(self, update, dispatcher): """ raise NotImplementedError + def __warn_autowire(self): + """ Warn if the user has set any `pass_*` flags to True in addition to `autowire` """ + for flag in self._get_available_pass_flags(): + to_pass = bool(getattr(self, flag)) + if to_pass is True: + warnings.warn('If `autowire` is set to `True`, it is unnecessary ' + 'to provide the `{}` flag.'.format(flag)) + + def _get_available_pass_flags(self): + """ + Used to provide warnings if the user decides to use `autowire` in conjunction with + ``pass_*`` flags, and to recalculate all flags. + + Getting objects dynamically is better than hard-coding all passable objects and setting + them to False in here, because the base class should not know about the existence of + passable objects that are only relevant to subclasses (e.g. args, groups, groupdict). + """ + return [f for f in dir(self) if f.startswith('pass_')] + + def __should_pass_obj(self, name): + """ + Utility to determine whether a passable object is part of + the user handler's signature, makes sense in this context, + and is not explicitly set to `False`. + """ + is_requested = name in self.PASSABLE_OBJECTS and name in self._callback_args + if is_requested and name not in self._passable: + warnings.warn("The argument `{}` cannot be autowired since it is not available " + "on `{}s`.".format(name, type(self).__name__)) + return False + return is_requested + + def set_autowired_flags(self, + passable={'update_queue', 'job_queue', 'user_data', 'chat_data'}): + """ + This method inspects the callback handler for used arguments. If it finds arguments that + are ``passable``, i.e. types that can also be passed by the various ``pass_*`` flags, + it sets the according flags to true. + + If the handler signature is prone to change at runtime for whatever reason, you can call + this method again to recalculate the flags to use. + + The ``passable`` arguments are required to be explicit as opposed to dynamically generated + to be absolutely safe that no arguments will be passed that are not allowed. + + Args: + passable: An iterable that contains the allowed flags for this handler + """ + self._passable = passable + + if not self.autowire: + raise ValueError("This handler is not autowired.") + + if self._autowire_initialized: + # In case that users decide to change their callback signatures at runtime, give the + # possibility to recalculate all flags. + for flag in self._get_available_pass_flags(): + setattr(self, flag, False) + + self.__warn_autowire() + + self._callback_args = inspect_arguments(self.callback) + + # Actually set `pass_*` flags to True + for to_pass in self.PASSABLE_OBJECTS: + if self.__should_pass_obj(to_pass): + setattr(self, 'pass_' + to_pass, True) + + self._autowire_initialized = True + def collect_optional_args(self, dispatcher, update=None): - """Prepares the optional arguments that are the same for all types of handlers. + """ + Prepares the optional arguments that are the same for all types of handlers. Args: dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher. @@ -108,18 +202,48 @@ def collect_optional_args(self, dispatcher, update=None): """ optional_args = dict() + if self.autowire: + # Subclasses are responsible for calling `set_autowired_flags` + # at the end of their __init__ + assert self._autowire_initialized + if self.pass_update_queue: optional_args['update_queue'] = dispatcher.update_queue if self.pass_job_queue: optional_args['job_queue'] = dispatcher.job_queue - if self.pass_user_data or self.pass_chat_data: - chat = update.effective_chat + if self.pass_user_data: user = update.effective_user + optional_args['user_data'] = dispatcher.user_data[user.id if user else None] + if self.pass_chat_data: + chat = update.effective_chat + optional_args['chat_data'] = dispatcher.chat_data[chat.id if chat else None] - if self.pass_user_data: - optional_args['user_data'] = dispatcher.user_data[user.id if user else None] + return optional_args - if self.pass_chat_data: - optional_args['chat_data'] = dispatcher.chat_data[chat.id if chat else None] + def collect_bot_update_args(self, dispatcher, update): + """ + Prepares the positional arguments ``bot`` and/or ``update`` that are required for every + python-telegram-bot handler that is not **autowired**. If ``autowire`` is set to ``True``, + this method uses the inspected callback arguments to decide whether bot or update, + respectively, need to be passed. The order is always (bot, update). - return optional_args + + Args: + dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher. + update (:class:`telegram.Update`): The update. + + Returns: + A tuple of bot, update, or both + """ + if self.autowire: + # Subclasses are responsible for calling `set_autowired_flags` in their __init__ + assert self._autowire_initialized + + positional_args = [] + if 'bot' in self._callback_args: + positional_args.append(dispatcher.bot) + if 'update' in self._callback_args: + positional_args.append(update) + return tuple(positional_args) + else: + return (dispatcher.bot, update) diff --git a/telegram/ext/inlinequeryhandler.py b/telegram/ext/inlinequeryhandler.py index 3e5ec7a0566..4aeb6bbd091 100644 --- a/telegram/ext/inlinequeryhandler.py +++ b/telegram/ext/inlinequeryhandler.py @@ -33,12 +33,14 @@ class InlineQueryHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + pattern (:obj:`str` | :obj:`Pattern`): Optional. Regex pattern to test + :attr:`telegram.InlineQuery.query` against. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to the callback function. - pattern (:obj:`str` | :obj:`Pattern`): Optional. Regex pattern to test - :attr:`telegram.InlineQuery.query` against. pass_groups (:obj:`bool`): Optional. Determines whether ``groups`` will be passed to the callback function. pass_groupdict (:obj:`bool`): Optional. Determines whether ``groupdict``. will be passed to @@ -58,6 +60,14 @@ class InlineQueryHandler(Handler): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that an update should be processed by this handler. + pattern (:obj:`str` | :obj:`Pattern`, optional): Regex pattern. If not ``None``, + ``re.match`` is used on :attr:`telegram.InlineQuery.query` to determine if an update + should be handled by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield + a warning. pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -66,9 +76,6 @@ class InlineQueryHandler(Handler): ``job_queue`` will be passed to the callback function. It will be a :class:`telegram.ext.JobQueue` instance created by the :class:`telegram.ext.Updater` which can be used to schedule new jobs. Default is ``False``. - pattern (:obj:`str` | :obj:`Pattern`, optional): Regex pattern. If not ``None``, - ``re.match`` is used on :attr:`telegram.InlineQuery.query` to determine if an update - should be handled by this handler. pass_groups (:obj:`bool`, optional): If the callback should be passed the result of ``re.match(pattern, data).groups()`` as a keyword argument called ``groups``. Default is ``False`` @@ -83,15 +90,17 @@ class InlineQueryHandler(Handler): def __init__(self, callback, + pattern=None, + autowire=False, pass_update_queue=False, pass_job_queue=False, - pattern=None, pass_groups=False, pass_groupdict=False, pass_user_data=False, pass_chat_data=False): super(InlineQueryHandler, self).__init__( callback, + autowire=autowire, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, @@ -103,6 +112,9 @@ def __init__(self, self.pattern = pattern self.pass_groups = pass_groups self.pass_groupdict = pass_groupdict + if self.autowire: + self.set_autowired_flags(passable={'update_queue', 'job_queue', 'user_data', + 'chat_data', 'groups', 'groupdict'}) def check_update(self, update): """ @@ -131,8 +143,9 @@ def handle_update(self, update, dispatcher): update (:class:`telegram.Update`): Incoming telegram update. dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. """ - + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher, update) + if self.pattern: match = re.match(self.pattern, update.inline_query.query) @@ -141,7 +154,7 @@ def handle_update(self, update, dispatcher): if self.pass_groupdict: optional_args['groupdict'] = match.groupdict() - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) # old non-PEP8 Handler methods m = "telegram.InlineQueryHandler." diff --git a/telegram/ext/messagehandler.py b/telegram/ext/messagehandler.py index 11c10803ceb..9fd4da2699d 100644 --- a/telegram/ext/messagehandler.py +++ b/telegram/ext/messagehandler.py @@ -31,6 +31,8 @@ class MessageHandler(Handler): filters (:obj:`Filter`): Only allow updates with these Filters. See :mod:`telegram.ext.filters` for a full list of all available filters. callback (:obj:`callable`): The callback function for this handler. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to @@ -62,6 +64,11 @@ class MessageHandler(Handler): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that an update should be processed by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield + a warning. pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -92,6 +99,7 @@ def __init__(self, filters, callback, allow_edited=False, + autowire=False, pass_update_queue=False, pass_job_queue=False, pass_user_data=False, @@ -108,6 +116,7 @@ def __init__(self, super(MessageHandler, self).__init__( callback, + autowire=autowire, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, @@ -117,6 +126,9 @@ def __init__(self, self.channel_post_updates = channel_post_updates self.edited_updates = edited_updates + if self.autowire: + self.set_autowired_flags() + # We put this up here instead of with the rest of checking code # in check_update since we don't wanna spam a ton if isinstance(self.filters, list): @@ -164,6 +176,7 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher, update) - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) diff --git a/telegram/ext/precheckoutqueryhandler.py b/telegram/ext/precheckoutqueryhandler.py index 5cccfdd54eb..20cc6f349e0 100644 --- a/telegram/ext/precheckoutqueryhandler.py +++ b/telegram/ext/precheckoutqueryhandler.py @@ -27,6 +27,8 @@ class PreCheckoutQueryHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to @@ -46,6 +48,11 @@ class PreCheckoutQueryHandler(Handler): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that an update should be processed by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield + a warning. pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -63,16 +70,20 @@ class PreCheckoutQueryHandler(Handler): def __init__(self, callback, + autowire=False, pass_update_queue=False, pass_job_queue=False, pass_user_data=False, pass_chat_data=False): super(PreCheckoutQueryHandler, self).__init__( callback, + autowire=autowire, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, pass_chat_data=pass_chat_data) + if self.autowire: + self.set_autowired_flags() def check_update(self, update): """Determines whether an update should be passed to this handlers :attr:`callback`. @@ -94,5 +105,7 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher, update) - return self.callback(dispatcher.bot, update, **optional_args) + + return self.callback(*positional_args, **optional_args) diff --git a/telegram/ext/regexhandler.py b/telegram/ext/regexhandler.py index 9397ff3720e..c7dfd2a713d 100644 --- a/telegram/ext/regexhandler.py +++ b/telegram/ext/regexhandler.py @@ -38,6 +38,8 @@ class RegexHandler(Handler): Attributes: pattern (:obj:`str` | :obj:`Pattern`): The regex pattern. callback (:obj:`callable`): The callback function for this handler. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_groups (:obj:`bool`): Optional. Determines whether ``groups`` will be passed to the callback function. pass_groupdict (:obj:`bool`): Optional. Determines whether ``groupdict``. will be passed to @@ -62,6 +64,11 @@ class RegexHandler(Handler): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that an update should be processed by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield + a warning. pass_groups (:obj:`bool`, optional): If the callback should be passed the result of ``re.match(pattern, data).groups()`` as a keyword argument called ``groups``. Default is ``False`` @@ -97,6 +104,7 @@ class RegexHandler(Handler): def __init__(self, pattern, callback, + autowire=False, pass_groups=False, pass_groupdict=False, pass_update_queue=False, @@ -117,6 +125,7 @@ def __init__(self, super(RegexHandler, self).__init__( callback, + autowire=autowire, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, @@ -128,6 +137,9 @@ def __init__(self, self.pattern = pattern self.pass_groups = pass_groups self.pass_groupdict = pass_groupdict + if self.autowire: + self.set_autowired_flags( + {'groups', 'groupdict', 'update_queue', 'job_queue', 'user_data', 'chat_data'}) self.allow_edited = allow_edited self.message_updates = message_updates self.channel_post_updates = channel_post_updates @@ -162,6 +174,7 @@ def handle_update(self, update, dispatcher): """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher, update) match = re.match(self.pattern, update.effective_message.text) @@ -170,4 +183,4 @@ def handle_update(self, update, dispatcher): if self.pass_groupdict: optional_args['groupdict'] = match.groupdict() - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) diff --git a/telegram/ext/shippingqueryhandler.py b/telegram/ext/shippingqueryhandler.py index ab21197d306..3b173dce295 100644 --- a/telegram/ext/shippingqueryhandler.py +++ b/telegram/ext/shippingqueryhandler.py @@ -27,6 +27,8 @@ class ShippingQueryHandler(Handler): Attributes: callback (:obj:`callable`): The callback function for this handler. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to @@ -46,6 +48,11 @@ class ShippingQueryHandler(Handler): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that an update should be processed by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield + a warning. pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -63,16 +70,20 @@ class ShippingQueryHandler(Handler): def __init__(self, callback, + autowire=False, pass_update_queue=False, pass_job_queue=False, pass_user_data=False, pass_chat_data=False): super(ShippingQueryHandler, self).__init__( callback, + autowire=autowire, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue, pass_user_data=pass_user_data, pass_chat_data=pass_chat_data) + if self.autowire: + self.set_autowired_flags() def check_update(self, update): """Determines whether an update should be passed to this handlers :attr:`callback`. @@ -94,5 +105,6 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher, update) - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) diff --git a/telegram/ext/stringcommandhandler.py b/telegram/ext/stringcommandhandler.py index ca77d1b4fd7..22633b9f701 100644 --- a/telegram/ext/stringcommandhandler.py +++ b/telegram/ext/stringcommandhandler.py @@ -33,6 +33,8 @@ class StringCommandHandler(Handler): Attributes: command (:obj:`str`): The command this handler should listen for. callback (:obj:`callable`): The callback function for this handler. + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_args (:obj:`bool`): Optional. Determines whether the handler should be passed ``args``. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be @@ -46,6 +48,10 @@ class StringCommandHandler(Handler): callback (:obj:`callable`): A function that takes ``bot, update`` as positional arguments. It will be called when the :attr:`check_update` has determined that a command should be processed by this handler. + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield a warning. pass_args (:obj:`bool`, optional): Determines whether the handler should be passed the arguments passed to the command as a keyword argument called ``args``. It will contain a list of strings, which is the text following the command split on single or @@ -64,13 +70,19 @@ class StringCommandHandler(Handler): def __init__(self, command, callback, + autowire=False, pass_args=False, pass_update_queue=False, pass_job_queue=False): super(StringCommandHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + autowire=autowire, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue) self.command = command self.pass_args = pass_args + if self.autowire: + self.set_autowired_flags(passable={'update_queue', 'job_queue', 'args'}) def check_update(self, update): """Determines whether an update should be passed to this handlers :attr:`callback`. @@ -94,10 +106,10 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the command. """ - + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher) if self.pass_args: optional_args['args'] = update.split()[1:] - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) diff --git a/telegram/ext/stringregexhandler.py b/telegram/ext/stringregexhandler.py index 18e17bc1715..d666d755f7d 100644 --- a/telegram/ext/stringregexhandler.py +++ b/telegram/ext/stringregexhandler.py @@ -72,12 +72,16 @@ class StringRegexHandler(Handler): def __init__(self, pattern, callback, + autowire=False, pass_groups=False, pass_groupdict=False, pass_update_queue=False, pass_job_queue=False): super(StringRegexHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + autowire=autowire, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue) if isinstance(pattern, string_types): pattern = re.compile(pattern) @@ -85,6 +89,8 @@ def __init__(self, self.pattern = pattern self.pass_groups = pass_groups self.pass_groupdict = pass_groupdict + if self.autowire: + self.set_autowired_flags({'groups', 'groupdict', 'update_queue', 'job_queue'}) def check_update(self, update): """Determines whether an update should be passed to this handlers :attr:`callback`. @@ -106,6 +112,7 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the command. """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher) match = re.match(self.pattern, update) @@ -114,4 +121,4 @@ def handle_update(self, update, dispatcher): if self.pass_groupdict: optional_args['groupdict'] = match.groupdict() - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) diff --git a/telegram/ext/typehandler.py b/telegram/ext/typehandler.py index 5b4d1d33acb..051fd7c632f 100644 --- a/telegram/ext/typehandler.py +++ b/telegram/ext/typehandler.py @@ -29,6 +29,8 @@ class TypeHandler(Handler): callback (:obj:`callable`): The callback function for this handler. strict (:obj:`bool`): Optional. Use ``type`` instead of ``isinstance``. Default is ``False`` + autowire (:obj:`bool`): Optional. Determines whether objects will be passed to the + callback function automatically. pass_update_queue (:obj:`bool`): Optional. Determines whether ``update_queue`` will be passed to the callback function. pass_job_queue (:obj:`bool`): Optional. Determines whether ``job_queue`` will be passed to @@ -42,6 +44,10 @@ class TypeHandler(Handler): processed by this handler. strict (:obj:`bool`, optional): Use ``type`` instead of ``isinstance``. Default is ``False`` + autowire (:obj:`bool`, optional): If set to ``True``, your callback handler will be + inspected for positional arguments and be passed objects whose names match any of the + ``pass_*`` flags of this Handler. Using any ``pass_*`` argument in conjunction with + ``autowire`` will yield a warning. pass_update_queue (:obj:`bool`, optional): If set to ``True``, a keyword argument called ``update_queue`` will be passed to the callback function. It will be the ``Queue`` instance used by the :class:`telegram.ext.Updater` and :class:`telegram.ext.Dispatcher` @@ -53,12 +59,22 @@ class TypeHandler(Handler): """ - def __init__(self, type, callback, strict=False, pass_update_queue=False, + def __init__(self, + type, + callback, + strict=False, + autowire=False, + pass_update_queue=False, pass_job_queue=False): super(TypeHandler, self).__init__( - callback, pass_update_queue=pass_update_queue, pass_job_queue=pass_job_queue) + callback, + autowire=autowire, + pass_update_queue=pass_update_queue, + pass_job_queue=pass_job_queue) self.type = type self.strict = strict + if self.autowire: + self.set_autowired_flags({'update_queue', 'job_queue'}) def check_update(self, update): """Determines whether an update should be passed to this handlers :attr:`callback`. @@ -84,6 +100,7 @@ def handle_update(self, update, dispatcher): dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. """ + positional_args = self.collect_bot_update_args(dispatcher, update) optional_args = self.collect_optional_args(dispatcher) - return self.callback(dispatcher.bot, update, **optional_args) + return self.callback(*positional_args, **optional_args) diff --git a/telegram/utils/inspection.py b/telegram/utils/inspection.py new file mode 100644 index 00000000000..090ab7e1467 --- /dev/null +++ b/telegram/utils/inspection.py @@ -0,0 +1,13 @@ +import inspect + +""" +Reflects on a function or method to retrieve all positional and keyword arguments available. +""" +try: + def inspect_arguments(func): + args, _, _, _ = inspect.getargspec(func) + return args +except Warning: # `getargspec()` is deprecated in Python3 + def inspect_arguments(func): + args, _, _, _, _, _, _ = inspect.getfullargspec(func) + return args diff --git a/tests/test_callbackqueryhandler.py b/tests/test_callbackqueryhandler.py index 8160bcf9345..202127afb7e 100644 --- a/tests/test_callbackqueryhandler.py +++ b/tests/test_callbackqueryhandler.py @@ -74,6 +74,13 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def autowire_callback_1(self, update, job_queue, update_queue, chat_data, user_data): + self.test_flag = all(x is not None for x in (update, job_queue, + update_queue, chat_data, user_data)) + + def autowire_callback_2(self, bot, update, job_queue): + self.test_flag = all(x is not None for x in (bot, update, job_queue)) + def callback_group(self, bot, update, groups=None, groupdict=None): if groups is not None: self.test_flag = groups == ('t', ' data') @@ -164,6 +171,30 @@ def test_pass_job_or_update_queue(self, dp, callback_query): dp.process_update(callback_query) assert self.test_flag + def test_autowire(self, dp, callback_query): + handler = CallbackQueryHandler(self.autowire_callback_1, autowire=True) + dp.add_handler(handler) + + dp.process_update(callback_query) + assert self.test_flag + + dp.remove_handler(handler) + handler = CallbackQueryHandler(self.autowire_callback_2, autowire=True) + dp.add_handler(handler) + + self.test_flag = False + dp.process_update(callback_query) + assert self.test_flag + + dp.remove_handler(handler) + handler = CallbackQueryHandler(self.callback_group, + pattern='(?P.*)est(?P.*)', + pass_groups=True) + dp.add_handler(handler) + + dp.process_update(callback_query) + assert self.test_flag + def test_other_update_types(self, false_update): handler = CallbackQueryHandler(self.callback_basic) assert not handler.check_update(false_update) diff --git a/tests/test_choseninlineresulthandler.py b/tests/test_choseninlineresulthandler.py index 2606c536e7e..54f1d73e43c 100644 --- a/tests/test_choseninlineresulthandler.py +++ b/tests/test_choseninlineresulthandler.py @@ -78,6 +78,10 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def autowire_callback(self, update, job_queue, update_queue, chat_data, user_data): + self.test_flag = all(x is not None for x in (update, job_queue, + update_queue, chat_data, user_data)) + def test_basic(self, dp, chosen_inline_result): handler = ChosenInlineResultHandler(self.callback_basic) dp.add_handler(handler) @@ -134,6 +138,13 @@ def test_pass_job_or_update_queue(self, dp, chosen_inline_result): dp.process_update(chosen_inline_result) assert self.test_flag + def test_autowire(self, dp, chosen_inline_result): + handler = ChosenInlineResultHandler(self.autowire_callback, autowire=True) + dp.add_handler(handler) + + dp.process_update(chosen_inline_result) + assert self.test_flag + def test_other_update_types(self, false_update): handler = ChosenInlineResultHandler(self.callback_basic) assert not handler.check_update(false_update) diff --git a/tests/test_commandhandler.py b/tests/test_commandhandler.py index fb1aafa1fcf..a08f5e5ce27 100644 --- a/tests/test_commandhandler.py +++ b/tests/test_commandhandler.py @@ -75,6 +75,10 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def autowire_callback(self, update, args, job_queue, update_queue, chat_data, user_data): + self.test_flag = all(x is not None for x in (update, args, job_queue, + update_queue, chat_data, user_data)) + def ch_callback_args(self, bot, update, args): if update.message.text == '/test': self.test_flag = len(args) == 0 @@ -215,6 +219,14 @@ def test_pass_job_or_update_queue(self, dp, message): dp.process_update(Update(0, message=message)) assert self.test_flag + def test_autowire(self, dp, message): + handler = CommandHandler('test', self.autowire_callback, autowire=True) + dp.add_handler(handler) + + message.text = '/test abc' + dp.process_update(Update(0, message=message)) + assert self.test_flag + def test_other_update_types(self, false_update): handler = CommandHandler('test', self.callback_basic) assert not handler.check_update(false_update) diff --git a/tests/test_handler.py b/tests/test_handler.py new file mode 100644 index 00000000000..61ab05fa120 --- /dev/null +++ b/tests/test_handler.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2017 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. + +import pytest + +from telegram.ext import Handler + + +class TestHandler(object): + test_flag = False + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + def callback_basic(self, bot, update): + pass + + def callback_some_passable(self, bot, update, update_queue, chat_data): + pass + + def callback_all_passable(self, bot, update, update_queue, job_queue, chat_data, user_data): + pass + + def test_set_autowired_flags_all(self): + handler = Handler(self.callback_all_passable, autowire=True) + assert handler._autowire_initialized is False + assert handler.pass_update_queue is False + assert handler.pass_job_queue is False + assert handler.pass_chat_data is False + assert handler.pass_user_data is False + + handler.set_autowired_flags() + + assert handler._autowire_initialized is True + assert handler.pass_update_queue is True + assert handler.pass_job_queue is True + assert handler.pass_chat_data is True + assert handler.pass_user_data is True + + def test_set_autowired_flags_some(self): + handler = Handler(self.callback_some_passable, autowire=True) + assert handler.pass_update_queue is False + assert handler.pass_chat_data is False + + handler.set_autowired_flags() + + assert handler._autowire_initialized is True + assert handler.pass_update_queue is True + assert handler.pass_chat_data is True + + def test_set_autowired_flags_wrong(self): + handler = Handler(self.callback_all_passable, autowire=True) + with pytest.raises(UserWarning): + handler.set_autowired_flags({'kektus'}) + with pytest.raises(UserWarning): + handler.set_autowired_flags({'chat_data', 'kektus'}) + with pytest.raises(UserWarning): + handler.set_autowired_flags({'bot', 'update'}) + + def test_autowire_and_pass(self): + handler = Handler(self.callback_all_passable, autowire=True, pass_chat_data=True) + with pytest.raises(UserWarning): + handler.set_autowired_flags() + + def test_not_autowired_set_flags(self): + handler = Handler(self.callback_all_passable, autowire=False) + with pytest.raises(ValueError): + handler.set_autowired_flags() + + def test_autowire_reinitialize(self): + handler = Handler(self.callback_all_passable, autowire=True) + assert handler._autowire_initialized is False + assert handler.pass_update_queue is False + assert handler.pass_job_queue is False + assert handler.pass_chat_data is False + assert handler.pass_user_data is False + + handler.set_autowired_flags() + + assert handler._autowire_initialized is True + assert handler.pass_update_queue is True + assert handler.pass_job_queue is True + assert handler.pass_chat_data is True + assert handler.pass_user_data is True + + handler.callback = self.callback_some_passable + handler.set_autowired_flags() + + assert handler._autowire_initialized is True + assert handler.pass_update_queue is True + assert handler.pass_job_queue is False + assert handler.pass_chat_data is True + assert handler.pass_user_data is False + + def test_get_available_pass_flags(self): + handler = Handler(self.callback_all_passable, autowire=True) + assert handler.pass_update_queue is False + assert handler.pass_job_queue is False + assert handler.pass_chat_data is False + assert handler.pass_user_data is False + + handler.set_autowired_flags() + + assert set(handler._get_available_pass_flags()) == {'pass_update_queue', 'pass_job_queue', + 'pass_chat_data', + 'pass_user_data'} diff --git a/tests/test_inlinequeryhandler.py b/tests/test_inlinequeryhandler.py index 3a370845bcb..ae8d706343a 100644 --- a/tests/test_inlinequeryhandler.py +++ b/tests/test_inlinequeryhandler.py @@ -78,6 +78,10 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def callback_autowire(self, update, job_queue, update_queue, chat_data, user_data): + self.test_flag = all(x is not None for x in (update, job_queue, + update_queue, chat_data, user_data)) + def callback_group(self, bot, update, groups=None, groupdict=None): if groups is not None: self.test_flag = groups == ('t', ' query') @@ -168,6 +172,32 @@ def test_pass_job_or_update_queue(self, dp, inline_query): dp.process_update(inline_query) assert self.test_flag + def test_autowire(self, dp, inline_query): + handler = InlineQueryHandler(self.callback_autowire, autowire=True) + dp.add_handler(handler) + + dp.process_update(inline_query) + assert self.test_flag + + def test_autowire_group(self, dp, inline_query): + handler = InlineQueryHandler(self.callback_group, + pattern='(?P.*)est(?P.*)', + autowire=True) + dp.add_handler(handler) + + dp.process_update(inline_query) + assert self.test_flag + + dp.remove_handler(handler) + handler = InlineQueryHandler(self.callback_group, + pattern='(?P.*)est(?P.*)', + autowire=True) + dp.add_handler(handler) + + self.test_flag = False + dp.process_update(inline_query) + assert self.test_flag + def test_other_update_types(self, false_update): handler = InlineQueryHandler(self.callback_basic) assert not handler.check_update(false_update) diff --git a/tests/test_messagehandler.py b/tests/test_messagehandler.py index 114d03ed6ea..193d6ddec13 100644 --- a/tests/test_messagehandler.py +++ b/tests/test_messagehandler.py @@ -72,6 +72,10 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def callback_autowire(self, update, job_queue, update_queue, chat_data, user_data): + self.test_flag = all(x is not None for x in (update, job_queue, + update_queue, chat_data, user_data)) + def test_basic(self, dp, message): handler = MessageHandler(None, self.callback_basic) dp.add_handler(handler) @@ -179,6 +183,13 @@ def test_pass_job_or_update_queue(self, dp, message): dp.process_update(Update(0, message=message)) assert self.test_flag + def test_autowire(self, dp, message): + handler = MessageHandler(None, self.callback_autowire, autowire=True) + dp.add_handler(handler) + + dp.process_update(Update(0, message=message)) + assert self.test_flag + def test_other_update_types(self, false_update): handler = MessageHandler(None, self.callback_basic, edited_updates=True) assert not handler.check_update(false_update) diff --git a/tests/test_precheckoutqueryhandler.py b/tests/test_precheckoutqueryhandler.py index b6a9c3a60b8..500572ea683 100644 --- a/tests/test_precheckoutqueryhandler.py +++ b/tests/test_precheckoutqueryhandler.py @@ -49,7 +49,8 @@ def false_update(request): @pytest.fixture(scope='class') def pre_checkout_query(): - return Update(1, pre_checkout_query=PreCheckoutQuery('id', User(1, 'test user', False), 'EUR', 223, + return Update(1, pre_checkout_query=PreCheckoutQuery('id', User(1, 'test user', False), 'EUR', + 223, 'invoice_payload')) @@ -77,6 +78,10 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def callback_autowire(self, update, job_queue, update_queue, chat_data, user_data): + self.test_flag = all(x is not None for x in (update, job_queue, + update_queue, chat_data, user_data)) + def test_basic(self, dp, pre_checkout_query): handler = PreCheckoutQueryHandler(self.callback_basic) dp.add_handler(handler) @@ -133,6 +138,13 @@ def test_pass_job_or_update_queue(self, dp, pre_checkout_query): dp.process_update(pre_checkout_query) assert self.test_flag + def test_autowire(self, dp, pre_checkout_query): + handler = PreCheckoutQueryHandler(self.callback_autowire, autowire=True) + dp.add_handler(handler) + + dp.process_update(pre_checkout_query) + assert self.test_flag + def test_other_update_types(self, false_update): handler = PreCheckoutQueryHandler(self.callback_basic) assert not handler.check_update(false_update) diff --git a/tests/test_regexhandler.py b/tests/test_regexhandler.py index bd87bc705e4..9e018fbb1ad 100644 --- a/tests/test_regexhandler.py +++ b/tests/test_regexhandler.py @@ -72,6 +72,10 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def callback_autowire(self, update, job_queue, update_queue, chat_data, user_data): + self.test_flag = all(x is not None for x in (update, job_queue, + update_queue, chat_data, user_data)) + def callback_group(self, bot, update, groups=None, groupdict=None): if groups is not None: self.test_flag = groups == ('t', ' message') @@ -201,6 +205,30 @@ def test_pass_job_or_update_queue(self, dp, message): dp.process_update(Update(0, message=message)) assert self.test_flag + def test_autowire(self, dp, message): + handler = RegexHandler('.*', self.callback_autowire, autowire=True) + dp.add_handler(handler) + + dp.process_update(Update(0, message=message)) + assert self.test_flag + + def test_autowire_group_dict(self, dp, message): + handler = RegexHandler('(?P.*)est(?P.*)', self.callback_group, + autowire=True) + dp.add_handler(handler) + + dp.process_update(Update(0, message)) + assert self.test_flag + + dp.remove_handler(handler) + handler = RegexHandler('(?P.*)est(?P.*)', self.callback_group, + autowire=True) + dp.add_handler(handler) + + self.test_flag = False + dp.process_update(Update(0, message)) + assert self.test_flag + def test_other_update_types(self, false_update): handler = RegexHandler('.*', self.callback_basic, edited_updates=True) assert not handler.check_update(false_update) diff --git a/tests/test_shippingqueryhandler.py b/tests/test_shippingqueryhandler.py index 47e6975b1be..edd57762c41 100644 --- a/tests/test_shippingqueryhandler.py +++ b/tests/test_shippingqueryhandler.py @@ -49,9 +49,10 @@ def false_update(request): @pytest.fixture(scope='class') def shiping_query(): - return Update(1, shipping_query=ShippingQuery(42, User(1, 'test user', False), 'invoice_payload', - ShippingAddress('EN', 'my_state', 'my_city', - 'steer_1', '', 'post_code'))) + return Update(1, + shipping_query=ShippingQuery(42, User(1, 'test user', False), 'invoice_payload', + ShippingAddress('EN', 'my_state', 'my_city', + 'steer_1', '', 'post_code'))) class TestShippingQueryHandler(object): @@ -78,6 +79,10 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def callback_autowire(self, update, job_queue, update_queue, chat_data, user_data): + self.test_flag = all(x is not None for x in (update, job_queue, + update_queue, chat_data, user_data)) + def test_basic(self, dp, shiping_query): handler = ShippingQueryHandler(self.callback_basic) dp.add_handler(handler) @@ -134,6 +139,13 @@ def test_pass_job_or_update_queue(self, dp, shiping_query): dp.process_update(shiping_query) assert self.test_flag + def test_autowire(self, dp, shiping_query): + handler = ShippingQueryHandler(self.callback_autowire, autowire=True) + dp.add_handler(handler) + + dp.process_update(shiping_query) + assert self.test_flag + def test_other_update_types(self, false_update): handler = ShippingQueryHandler(self.callback_basic) assert not handler.check_update(false_update) diff --git a/tests/test_stringcommandhandler.py b/tests/test_stringcommandhandler.py index 60fa786cb23..d36d065ca34 100644 --- a/tests/test_stringcommandhandler.py +++ b/tests/test_stringcommandhandler.py @@ -118,6 +118,25 @@ def test_pass_job_or_update_queue(self, dp): dp.process_update('/test') assert self.test_flag + def test_autowire(self, dp): + handler = StringCommandHandler('test', self.callback_queue_2, autowire=True) + dp.add_handler(handler) + + self.test_flag = False + dp.process_update('/test') + assert self.test_flag + + dp.remove_handler(handler) + handler = StringCommandHandler('test', self.sch_callback_args, autowire=True) + dp.add_handler(handler) + + dp.process_update('/test') + assert self.test_flag + + self.test_flag = False + dp.process_update('/test one two') + assert self.test_flag + def test_other_update_types(self, false_update): handler = StringCommandHandler('test', self.callback_basic) assert not handler.check_update(false_update) diff --git a/tests/test_stringregexhandler.py b/tests/test_stringregexhandler.py index 9d0ed7829f8..45d7367ab2f 100644 --- a/tests/test_stringregexhandler.py +++ b/tests/test_stringregexhandler.py @@ -65,6 +65,9 @@ def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): self.test_flag = (job_queue is not None) and (update_queue is not None) + def callback_autowire(self, update, job_queue, update_queue): + self.test_flag = all(x is not None for x in (update, job_queue, update_queue)) + def callback_group(self, bot, update, groups=None, groupdict=None): if groups is not None: self.test_flag = groups == ('t', ' message') @@ -122,6 +125,30 @@ def test_pass_job_or_update_queue(self, dp): dp.process_update('test') assert self.test_flag + def test_autowire(self, dp): + handler = StringRegexHandler('test', self.callback_autowire, autowire=True) + dp.add_handler(handler) + + dp.process_update('test') + assert self.test_flag + + def test_autowire_groups_and_groupdict(self, dp): + handler = StringRegexHandler('(?P.*)est(?P.*)', self.callback_group, + autowire=True) + dp.add_handler(handler) + + dp.process_update('test message') + assert self.test_flag + + dp.remove_handler(handler) + handler = StringRegexHandler('(?P.*)est(?P.*)', self.callback_group, + autowire=True) + dp.add_handler(handler) + + self.test_flag = False + dp.process_update('test message') + assert self.test_flag + def test_other_update_types(self, false_update): handler = StringRegexHandler('test', self.callback_basic) assert not handler.check_update(false_update) diff --git a/tests/test_typehandler.py b/tests/test_typehandler.py index 43119123628..4f34ea84dcd 100644 --- a/tests/test_typehandler.py +++ b/tests/test_typehandler.py @@ -80,3 +80,10 @@ def test_pass_job_or_update_queue(self, dp): self.test_flag = False dp.process_update({'a': 1, 'b': 2}) assert self.test_flag + + def autowire_job_update(self, dp): + handler = TypeHandler(dict, self.callback_queue_2, autowire=True) + dp.add_handler(handler) + + dp.process_update({'a': 1, 'b': 2}) + assert self.test_flag