diff --git a/playwright/_impl/_connection.py b/playwright/_impl/_connection.py index 420ca35a5..de6962e16 100644 --- a/playwright/_impl/_connection.py +++ b/playwright/_impl/_connection.py @@ -124,6 +124,22 @@ class ProtocolCallback: def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self.stack_trace: traceback.StackSummary = traceback.StackSummary() self.future = loop.create_future() + # The outer task can get cancelled by the user, this forwards the cancellation to the inner task. + current_task = asyncio.current_task() + + def cb(task: asyncio.Task) -> None: + if current_task: + current_task.remove_done_callback(cb) + if task.cancelled(): + self.future.cancel() + + if current_task: + current_task.add_done_callback(cb) + self.future.add_done_callback( + lambda _: current_task.remove_done_callback(cb) + if current_task + else None + ) class RootChannelOwner(ChannelOwner): diff --git a/tests/async/test_asyncio.py b/tests/async/test_asyncio.py new file mode 100644 index 000000000..26d376c8c --- /dev/null +++ b/tests/async/test_asyncio.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License") +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import gc +from typing import Dict + +import pytest + +from playwright.async_api import async_playwright + + +async def test_should_cancel_underlying_protocol_calls( + browser_name: str, launch_arguments: Dict +): + handler_exception = None + + def exception_handlerdler(loop, context) -> None: + nonlocal handler_exception + handler_exception = context["exception"] + + asyncio.get_running_loop().set_exception_handler(exception_handlerdler) + + async with async_playwright() as p: + browser = await p[browser_name].launch(**launch_arguments) + page = await browser.new_page() + task = asyncio.create_task(page.wait_for_selector("will-never-find")) + # make sure that the wait_for_selector message was sent to the server (driver) + await asyncio.sleep(0.1) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + await browser.close() + + # The actual 'Future exception was never retrieved' is logged inside the Future destructor (__del__). + gc.collect() + + assert handler_exception is None + + asyncio.get_running_loop().set_exception_handler(None)