Skip to content

Fix Application.create_task type hinting #3543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ The following wonderful people contributed directly or indirectly to this projec
- `Riko Naka <https://github.com/rikonaka>`_
- `Rizlas <https://github.com/rizlas>`_
- `Sahil Sharma <https://github.com/sahilsharma811>`_
- `Sam Mosleh <https://github.com/sam-mosleh>`_
- `Sascha <https://github.com/saschalalala>`_
- `Shelomentsev D <https://github.com/shelomentsevd>`_
- `Shivam Saini <https://github.com/shivamsn97>`_
Expand Down
31 changes: 22 additions & 9 deletions telegram/ext/_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
TYPE_CHECKING,
Any,
AsyncContextManager,
Awaitable,
Callable,
Coroutine,
DefaultDict,
Dict,
Generator,
Generic,
List,
Mapping,
Expand Down Expand Up @@ -71,7 +73,6 @@
DEFAULT_GROUP: int = 0

_AppType = TypeVar("_AppType", bound="Application") # pylint: disable=invalid-name
_RT = TypeVar("_RT")
_STOP_SIGNAL = object()

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -934,7 +935,9 @@ def __run(
loop.close()

def create_task(
self, coroutine: Coroutine[Any, Any, RT], update: object = None
self,
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None,
) -> "asyncio.Task[RT]":
"""Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by
the :paramref:`coroutine` with :meth:`process_error`.
Expand All @@ -948,7 +951,10 @@ def create_task(
.. seealso:: :wiki:`Concurrency`

Args:
coroutine (:term:`coroutine function`): The coroutine to run as task.
coroutine (:term:`awaitable`): The awaitable to run as task.

.. versionchanged:: 20.2
Accepts :class:`asyncio.Future` and generator-based coroutine functions.
update (:obj:`object`, optional): If set, will be passed to :meth:`process_error`
as additional information for the error handlers. Moreover, the corresponding
:attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of
Expand All @@ -960,13 +966,16 @@ def create_task(
return self.__create_task(coroutine=coroutine, update=update)

def __create_task(
self, coroutine: Coroutine, update: object = None, is_error_handler: bool = False
) -> asyncio.Task:
self,
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None,
is_error_handler: bool = False,
) -> "asyncio.Task[RT]":
# Unfortunately, we can't know if `coroutine` runs one of the error handler functions
# but by passing `is_error_handler=True` from `process_error`, we can make sure that we
# get at most one recursion of the user calls `create_task` manually with an error handler
# function
task = asyncio.create_task(
task: "asyncio.Task[RT]" = asyncio.create_task(
self.__create_task_callback(
coroutine=coroutine, update=update, is_error_handler=is_error_handler
)
Expand Down Expand Up @@ -995,11 +1004,13 @@ def __create_task_done_callback(self, task: asyncio.Task) -> None:

async def __create_task_callback(
self,
coroutine: Coroutine[Any, Any, _RT],
coroutine: Union[Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]],
update: object = None,
is_error_handler: bool = False,
) -> _RT:
) -> RT:
try:
if isinstance(coroutine, Generator):
return await asyncio.create_task(coroutine)
return await coroutine
except asyncio.CancelledError as cancel:
# TODO: in py3.8+, CancelledError is a subclass of BaseException, so we can drop this
Expand Down Expand Up @@ -1562,7 +1573,9 @@ async def process_error(
update: Optional[object],
error: Exception,
job: "Job[CCT]" = None,
coroutine: Coroutine[Any, Any, Any] = None,
coroutine: Union[
Generator[Optional["asyncio.Future[object]"], None, RT], Awaitable[RT]
] = None,
) -> bool:
"""Processes an error by passing it to all error handlers registered with
:meth:`add_error_handler`. If one of the error handlers raises
Expand Down
20 changes: 13 additions & 7 deletions telegram/ext/_callbackcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Awaitable,
Dict,
Generator,
Generic,
List,
Match,
NoReturn,
Optional,
Type,
Union,
)

from telegram._callbackquery import CallbackQuery
Expand All @@ -37,7 +39,7 @@
from telegram.ext._utils.types import BD, BT, CD, UD

if TYPE_CHECKING:
from asyncio import Queue
from asyncio import Future, Queue

from telegram.ext import Application, Job, JobQueue # noqa: F401
from telegram.ext._utils.types import CCT
Expand Down Expand Up @@ -96,8 +98,8 @@ class CallbackContext(Generic[BT, UD, CD, BD]):

.. versionadded:: 20.0
Attributes:
coroutine (:term:`coroutine function`): Optional. Only present in error handlers if the
error was caused by a coroutine run with :meth:`Application.create_task` or a handler
coroutine (:term:`awaitable`): Optional. Only present in error handlers if the
error was caused by an awaitable run with :meth:`Application.create_task` or a handler
callback with :attr:`block=False <BaseHandler.block>`.
matches (List[:meth:`re.Match <re.Match.expand>`]): Optional. If the associated update
originated from a :class:`filters.Regex`, this will contain a list of match objects for
Expand Down Expand Up @@ -143,7 +145,9 @@ def __init__(
self.matches: Optional[List[Match[str]]] = None
self.error: Optional[Exception] = None
self.job: Optional["Job[CCT]"] = None
self.coroutine: Optional[Coroutine[Any, Any, Any]] = None
self.coroutine: Optional[
Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]]
] = None

@property
def application(self) -> "Application[BT, CCT, UD, CD, BD, Any]":
Expand Down Expand Up @@ -275,7 +279,7 @@ def from_error(
error: Exception,
application: "Application[BT, CCT, UD, CD, BD, Any]",
job: "Job[Any]" = None,
coroutine: Coroutine[Any, Any, Any] = None,
coroutine: Union[Generator[Optional["Future[object]"], None, Any], Awaitable[Any]] = None,
) -> "CCT":
"""
Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error
Expand All @@ -295,13 +299,15 @@ def from_error(
job (:class:`telegram.ext.Job`, optional): The job associated with the error.

.. versionadded:: 20.0
coroutine (:term:`coroutine function`, optional): The coroutine function associated
coroutine (:term:`awaitable`, optional): The awaitable associated
with this error if the error was caused by a coroutine run with
:meth:`Application.create_task` or a handler callback with
:attr:`block=False <BaseHandler.block>`.

.. versionadded:: 20.0

.. versionchanged:: 20.2
Accepts :class:`asyncio.Future` and generator-based coroutine functions.
Returns:
:class:`telegram.ext.CallbackContext`
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,26 @@ async def callback_2():
await asyncio.sleep(0.05)
assert stop_task.done()

async def test_create_task_awaiting_future(self, app):
async def callback():
await asyncio.sleep(0.01)
return 42

# `asyncio.gather` returns an `asyncio.Future` and not an
# `asyncio.Task`
out = await app.create_task(asyncio.gather(callback()))
assert out == [42]

async def test_create_task_awaiting_generator(self, app):
event = asyncio.Event()

def gen():
yield
event.set()

await app.create_task(gen())
assert event.is_set()

async def test_no_concurrent_updates(self, app):
queue = asyncio.Queue()
event_1 = asyncio.Event()
Expand Down