diff --git a/telegram/base.py b/telegram/base.py index 4564b5b12d6..3431396c95c 100644 --- a/telegram/base.py +++ b/telegram/base.py @@ -23,13 +23,10 @@ except ImportError: import json -from abc import ABCMeta - class TelegramObject(object): """Base class for most telegram objects.""" - __metaclass__ = ABCMeta _id_attrs = () def __str__(self): diff --git a/telegram/ext/basepersistence.py b/telegram/ext/basepersistence.py index 23c42453b68..b4004a7c33f 100644 --- a/telegram/ext/basepersistence.py +++ b/telegram/ext/basepersistence.py @@ -18,8 +18,10 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the BasePersistence class.""" +from abc import ABC, abstractmethod -class BasePersistence(object): + +class BasePersistence(ABC): """Interface class for adding persistence to your bot. Subclass this object for different implementations of a persistent bot. @@ -57,6 +59,7 @@ def __init__(self, store_user_data=True, store_chat_data=True, store_bot_data=Tr self.store_chat_data = store_chat_data self.store_bot_data = store_bot_data + @abstractmethod def get_user_data(self): """"Will be called by :class:`telegram.ext.Dispatcher` upon creation with a persistence object. It should return the user_data if stored, or an empty @@ -65,8 +68,8 @@ def get_user_data(self): Returns: :obj:`defaultdict`: The restored user data. """ - raise NotImplementedError + @abstractmethod def get_chat_data(self): """"Will be called by :class:`telegram.ext.Dispatcher` upon creation with a persistence object. It should return the chat_data if stored, or an empty @@ -75,8 +78,8 @@ def get_chat_data(self): Returns: :obj:`defaultdict`: The restored chat data. """ - raise NotImplementedError + @abstractmethod def get_bot_data(self): """"Will be called by :class:`telegram.ext.Dispatcher` upon creation with a persistence object. It should return the bot_data if stored, or an empty @@ -85,8 +88,8 @@ def get_bot_data(self): Returns: :obj:`defaultdict`: The restored bot data. """ - raise NotImplementedError + @abstractmethod def get_conversations(self, name): """"Will be called by :class:`telegram.ext.Dispatcher` when a :class:`telegram.ext.ConversationHandler` is added if @@ -99,8 +102,8 @@ def get_conversations(self, name): Returns: :obj:`dict`: The restored conversations for the handler. """ - raise NotImplementedError + @abstractmethod def update_conversation(self, name, key, new_state): """Will be called when a :attr:`telegram.ext.ConversationHandler.update_state` is called. this allows the storeage of the new state in the persistence. @@ -110,8 +113,8 @@ def update_conversation(self, name, key, new_state): key (:obj:`tuple`): The key the state is changed for. new_state (:obj:`tuple` | :obj:`any`): The new state for the given key. """ - raise NotImplementedError + @abstractmethod def update_user_data(self, user_id, data): """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has handled an update. @@ -120,8 +123,8 @@ def update_user_data(self, user_id, data): user_id (:obj:`int`): The user the data might have been changed for. data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data` [user_id]. """ - raise NotImplementedError + @abstractmethod def update_chat_data(self, chat_id, data): """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has handled an update. @@ -130,8 +133,8 @@ def update_chat_data(self, chat_id, data): chat_id (:obj:`int`): The chat the data might have been changed for. data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data` [chat_id]. """ - raise NotImplementedError + @abstractmethod def update_bot_data(self, data): """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has handled an update. @@ -139,7 +142,6 @@ def update_bot_data(self, data): Args: data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.bot_data` . """ - raise NotImplementedError def flush(self): """Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index ea4e274226c..b91463de4aa 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -21,13 +21,14 @@ import re from future.utils import string_types +from abc import ABC, abstractmethod from telegram import Chat, Update, MessageEntity __all__ = ['Filters', 'BaseFilter', 'InvertedFilter', 'MergedFilter'] -class BaseFilter(object): +class BaseFilter(ABC): """Base class for all Message Filters. Subclassing from this class filters to be combined using bitwise operators: @@ -103,6 +104,7 @@ def __repr__(self): self.name = self.__class__.__name__ return self.name + @abstractmethod def filter(self, update): """This method must be overwritten. @@ -118,8 +120,6 @@ def filter(self, update): """ - raise NotImplementedError - class InvertedFilter(BaseFilter): """Represents a filter that has been inverted. diff --git a/telegram/ext/handler.py b/telegram/ext/handler.py index b01aa58b74e..c7706985fb0 100644 --- a/telegram/ext/handler.py +++ b/telegram/ext/handler.py @@ -18,8 +18,10 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the base class for handlers as used by the Dispatcher.""" +from abc import ABC, abstractmethod -class Handler(object): + +class Handler(ABC): """The base class for all update handlers. Create custom handlers by inheriting from it. Attributes: @@ -82,6 +84,7 @@ def __init__(self, self.pass_user_data = pass_user_data self.pass_chat_data = pass_chat_data + @abstractmethod def check_update(self, update): """ This method is called to determine if an update should be handled by @@ -96,7 +99,6 @@ def check_update(self, update): when the update gets handled. """ - raise NotImplementedError def handle_update(self, update, dispatcher, check_result, context=None): """ diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index b3e1c3eb32b..84ecd09c335 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -373,6 +373,12 @@ def get_user_data(self): def update_user_data(self, user_id, data): raise Exception + def get_conversations(self, name): + pass + + def update_conversation(self, name, key, new_state): + pass + def start1(b, u): pass @@ -470,6 +476,21 @@ def update_chat_data(self, chat_id, data): def update_user_data(self, user_id, data): self.update(data) + def get_chat_data(self): + pass + + def get_bot_data(self): + pass + + def get_user_data(self): + pass + + def get_conversations(self, name): + pass + + def update_conversation(self, name, key, new_state): + pass + def callback(update, context): pass @@ -513,6 +534,21 @@ def update_chat_data(self, chat_id, data): def update_user_data(self, user_id, data): self.test_flag_user_data = True + def update_conversation(self, name, key, new_state): + pass + + def get_conversations(self, name): + pass + + def get_user_data(self): + pass + + def get_bot_data(self): + pass + + def get_chat_data(self): + pass + def callback(update, context): pass diff --git a/tests/test_filters.py b/tests/test_filters.py index c8961589938..8986346d461 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -730,10 +730,8 @@ def test_faulty_custom_filter(self, update): class _CustomFilter(BaseFilter): pass - custom = _CustomFilter() - - with pytest.raises(NotImplementedError): - (custom & Filters.text)(update) + with pytest.raises(TypeError, match='Can\'t instantiate abstract class _CustomFilter'): + _CustomFilter() def test_custom_unnamed_filter(self, update): class Unnamed(BaseFilter): diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 20fe75d5783..8c307e51ecf 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -51,7 +51,33 @@ def change_directory(tmp_path): @pytest.fixture(scope="function") def base_persistence(): - return BasePersistence(store_chat_data=True, store_user_data=True, store_bot_data=True) + class OwnPersistence(BasePersistence): + + def get_bot_data(self): + raise NotImplementedError + + def get_chat_data(self): + raise NotImplementedError + + def get_user_data(self): + raise NotImplementedError + + def get_conversations(self, name): + raise NotImplementedError + + def update_bot_data(self, data): + raise NotImplementedError + + def update_chat_data(self, chat_id, data): + raise NotImplementedError + + def update_conversation(self, name, key, new_state): + raise NotImplementedError + + def update_user_data(self, user_id, data): + raise NotImplementedError + + return OwnPersistence(store_chat_data=True, store_user_data=True, store_bot_data=True) @pytest.fixture(scope="function") @@ -100,22 +126,13 @@ class TestBasePersistence(object): def test_creation(self, base_persistence): assert base_persistence.store_chat_data assert base_persistence.store_user_data - with pytest.raises(NotImplementedError): - base_persistence.get_bot_data() - with pytest.raises(NotImplementedError): - base_persistence.get_chat_data() - with pytest.raises(NotImplementedError): - base_persistence.get_user_data() - with pytest.raises(NotImplementedError): - base_persistence.get_conversations("test") - with pytest.raises(NotImplementedError): - base_persistence.update_bot_data(None) - with pytest.raises(NotImplementedError): - base_persistence.update_chat_data(None, None) - with pytest.raises(NotImplementedError): - base_persistence.update_user_data(None, None) - with pytest.raises(NotImplementedError): - base_persistence.update_conversation(None, None, None) + assert base_persistence.store_bot_data + + def test_abstract_methods(self): + with pytest.raises(TypeError, match=('get_bot_data, get_chat_data, get_conversations, ' + 'get_user_data, update_bot_data, update_chat_data, ' + 'update_conversation, update_user_data')): + BasePersistence() def test_implementation(self, updater, base_persistence): dp = updater.dispatcher @@ -127,8 +144,6 @@ def test_conversationhandler_addition(self, dp, base_persistence): with pytest.raises(ValueError, match="if dispatcher has no persistence"): dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler")) dp.persistence = base_persistence - with pytest.raises(NotImplementedError): - dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler")) def test_dispatcher_integration_init(self, bot, base_persistence, chat_data, user_data, bot_data): diff --git a/tests/test_updater.py b/tests/test_updater.py index 4d6af80f4db..59009cb5236 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -41,7 +41,7 @@ from telegram import TelegramError, Message, User, Chat, Update, Bot from telegram.error import Unauthorized, InvalidToken, TimedOut, RetryAfter -from telegram.ext import Updater, Dispatcher, BasePersistence +from telegram.ext import Updater, Dispatcher, DictPersistence signalskip = pytest.mark.skipif(sys.platform == 'win32', reason='Can\'t send signals without stopping ' @@ -467,7 +467,7 @@ def test_mutual_exclude_bot_dispatcher(self): def test_mutual_exclude_persistence_dispatcher(self): dispatcher = Dispatcher(None, None) - persistence = BasePersistence() + persistence = DictPersistence() with pytest.raises(ValueError): Updater(dispatcher=dispatcher, persistence=persistence)