From 63a09f3d9b9844b2c38b03909f17a7c939cc8b43 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:26:01 +0100 Subject: [PATCH 1/4] Support `allow_paid_broadcast` in `AIORateLimiter` --- telegram/ext/_aioratelimiter.py | 60 +++++++++++++++++--------- tests/ext/test_ratelimiter.py | 75 +++++++++++++++++++++++++++++++-- 2 files changed, 112 insertions(+), 23 deletions(-) diff --git a/telegram/ext/_aioratelimiter.py b/telegram/ext/_aioratelimiter.py index e619819eac8..ac249635015 100644 --- a/telegram/ext/_aioratelimiter.py +++ b/telegram/ext/_aioratelimiter.py @@ -32,6 +32,7 @@ except ImportError: AIO_LIMITER_AVAILABLE = False +from telegram import constants from telegram._utils.logging import get_logger from telegram._utils.types import JSONDict from telegram.error import RetryAfter @@ -86,7 +87,8 @@ class AIORateLimiter(BaseRateLimiter[int]): * A :exc:`~telegram.error.RetryAfter` exception will halt *all* requests for :attr:`~telegram.error.RetryAfter.retry_after` + 0.1 seconds. This may be stricter than necessary in some cases, e.g. the bot may hit a rate limit in one group but might still - be allowed to send messages in another group. + be allowed to send messages in another group or with + :paramref:`~telegram.Bot.send_message.allow_paid_broadcast` set to ``True``. Tip: With `Bot API 7.1 `_ @@ -96,10 +98,10 @@ class AIORateLimiter(BaseRateLimiter[int]): :tg-const:`telegram.constants.FloodLimit.PAID_MESSAGES_PER_SECOND` messages per second by paying a fee in Telegram Stars. - .. caution:: - This class currently doesn't take the - :paramref:`~telegram.Bot.send_message.allow_paid_broadcast` parameter into account. - This means that the rate limiting is applied just like for any other message. + .. versionchanged:: NEXT.VERSION + This class automatically takes the + :paramref:`~telegram.Bot.send_message.allow_paid_broadcast` parameter into account and + throttles the requests accordingly. Note: This class is to be understood as minimal effort reference implementation. @@ -114,16 +116,17 @@ class AIORateLimiter(BaseRateLimiter[int]): Args: overall_max_rate (:obj:`float`): The maximum number of requests allowed for the entire bot per :paramref:`overall_time_period`. When set to 0, no rate limiting will be applied. - Defaults to ``30``. + Defaults to :tg-const:`telegram.constants.FloodLimit.MESSAGES_PER_SECOND`. overall_time_period (:obj:`float`): The time period (in seconds) during which the :paramref:`overall_max_rate` is enforced. When set to 0, no rate limiting will be - applied. Defaults to 1. + applied. Defaults to ``1``. group_max_rate (:obj:`float`): The maximum number of requests allowed for requests related to groups and channels per :paramref:`group_time_period`. When set to 0, no rate - limiting will be applied. Defaults to 20. + limiting will be applied. Defaults to + :tg-const:`telegram.constants.FloodLimit.MESSAGES_PER_MINUTE_PER_GROUP`. group_time_period (:obj:`float`): The time period (in seconds) during which the :paramref:`group_max_rate` is enforced. When set to 0, no rate limiting will be - applied. Defaults to 60. + applied. Defaults to ``60``. max_retries (:obj:`int`): The maximum number of retries to be made in case of a :exc:`~telegram.error.RetryAfter` exception. If set to 0, no retries will be made. Defaults to ``0``. @@ -131,6 +134,7 @@ class AIORateLimiter(BaseRateLimiter[int]): """ __slots__ = ( + "_apb_limiter", "_base_limiter", "_group_limiters", "_group_max_rate", @@ -141,9 +145,9 @@ class AIORateLimiter(BaseRateLimiter[int]): def __init__( self, - overall_max_rate: float = 30, + overall_max_rate: float = constants.FloodLimit.MESSAGES_PER_SECOND, overall_time_period: float = 1, - group_max_rate: float = 20, + group_max_rate: float = constants.FloodLimit.MESSAGES_PER_MINUTE_PER_GROUP, group_time_period: float = 60, max_retries: int = 0, ) -> None: @@ -167,6 +171,9 @@ def __init__( self._group_time_period = 0 self._group_limiters: dict[Union[str, int], AsyncLimiter] = {} + self._apb_limiter: AsyncLimiter = AsyncLimiter( + max_rate=constants.FloodLimit.PAID_MESSAGES_PER_SECOND, time_period=1 + ) self._max_retries: int = max_retries self._retry_after_event = asyncio.Event() self._retry_after_event.set() @@ -201,21 +208,30 @@ async def _run_request( self, chat: bool, group: Union[str, int, bool], + allow_paid_broadcast: bool, callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, list[JSONDict]]]], args: Any, kwargs: dict[str, Any], ) -> Union[bool, JSONDict, list[JSONDict]]: - base_context = self._base_limiter if (chat and self._base_limiter) else null_context() - group_context = ( - self._get_group_limiter(group) if group and self._group_max_rate else null_context() - ) - - async with group_context, base_context: + async def inner() -> Union[bool, JSONDict, list[JSONDict]]: # In case a retry_after was hit, we wait with processing the request await self._retry_after_event.wait() - return await callback(*args, **kwargs) + if allow_paid_broadcast: + async with self._apb_limiter: + return await inner() + else: + base_context = self._base_limiter if (chat and self._base_limiter) else null_context() + group_context = ( + self._get_group_limiter(group) + if group and self._group_max_rate + else null_context() + ) + + async with group_context, base_context: + return await inner() + # mypy doesn't understand that the last run of the for loop raises an exception async def process_request( self, @@ -242,6 +258,7 @@ async def process_request( group: Union[int, str, bool] = False chat: bool = False chat_id = data.get("chat_id") + allow_paid_broadcast = data.get("allow_paid_broadcast", False) if chat_id is not None: chat = True @@ -257,7 +274,12 @@ async def process_request( for i in range(max_retries + 1): try: return await self._run_request( - chat=chat, group=group, callback=callback, args=args, kwargs=kwargs + chat=chat, + group=group, + allow_paid_broadcast=allow_paid_broadcast, + callback=callback, + args=args, + kwargs=kwargs, ) except RetryAfter as exc: if i == max_retries: diff --git a/tests/ext/test_ratelimiter.py b/tests/ext/test_ratelimiter.py index 8af1e541118..9ac9e4925da 100644 --- a/tests/ext/test_ratelimiter.py +++ b/tests/ext/test_ratelimiter.py @@ -148,7 +148,9 @@ async def do_request(self, *args, **kwargs): @pytest.mark.flaky(10, 1) # Timings aren't quite perfect class TestAIORateLimiter: count = 0 + apb_count = 0 call_times = [] + apb_call_times = [] class CountRequest(BaseRequest): def __init__(self, retry_after=None): @@ -161,8 +163,16 @@ async def shutdown(self) -> None: pass async def do_request(self, *args, **kwargs): - TestAIORateLimiter.count += 1 - TestAIORateLimiter.call_times.append(time.time()) + request_data = kwargs.get("request_data") + allow_paid_broadcast = request_data.parameters.get("allow_paid_broadcast", False) + + if allow_paid_broadcast: + TestAIORateLimiter.apb_count += 1 + TestAIORateLimiter.apb_call_times.append(time.time()) + else: + TestAIORateLimiter.count += 1 + TestAIORateLimiter.call_times.append(time.time()) + if self.retry_after: raise RetryAfter(retry_after=1) @@ -190,10 +200,14 @@ async def do_request(self, *args, **kwargs): @pytest.fixture(autouse=True) def _reset(self): - self.count = 0 + # self.count = 0 + # self.apb_count = 0 TestAIORateLimiter.count = 0 - self.call_times = [] TestAIORateLimiter.call_times = [] + # self.call_times = [] + # self.apb_call_times = [] + TestAIORateLimiter.call_times = [] + TestAIORateLimiter.apb_call_times = [] @pytest.mark.parametrize("max_retries", [0, 1, 4]) async def test_max_retries(self, bot, max_retries): @@ -358,3 +372,56 @@ async def test_group_caching(self, bot, intermediate): finally: TestAIORateLimiter.count = 0 TestAIORateLimiter.call_times = [] + + async def test_allow_paid_broadcast(self, bot): + try: + rl_bot = ExtBot( + token=bot.token, + request=self.CountRequest(retry_after=None), + rate_limiter=AIORateLimiter(), + ) + + async with rl_bot: + apb_tasks = {} + non_apb_tasks = {} + for i in range(3000): + apb_tasks[i] = asyncio.create_task( + rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=True) + ) + + number = 2 + for i in range(number): + non_apb_tasks[i] = asyncio.create_task( + rl_bot.send_message(chat_id=-1, text="test") + ) + non_apb_tasks[i + number] = asyncio.create_task( + rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=False) + ) + + await asyncio.sleep(0.85) + # We expect 5 non-apb requests: + # 1: `get_me` from `async with rl_bot` + # 2: `send_message` at time 0.00 + # 3: `send_message` at time 0.25 + # 4: `send_message` at time 0.50 + # 5: `send_message` at time 0.75 + # We expect + assert TestAIORateLimiter.count == 5 + assert sum(1 for task in non_apb_tasks.values() if task.done()) == 4 + + # 1 second after start + await asyncio.sleep(1 - 0.85) + # We expect ~2000 apb requests after the first second + # 2000 (>>1000), since we have a floating window logic such that an initial + # burst is allowed that is hard to measure in the tests + assert TestAIORateLimiter.apb_count < 3000 + assert sum(1 for task in apb_tasks.values() if task.done()) < 3000 + + # 2 seconds after start + await asyncio.sleep(2.1 - 1) + assert TestAIORateLimiter.apb_count == 3000 + assert all(task.done() for task in apb_tasks.values()) + + finally: + # cleanup + await asyncio.gather(*apb_tasks.values(), *non_apb_tasks.values()) From 7d73c6cb04a0fab904b0def547bc7d253a8e56da Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:40:20 +0100 Subject: [PATCH 2/4] try fixing tests --- tests/ext/test_ratelimiter.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/ext/test_ratelimiter.py b/tests/ext/test_ratelimiter.py index 9ac9e4925da..e233670ca40 100644 --- a/tests/ext/test_ratelimiter.py +++ b/tests/ext/test_ratelimiter.py @@ -200,13 +200,9 @@ async def do_request(self, *args, **kwargs): @pytest.fixture(autouse=True) def _reset(self): - # self.count = 0 - # self.apb_count = 0 TestAIORateLimiter.count = 0 TestAIORateLimiter.call_times = [] - # self.call_times = [] - # self.apb_call_times = [] - TestAIORateLimiter.call_times = [] + TestAIORateLimiter.apb_count = 0 TestAIORateLimiter.apb_call_times = [] @pytest.mark.parametrize("max_retries", [0, 1, 4]) From 2aad73ea09aeecb5ae12cea7e0e548d7a7589b6b Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 31 Dec 2024 16:27:00 +0100 Subject: [PATCH 3/4] try fixing tests --- tests/ext/test_ratelimiter.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/ext/test_ratelimiter.py b/tests/ext/test_ratelimiter.py index e233670ca40..5e08a0c8283 100644 --- a/tests/ext/test_ratelimiter.py +++ b/tests/ext/test_ratelimiter.py @@ -26,6 +26,7 @@ import json import platform import time +from collections import defaultdict from http import HTTPStatus import pytest @@ -405,18 +406,26 @@ async def test_allow_paid_broadcast(self, bot): assert TestAIORateLimiter.count == 5 assert sum(1 for task in non_apb_tasks.values() if task.done()) == 4 - # 1 second after start - await asyncio.sleep(1 - 0.85) + # ~2 second after start + # We do the checks once all apb_tasks are done as apparently getting the timings + # right to check after 1 second is hard + await asyncio.sleep(2.1 - 0.85) + assert all(task.done() for task in apb_tasks.values()) + + apb_call_times = [ + ct - TestAIORateLimiter.apb_call_times[0] + for ct in TestAIORateLimiter.apb_call_times + ] + apb_call_times_dict = defaultdict(int) + for ct in apb_call_times: + apb_call_times_dict[int(ct)] += 1 + # We expect ~2000 apb requests after the first second # 2000 (>>1000), since we have a floating window logic such that an initial # burst is allowed that is hard to measure in the tests - assert TestAIORateLimiter.apb_count < 3000 - assert sum(1 for task in apb_tasks.values() if task.done()) < 3000 - - # 2 seconds after start - await asyncio.sleep(2.1 - 1) - assert TestAIORateLimiter.apb_count == 3000 - assert all(task.done() for task in apb_tasks.values()) + assert apb_call_times_dict[0] <= 2000 + assert apb_call_times_dict[0] + apb_call_times_dict[1] < 3000 + assert sum(apb_call_times_dict.values()) == 3000 finally: # cleanup From e31cf45127f0ae172d66e51e6182d44b90caf074 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 22 Jan 2025 20:24:12 +0100 Subject: [PATCH 4/4] review --- telegram/ext/_aioratelimiter.py | 2 +- tests/ext/test_ratelimiter.py | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/telegram/ext/_aioratelimiter.py b/telegram/ext/_aioratelimiter.py index 5f3edf8933c..7471652dd73 100644 --- a/telegram/ext/_aioratelimiter.py +++ b/telegram/ext/_aioratelimiter.py @@ -88,7 +88,7 @@ class AIORateLimiter(BaseRateLimiter[int]): :attr:`~telegram.error.RetryAfter.retry_after` + 0.1 seconds. This may be stricter than necessary in some cases, e.g. the bot may hit a rate limit in one group but might still be allowed to send messages in another group or with - :paramref:`~telegram.Bot.send_message.allow_paid_broadcast` set to ``True``. + :paramref:`~telegram.Bot.send_message.allow_paid_broadcast` set to :obj:`True`. Tip: With `Bot API 7.1 `_ diff --git a/tests/ext/test_ratelimiter.py b/tests/ext/test_ratelimiter.py index 8663f487f65..b1c66b6009b 100644 --- a/tests/ext/test_ratelimiter.py +++ b/tests/ext/test_ratelimiter.py @@ -26,7 +26,7 @@ import json import platform import time -from collections import defaultdict +from collections import Counter from http import HTTPStatus import pytest @@ -395,30 +395,24 @@ async def test_allow_paid_broadcast(self, bot): rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=False) ) - await asyncio.sleep(0.85) + await asyncio.sleep(0.1) # We expect 5 non-apb requests: # 1: `get_me` from `async with rl_bot` - # 2: `send_message` at time 0.00 - # 3: `send_message` at time 0.25 - # 4: `send_message` at time 0.50 - # 5: `send_message` at time 0.75 - # We expect + # 2-5: `send_message` assert TestAIORateLimiter.count == 5 assert sum(1 for task in non_apb_tasks.values() if task.done()) == 4 # ~2 second after start # We do the checks once all apb_tasks are done as apparently getting the timings # right to check after 1 second is hard - await asyncio.sleep(2.1 - 0.85) + await asyncio.sleep(2.1 - 0.1) assert all(task.done() for task in apb_tasks.values()) apb_call_times = [ ct - TestAIORateLimiter.apb_call_times[0] for ct in TestAIORateLimiter.apb_call_times ] - apb_call_times_dict = defaultdict(int) - for ct in apb_call_times: - apb_call_times_dict[int(ct)] += 1 + apb_call_times_dict = Counter(map(int, apb_call_times)) # We expect ~2000 apb requests after the first second # 2000 (>>1000), since we have a floating window logic such that an initial