From 8d2e112b40dd4916afd37534d57d46043db240c5 Mon Sep 17 00:00:00 2001 From: Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com> Date: Sat, 3 Jul 2021 16:34:11 +0530 Subject: [PATCH] chore(types): improve typing --- playwright/_impl/_artifact.py | 4 ++-- playwright/_impl/_async_base.py | 30 +++++++++++++++++++---------- playwright/_impl/_browser_type.py | 13 ++++++------- playwright/_impl/_connection.py | 5 +++-- playwright/_impl/_frame.py | 11 ++++------- playwright/_impl/_network.py | 2 +- playwright/_impl/_object_factory.py | 4 ++-- playwright/_impl/_page.py | 6 ++++-- playwright/_impl/_sync_base.py | 30 ++++++++++++++++++----------- playwright/_impl/_tracing.py | 6 ++---- 10 files changed, 63 insertions(+), 48 deletions(-) diff --git a/playwright/_impl/_artifact.py b/playwright/_impl/_artifact.py index 9c3afc3b5..9bf442f50 100644 --- a/playwright/_impl/_artifact.py +++ b/playwright/_impl/_artifact.py @@ -14,7 +14,7 @@ import pathlib from pathlib import Path -from typing import Dict, Optional, Union, cast +from typing import Dict, Optional, Union from playwright._impl._connection import ChannelOwner, from_channel from playwright._impl._helper import Error, make_dirs_for_file, patch_error_message @@ -37,7 +37,7 @@ async def path_after_finished(self) -> Optional[pathlib.Path]: return pathlib.Path(await self._channel.send("pathAfterFinished")) async def save_as(self, path: Union[str, Path]) -> None: - stream = cast(Stream, from_channel(await self._channel.send("saveAsStream"))) + stream: Stream = from_channel(await self._channel.send("saveAsStream")) make_dirs_for_file(path) await stream.save_as(path) diff --git a/playwright/_impl/_async_base.py b/playwright/_impl/_async_base.py index 587d0d200..4d678f668 100644 --- a/playwright/_impl/_async_base.py +++ b/playwright/_impl/_async_base.py @@ -15,7 +15,7 @@ import asyncio import traceback from types import TracebackType -from typing import Any, Awaitable, Callable, Generic, Type, TypeVar +from typing import Any, Awaitable, Callable, Generic, Type, TypeVar, Union from playwright._impl._impl_to_api_mapping import ImplToApiMapping, ImplWrapper @@ -23,11 +23,11 @@ T = TypeVar("T") -Self = TypeVar("Self", bound="AsyncBase") +Self = TypeVar("Self", bound="AsyncContextManager") class AsyncEventInfo(Generic[T]): - def __init__(self, future: asyncio.Future) -> None: + def __init__(self, future: "asyncio.Future[T]") -> None: self._future = future @property @@ -39,13 +39,18 @@ def is_done(self) -> bool: class AsyncEventContextManager(Generic[T]): - def __init__(self, future: asyncio.Future) -> None: - self._event: AsyncEventInfo = AsyncEventInfo(future) + def __init__(self, future: "asyncio.Future[T]") -> None: + self._event = AsyncEventInfo[T](future) async def __aenter__(self) -> AsyncEventInfo[T]: return self._event - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + async def __aexit__( + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: await self._event.value @@ -68,17 +73,19 @@ def _wrap_handler(self, handler: Any) -> Callable[..., None]: return mapping.wrap_handler(handler) return handler - def on(self, event: str, f: Any) -> None: + def on(self, event: str, f: Callable[..., Union[Awaitable[None], None]]) -> None: """Registers the function ``f`` to the event name ``event``.""" self._impl_obj.on(event, self._wrap_handler(f)) - def once(self, event: str, f: Any) -> None: + def once(self, event: str, f: Callable[..., Union[Awaitable[None], None]]) -> None: """The same as ``self.on``, except that the listener is automatically removed after being called. """ self._impl_obj.once(event, self._wrap_handler(f)) - def remove_listener(self, event: str, f: Any) -> None: + def remove_listener( + self, event: str, f: Callable[..., Union[Awaitable[None], None]] + ) -> None: """Removes the function ``f`` from ``event``.""" self._impl_obj.remove_listener(event, self._wrap_handler(f)) @@ -93,4 +100,7 @@ async def __aexit__( exc_val: BaseException, traceback: TracebackType, ) -> None: - await self.close() # type: ignore + await self.close() + + async def close(self: Self) -> None: + ... diff --git a/playwright/_impl/_browser_type.py b/playwright/_impl/_browser_type.py index 07e0ad724..08d0c13c2 100644 --- a/playwright/_impl/_browser_type.py +++ b/playwright/_impl/_browser_type.py @@ -15,7 +15,7 @@ import asyncio import pathlib from pathlib import Path -from typing import Dict, List, Optional, Union, cast +from typing import Dict, List, Optional, Union from playwright._impl._api_structures import ( Geolocation, @@ -138,7 +138,7 @@ async def launch_persistent_context( await normalize_context_params(self._connection._is_sync, params) normalize_launch_params(params) try: - context = from_channel( + context: BrowserContext = from_channel( await self._channel.send("launchPersistentContext", params) ) context._options = params @@ -160,12 +160,11 @@ async def connect_over_cdp( "python" if self._connection._is_sync else "python-async" ) response = await self._channel.send_return_as_dict("connectOverCDP", params) - browser = cast(Browser, from_channel(response["browser"])) + browser: Browser = from_channel(response["browser"]) browser._is_remote = True - default_context = cast( - Optional[BrowserContext], - from_nullable_channel(response.get("defaultContext")), + default_context: Optional[BrowserContext] = from_nullable_channel( + response.get("defaultContext") ) if default_context: browser._contexts.append(default_context) @@ -209,7 +208,7 @@ async def connect( self._connection._child_ws_connections.append(connection) pre_launched_browser = playwright._initializer.get("preLaunchedBrowser") assert pre_launched_browser - browser = cast(Browser, from_channel(pre_launched_browser)) + browser: Browser = from_channel(pre_launched_browser) browser._is_remote = True browser._is_connected_over_websocket = True diff --git a/playwright/_impl/_connection.py b/playwright/_impl/_connection.py index 3ccba55f8..1f73ea107 100644 --- a/playwright/_impl/_connection.py +++ b/playwright/_impl/_connection.py @@ -319,11 +319,12 @@ def _replace_guids_with_channels(self, payload: Any) -> Any: return payload -def from_channel(channel: Channel) -> Any: +def from_channel(channel: Channel) -> ChannelOwner: + assert channel._object return channel._object -def from_nullable_channel(channel: Optional[Channel]) -> Optional[Any]: +def from_nullable_channel(channel: Optional[Channel]) -> Optional[ChannelOwner]: return channel._object if channel else None diff --git a/playwright/_impl/_frame.py b/playwright/_impl/_frame.py index 34825b03c..853c550d0 100644 --- a/playwright/_impl/_frame.py +++ b/playwright/_impl/_frame.py @@ -15,7 +15,7 @@ import asyncio import sys from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union from pyee import EventEmitter @@ -112,11 +112,8 @@ async def goto( waitUntil: DocumentLoadState = None, referer: str = None, ) -> Optional[Response]: - return cast( - Optional[Response], - from_nullable_channel( - await self._channel.send("goto", locals_to_params(locals())) - ), + return from_nullable_channel( + await self._channel.send("goto", locals_to_params(locals())) ) def _setup_navigation_wait_helper( @@ -247,7 +244,7 @@ async def query_selector(self, selector: str) -> Optional[ElementHandle]: async def query_selector_all(self, selector: str) -> List[ElementHandle]: return list( map( - cast(ElementHandle, from_channel), + from_channel, await self._channel.send("querySelectorAll", dict(selector=selector)), ) ) diff --git a/playwright/_impl/_network.py b/playwright/_impl/_network.py index 59738b609..2ad56dd2c 100644 --- a/playwright/_impl/_network.py +++ b/playwright/_impl/_network.py @@ -204,7 +204,7 @@ async def continue_( overrides["postData"] = base64.b64encode(postData.encode()).decode() elif isinstance(postData, bytes): overrides["postData"] = base64.b64encode(postData).decode() - await self._channel.send("continue", cast(Any, overrides)) + await self._channel.send("continue", cast(Dict, overrides)) class Response(ChannelOwner): diff --git a/playwright/_impl/_object_factory.py b/playwright/_impl/_object_factory.py index a9c3b9fbb..36f4b8da3 100644 --- a/playwright/_impl/_object_factory.py +++ b/playwright/_impl/_object_factory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, cast +from typing import Dict from playwright._impl._artifact import Artifact from playwright._impl._browser import Browser @@ -47,7 +47,7 @@ def create_remote_object( if type == "BindingCall": return BindingCall(parent, type, guid, initializer) if type == "Browser": - return Browser(cast(BrowserType, parent), type, guid, initializer) + return Browser(parent, type, guid, initializer) if type == "BrowserType": return BrowserType(parent, type, guid, initializer) if type == "BrowserContext": diff --git a/playwright/_impl/_page.py b/playwright/_impl/_page.py index 10c8c4237..138e2bb63 100644 --- a/playwright/_impl/_page.py +++ b/playwright/_impl/_page.py @@ -134,7 +134,9 @@ def __init__( self._browser_context._timeout_settings ) self._video: Optional[Video] = None - self._opener = cast("Page", from_nullable_channel(initializer.get("opener"))) + self._opener: Optional["Page"] = from_nullable_channel( + initializer.get("opener") + ) self._channel.on( "bindingCall", @@ -248,7 +250,7 @@ def _on_dialog(self, params: Any) -> None: def _on_download(self, params: Any) -> None: url = params["url"] suggested_filename = params["suggestedFilename"] - artifact = cast(Artifact, from_channel(params["artifact"])) + artifact: Artifact = from_channel(params["artifact"]) if self._browser_context._browser: artifact._is_remote = self._browser_context._browser._is_remote self.emit( diff --git a/playwright/_impl/_sync_base.py b/playwright/_impl/_sync_base.py index 005161a51..4e0272420 100644 --- a/playwright/_impl/_sync_base.py +++ b/playwright/_impl/_sync_base.py @@ -36,18 +36,18 @@ T = TypeVar("T") -Self = TypeVar("Self") +Self = TypeVar("Self", bound="SyncContextManager") class EventInfo(Generic[T]): - def __init__(self, sync_base: "SyncBase", future: asyncio.Future) -> None: + def __init__(self, sync_base: "SyncBase", future: "asyncio.Future[T]") -> None: self._sync_base = sync_base self._value: Optional[T] = None - self._exception = None + self._exception: Optional[Exception] = None self._future = future g_self = greenlet.getcurrent() - def done_callback(task: Any) -> None: + def done_callback(task: "asyncio.Future[T]") -> None: try: self._value = mapping.from_maybe_impl(self._future.result()) except Exception as e: @@ -71,13 +71,18 @@ def is_done(self) -> bool: class EventContextManager(Generic[T]): - def __init__(self, sync_base: "SyncBase", future: asyncio.Future) -> None: - self._event: EventInfo = EventInfo(sync_base, future) + def __init__(self, sync_base: "SyncBase", future: "asyncio.Future[T]") -> None: + self._event = EventInfo[T](sync_base, future) def __enter__(self) -> EventInfo[T]: return self._event - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: self._event.value @@ -110,17 +115,17 @@ def _wrap_handler(self, handler: Any) -> Callable[..., None]: return mapping.wrap_handler(handler) return handler - def on(self, event: str, f: Any) -> None: + def on(self, event: str, f: Callable[..., None]) -> None: """Registers the function ``f`` to the event name ``event``.""" self._impl_obj.on(event, self._wrap_handler(f)) - def once(self, event: str, f: Any) -> None: + def once(self, event: str, f: Callable[..., None]) -> None: """The same as ``self.on``, except that the listener is automatically removed after being called. """ self._impl_obj.once(event, self._wrap_handler(f)) - def remove_listener(self, event: str, f: Any) -> None: + def remove_listener(self, event: str, f: Callable[..., None]) -> None: """Removes the function ``f`` from ``event``.""" self._impl_obj.remove_listener(event, self._wrap_handler(f)) @@ -167,4 +172,7 @@ def __exit__( exc_val: BaseException, traceback: TracebackType, ) -> None: - self.close() # type: ignore + self.close() + + def close(self: Self) -> None: + ... diff --git a/playwright/_impl/_tracing.py b/playwright/_impl/_tracing.py index 45b9bb0da..b5ae60cdd 100644 --- a/playwright/_impl/_tracing.py +++ b/playwright/_impl/_tracing.py @@ -13,7 +13,7 @@ # limitations under the License. import pathlib -from typing import TYPE_CHECKING, Union, cast +from typing import TYPE_CHECKING, Union from playwright._impl._artifact import Artifact from playwright._impl._connection import from_channel @@ -39,9 +39,7 @@ async def start( async def stop(self, path: Union[pathlib.Path, str] = None) -> None: await self._channel.send("tracingStop") if path: - artifact = cast( - Artifact, from_channel(await self._channel.send("tracingExport")) - ) + artifact: Artifact = from_channel(await self._channel.send("tracingExport")) if self._context._browser: artifact._is_remote = self._context._browser._is_remote await artifact.save_as(path)