diff --git a/playwright/accessibility.py b/playwright/accessibility.py index ceefb1c1e..dc7adaada 100644 --- a/playwright/accessibility.py +++ b/playwright/accessibility.py @@ -57,6 +57,7 @@ class Accessibility: def __init__(self, channel: Channel) -> None: self._channel = channel self._loop = channel._connection._loop + self._dispatcher_fiber = channel._connection._dispatcher_fiber async def snapshot( self, interestingOnly: bool = None, root: ElementHandle = None diff --git a/playwright/connection.py b/playwright/connection.py index ef2f454b5..cd461d233 100644 --- a/playwright/connection.py +++ b/playwright/connection.py @@ -15,13 +15,13 @@ import asyncio import sys import traceback +from pathlib import Path from typing import Any, Callable, Dict, Optional, Union from greenlet import greenlet from pyee import AsyncIOEventEmitter from playwright.helper import ParsedMessagePayload, parse_error -from playwright.sync_base import dispatcher_fiber from playwright.transport import Transport @@ -74,6 +74,7 @@ def __init__( ) -> None: super().__init__(loop=parent._loop) self._loop: asyncio.AbstractEventLoop = parent._loop + self._dispatcher_fiber: Any = parent._dispatcher_fiber self._type = type self._guid = guid self._connection: Connection = ( @@ -116,33 +117,30 @@ def __init__(self, connection: "Connection") -> None: class Connection: def __init__( - self, - input: asyncio.StreamReader, - output: asyncio.StreamWriter, - object_factory: Any, - loop: asyncio.AbstractEventLoop, + self, dispatcher_fiber: Any, object_factory: Any, driver_executable: Path ) -> None: - self._transport = Transport(input, output, loop) + self._dispatcher_fiber: Any = dispatcher_fiber + self._transport = Transport(driver_executable) self._transport.on_message = lambda msg: self._dispatch(msg) self._waiting_for_object: Dict[str, Any] = {} self._last_id = 0 - self._loop = loop self._objects: Dict[str, ChannelOwner] = {} self._callbacks: Dict[int, ProtocolCallback] = {} - self._root_object = RootChannelOwner(self) self._object_factory = object_factory self._is_sync = False - def run_sync(self) -> None: + async def run_as_sync(self) -> None: self._is_sync = True - self._transport.run_sync() + await self.run() - def run_async(self) -> None: - self._transport.run_async() + async def run(self) -> None: + self._loop = asyncio.get_running_loop() + self._root_object = RootChannelOwner(self) + await self._transport.run() def stop_sync(self) -> None: self._transport.stop() - dispatcher_fiber().switch() + self._dispatcher_fiber.switch() def stop_async(self) -> None: self._transport.stop() diff --git a/playwright/file_chooser.py b/playwright/file_chooser.py index d8a993b25..cebbcebd0 100644 --- a/playwright/file_chooser.py +++ b/playwright/file_chooser.py @@ -27,6 +27,7 @@ def __init__( ) -> None: self._page = page self._loop = page._loop + self._dispatcher_fiber = page._dispatcher_fiber self._element_handle = element_handle self._is_multiple = is_multiple diff --git a/playwright/input.py b/playwright/input.py index 515a69e20..e72e905f1 100644 --- a/playwright/input.py +++ b/playwright/input.py @@ -20,6 +20,7 @@ class Keyboard: def __init__(self, channel: Channel) -> None: self._channel = channel self._loop = channel._connection._loop + self._dispatcher_fiber = channel._connection._dispatcher_fiber async def down(self, key: str) -> None: await self._channel.send("keyboardDown", locals_to_params(locals())) @@ -41,6 +42,7 @@ class Mouse: def __init__(self, channel: Channel) -> None: self._channel = channel self._loop = channel._connection._loop + self._dispatcher_fiber = channel._connection._dispatcher_fiber async def move(self, x: float, y: float, steps: int = None) -> None: await self._channel.send("mouseMove", locals_to_params(locals())) @@ -83,6 +85,7 @@ class Touchscreen: def __init__(self, channel: Channel) -> None: self._channel = channel self._loop = channel._connection._loop + self._dispatcher_fiber = channel._connection._dispatcher_fiber async def tap(self, x: float, y: float) -> None: await self._channel.send("touchscreenTap", locals_to_params(locals())) diff --git a/playwright/main.py b/playwright/main.py index 197c62a51..d0e031996 100644 --- a/playwright/main.py +++ b/playwright/main.py @@ -28,7 +28,6 @@ from playwright.path_utils import get_file_dirname from playwright.playwright import Playwright from playwright.sync_api import Playwright as SyncPlaywright -from playwright.sync_base import dispatcher_fiber, set_dispatcher_fiber def compute_driver_executable() -> Path: @@ -39,38 +38,34 @@ def compute_driver_executable() -> Path: return package_path / "driver" / "playwright-cli" -async def run_driver_async() -> Connection: - driver_executable = compute_driver_executable() - - proc = await asyncio.create_subprocess_exec( - str(driver_executable), - "run-driver", - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=sys.stderr, - limit=32768, - ) - assert proc.stdout - assert proc.stdin - connection = Connection( - proc.stdout, proc.stdin, create_remote_object, asyncio.get_event_loop() - ) - return connection - - -def run_driver() -> Connection: - loop = asyncio.get_event_loop() - if loop.is_running(): - raise Error("Can only run one Playwright at a time.") - return loop.run_until_complete(run_driver_async()) - - class SyncPlaywrightContextManager: def __init__(self) -> None: - self._connection = run_driver() self._playwright: SyncPlaywright def __enter__(self) -> SyncPlaywright: + def greenlet_main() -> None: + loop = None + own_loop = None + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + own_loop = loop + + if loop.is_running(): + raise Error("Can only run one Playwright at a time.") + + loop.run_until_complete(self._connection.run_as_sync()) + + if own_loop: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + dispatcher_fiber = greenlet(greenlet_main) + self._connection = Connection( + dispatcher_fiber, create_remote_object, compute_driver_executable() + ) + g_self = greenlet.getcurrent() def callback_wrapper(playwright_impl: Playwright) -> None: @@ -78,8 +73,8 @@ def callback_wrapper(playwright_impl: Playwright) -> None: g_self.switch() self._connection.call_on_object_with_known_name("Playwright", callback_wrapper) - set_dispatcher_fiber(greenlet(lambda: self._connection.run_sync())) - dispatcher_fiber().switch() + + dispatcher_fiber.switch() playwright = self._playwright playwright.stop = self.__exit__ # type: ignore return playwright @@ -96,8 +91,12 @@ def __init__(self) -> None: self._connection: Connection async def __aenter__(self) -> AsyncPlaywright: - self._connection = await run_driver_async() - self._connection.run_async() + self._connection = Connection( + None, create_remote_object, compute_driver_executable() + ) + loop = asyncio.get_running_loop() + self._connection._loop = loop + loop.create_task(self._connection.run()) playwright = AsyncPlaywright( await self._connection.wait_for_object_with_known_name("Playwright") ) @@ -113,8 +112,7 @@ async def __aexit__(self, *args: Any) -> None: if sys.platform == "win32": # Use ProactorEventLoop in 3.7, which is default in 3.8 - loop = asyncio.ProactorEventLoop() - asyncio.set_event_loop(loop) + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) def main() -> None: diff --git a/playwright/sync_api.py b/playwright/sync_api.py index 01ab91581..ccf88249a 100644 --- a/playwright/sync_api.py +++ b/playwright/sync_api.py @@ -577,7 +577,7 @@ def expect_event( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) def isClosed(self) -> bool: @@ -3198,7 +3198,7 @@ def expect_load_state( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForLoadState(state, timeout) + self, self._impl_obj.waitForLoadState(state, timeout) ) def expect_navigation( @@ -3227,7 +3227,7 @@ def expect_navigation( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForNavigation(url, waitUntil, timeout) + self, self._impl_obj.waitForNavigation(url, waitUntil, timeout) ) @@ -5561,7 +5561,7 @@ def expect_event( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) def expect_console_message( @@ -5590,7 +5590,7 @@ def expect_console_message( """ event = "console" return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) def expect_download( @@ -5619,7 +5619,7 @@ def expect_download( """ event = "download" return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) def expect_file_chooser( @@ -5648,7 +5648,7 @@ def expect_file_chooser( """ event = "filechooser" return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) def expect_load_state( @@ -5676,7 +5676,7 @@ def expect_load_state( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForLoadState(state, timeout) + self, self._impl_obj.waitForLoadState(state, timeout) ) def expect_navigation( @@ -5705,7 +5705,7 @@ def expect_navigation( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForNavigation(url, waitUntil, timeout) + self, self._impl_obj.waitForNavigation(url, waitUntil, timeout) ) def expect_popup( @@ -5734,7 +5734,7 @@ def expect_popup( """ event = "popup" return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) def expect_request( @@ -5763,7 +5763,7 @@ def expect_request( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForRequest(url, predicate, timeout) + self, self._impl_obj.waitForRequest(url, predicate, timeout) ) def expect_response( @@ -5792,7 +5792,7 @@ def expect_response( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForResponse(url, predicate, timeout) + self, self._impl_obj.waitForResponse(url, predicate, timeout) ) def expect_worker( @@ -5821,7 +5821,7 @@ def expect_worker( """ event = "worker" return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) @@ -6241,7 +6241,7 @@ def expect_event( page.setDefaultTimeout(timeout) methods. """ return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) def expect_page( @@ -6270,7 +6270,7 @@ def expect_page( """ event = "page" return EventContextManager( - self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + self, self._impl_obj.waitForEvent(event, predicate, timeout) ) diff --git a/playwright/sync_base.py b/playwright/sync_base.py index a532fb783..2e3e07355 100644 --- a/playwright/sync_base.py +++ b/playwright/sync_base.py @@ -34,24 +34,13 @@ T = TypeVar("T") -dispatcher_fiber_: greenlet - - -def set_dispatcher_fiber(fiber: greenlet) -> None: - global dispatcher_fiber_ - dispatcher_fiber_ = fiber - - -def dispatcher_fiber() -> greenlet: - return dispatcher_fiber_ - class EventInfo(Generic[T]): - def __init__(self, loop: asyncio.AbstractEventLoop, coroutine: Coroutine) -> None: - self._loop = loop + def __init__(self, sync_base: "SyncBase", coroutine: Coroutine) -> None: + self._sync_base = sync_base self._value: Optional[T] = None self._exception = None - self._future = loop.create_task(coroutine) + self._future = sync_base._loop.create_task(coroutine) g_self = greenlet.getcurrent() def done_callback(task: Any) -> None: @@ -67,16 +56,16 @@ def done_callback(task: Any) -> None: @property def value(self) -> T: while not self._future.done(): - dispatcher_fiber_.switch() - asyncio._set_running_loop(self._loop) + self._sync_base._dispatcher_fiber.switch() + asyncio._set_running_loop(self._sync_base._loop) if self._exception: raise self._exception return cast(T, self._value) class EventContextManager(Generic[T]): - def __init__(self, loop: asyncio.AbstractEventLoop, coroutine: Coroutine) -> None: - self._event: EventInfo = EventInfo(loop, coroutine) + def __init__(self, sync_base: "SyncBase", coroutine: Coroutine) -> None: + self._event: EventInfo = EventInfo(sync_base, coroutine) def __enter__(self) -> EventInfo[T]: return self._event @@ -89,6 +78,7 @@ class SyncBase(ImplWrapper): def __init__(self, impl_obj: Any) -> None: super().__init__(impl_obj) self._loop = impl_obj._loop + self._dispatcher_fiber = impl_obj._dispatcher_fiber def __str__(self) -> str: return self._impl_obj.__str__() @@ -102,7 +92,7 @@ def callback(result: Any) -> None: future.add_done_callback(callback) while not future.done(): - dispatcher_fiber_.switch() + self._dispatcher_fiber.switch() asyncio._set_running_loop(self._loop) return future.result() @@ -149,7 +139,7 @@ async def task() -> None: self._loop.create_task(task()) while len(results) < len(actions): - dispatcher_fiber_.switch() + self._dispatcher_fiber.switch() asyncio._set_running_loop(self._loop) if exceptions: diff --git a/playwright/transport.py b/playwright/transport.py index 0a2121bee..3be4c8217 100644 --- a/playwright/transport.py +++ b/playwright/transport.py @@ -15,42 +15,47 @@ import asyncio import json import os +import sys +from pathlib import Path from typing import Dict class Transport: - def __init__( - self, - input: asyncio.StreamReader, - output: asyncio.StreamWriter, - loop: asyncio.AbstractEventLoop, - ) -> None: + def __init__(self, driver_executable: Path) -> None: super().__init__() - self._input: asyncio.StreamReader = input - self._output: asyncio.StreamWriter = output - self.loop: asyncio.AbstractEventLoop = loop self.on_message = lambda _: None self._stopped = False - - def run_sync(self) -> None: - self.loop.run_until_complete(self._run()) - - def run_async(self) -> None: - self.loop.create_task(self._run()) + self._driver_executable = driver_executable + self._loop: asyncio.AbstractEventLoop def stop(self) -> None: self._stopped = True self._output.close() - async def _run(self) -> None: + async def run(self) -> None: + self._loop = asyncio.get_running_loop() + driver_executable = self._driver_executable + + proc = await asyncio.create_subprocess_exec( + str(driver_executable), + "run-driver", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=sys.stderr, + limit=32768, + ) + assert proc.stdout + assert proc.stdin + self._output = proc.stdin + while not self._stopped: try: - buffer = await self._input.readexactly(4) + buffer = await proc.stdout.readexactly(4) length = int.from_bytes(buffer, byteorder="little", signed=False) buffer = bytes(0) while length: to_read = min(length, 32768) - data = await self._input.readexactly(to_read) + data = await proc.stdout.readexactly(to_read) length -= to_read if len(buffer): buffer = buffer + data diff --git a/playwright/video.py b/playwright/video.py index 6af5b0fa3..1b9b24f2b 100644 --- a/playwright/video.py +++ b/playwright/video.py @@ -22,6 +22,7 @@ class Video: def __init__(self, page: "Page") -> None: self._loop = page._loop + self._dispatcher_fiber = page._dispatcher_fiber self._page = page self._path_future = page._loop.create_future() diff --git a/scripts/generate_sync_api.py b/scripts/generate_sync_api.py index fb89fdb95..2b5387c93 100755 --- a/scripts/generate_sync_api.py +++ b/scripts/generate_sync_api.py @@ -145,7 +145,7 @@ def generate(t: Any) -> None: print(f' event = "{event_name}"') print( - f" return EventContextManager(self._loop, self._impl_obj.{wait_for_method})" + f" return EventContextManager(self, self._impl_obj.{wait_for_method})" ) print("") diff --git a/tests/server.py b/tests/server.py index e69a42f5f..76995ed3b 100644 --- a/tests/server.py +++ b/tests/server.py @@ -14,7 +14,6 @@ import abc import asyncio -import contextlib import gzip import mimetypes import socket @@ -22,14 +21,12 @@ from contextlib import closing from http import HTTPStatus -import greenlet from autobahn.twisted.websocket import WebSocketServerFactory, WebSocketServerProtocol from OpenSSL import crypto from twisted.internet import reactor, ssl from twisted.web import http from playwright.path_utils import get_file_dirname -from playwright.sync_base import dispatcher_fiber _dirname = get_file_dirname() @@ -140,30 +137,6 @@ async def wait_for_request(self, path): self.request_subscribers[path] = future return await future - @contextlib.contextmanager - def expect_request(self, path): - future = asyncio.create_task(self.wait_for_request(path)) - - class CallbackValue: - def __init__(self) -> None: - self._value = None - - @property - def value(self): - return self._value - - g_self = greenlet.getcurrent() - cb_wrapper = CallbackValue() - - def done_cb(task): - cb_wrapper._value = future.result() - g_self.switch() - - future.add_done_callback(done_cb) - yield cb_wrapper - while not future.done(): - dispatcher_fiber.switch() - def set_auth(self, path: str, username: str, password: str): self.auth[path] = (username, password)