Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions playwright/_impl/_async_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.

import asyncio
import traceback
from types import TracebackType
from typing import Any, Awaitable, Callable, Generic, Type, TypeVar
from typing import Any, Callable, Generic, Type, TypeVar

from playwright._impl._impl_to_api_mapping import ImplToApiMapping, ImplWrapper

Expand Down Expand Up @@ -62,12 +61,6 @@ def __init__(self, impl_obj: Any) -> None:
def __str__(self) -> str:
return self._impl_obj.__str__()

def _async(self, api_name: str, coro: Awaitable) -> Any:
task = asyncio.current_task()
setattr(task, "__pw_api_name__", api_name)
setattr(task, "__pw_stack_trace__", traceback.extract_stack())
return coro

def _wrap_handler(self, handler: Any) -> Callable[..., None]:
if callable(handler):
return mapping.wrap_handler(handler)
Expand Down
109 changes: 82 additions & 27 deletions playwright/_impl/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
# limitations under the License.

import asyncio
import contextvars
import inspect
import sys
import traceback
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast

from greenlet import greenlet
from pyee import AsyncIOEventEmitter, EventEmitter

import playwright
from playwright._impl._helper import ParsedMessagePayload, parse_error
from playwright._impl._transport import Transport

Expand All @@ -36,10 +39,21 @@ def __init__(self, connection: "Connection", guid: str) -> None:
self._object: Optional[ChannelOwner] = None

async def send(self, method: str, params: Dict = None) -> Any:
return await self.inner_send(method, params, False)
return await self._connection.wrap_api_call(
lambda: self.inner_send(method, params, False)
)

async def send_return_as_dict(self, method: str, params: Dict = None) -> Any:
return await self.inner_send(method, params, True)
return await self._connection.wrap_api_call(
lambda: self.inner_send(method, params, True)
)

def send_no_reply(self, method: str, params: Dict = None) -> None:
self._connection.wrap_api_call(
lambda: self._connection._send_message_to_server(
self._guid, method, {} if params is None else params
)
)

async def inner_send(
self, method: str, params: Optional[Dict], return_as_dict: bool
Expand Down Expand Up @@ -74,11 +88,6 @@ async def inner_send(
key = next(iter(result))
return result[key]

def send_no_reply(self, method: str, params: Dict = None) -> None:
if params is None:
params = {}
self._connection._send_message_to_server(self._guid, method, params)


class ChannelOwner(AsyncIOEventEmitter):
def __init__(
Expand Down Expand Up @@ -122,7 +131,7 @@ def _dispose(self) -> None:

class ProtocolCallback:
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self.stack_trace: traceback.StackSummary = traceback.StackSummary()
self.stack_trace: traceback.StackSummary
self.future = loop.create_future()
# The outer task can get cancelled by the user, this forwards the cancellation to the inner task.
current_task = asyncio.current_task()
Expand Down Expand Up @@ -181,6 +190,9 @@ def __init__(
self._error: Optional[BaseException] = None
self.is_remote = False
self._init_task: Optional[asyncio.Task] = None
self._api_zone: contextvars.ContextVar[Optional[Dict]] = contextvars.ContextVar(
"ApiZone", default=None
)

def mark_as_remote(self) -> None:
self.is_remote = True
Expand Down Expand Up @@ -230,22 +242,17 @@ def _send_message_to_server(
id = self._last_id
callback = ProtocolCallback(self._loop)
task = asyncio.current_task(self._loop)
stack_trace: Optional[traceback.StackSummary] = getattr(
task, "__pw_stack_trace__", None
callback.stack_trace = cast(
traceback.StackSummary,
getattr(task, "__pw_stack_trace__", traceback.extract_stack()),
)
callback.stack_trace = stack_trace or traceback.extract_stack()
self._callbacks[id] = callback
metadata = {"stack": serialize_call_stack(callback.stack_trace)}
api_name = getattr(task, "__pw_api_name__", None)
if api_name:
metadata["apiName"] = api_name

message = {
"id": id,
"guid": guid,
"method": method,
"params": self._replace_channels_with_guids(params),
"metadata": metadata,
"metadata": self._api_zone.get(),
}
self._transport.send(message)
self._callbacks[id] = callback
Expand Down Expand Up @@ -337,6 +344,27 @@ def _replace_guids_with_channels(self, payload: Any) -> Any:
return result
return payload

def wrap_api_call(self, cb: Callable[[], Any], is_internal: bool = False) -> Any:
if self._api_zone.get():
return cb()
task = asyncio.current_task(self._loop)
st: List[inspect.FrameInfo] = getattr(task, "__pw_stack__", inspect.stack())
metadata = _extract_metadata_from_stack(st, is_internal)
if metadata:
self._api_zone.set(metadata)
result = cb()

async def _() -> None:
try:
return await result
finally:
self._api_zone.set(None)

if asyncio.iscoroutine(result):
return _()
self._api_zone.set(None)
return result


def from_channel(channel: Channel) -> Any:
return channel._object
Expand All @@ -346,13 +374,40 @@ def from_nullable_channel(channel: Optional[Channel]) -> Optional[Any]:
return channel._object if channel else None


def serialize_call_stack(stack_trace: traceback.StackSummary) -> List[Dict]:
def _extract_metadata_from_stack(
st: List[inspect.FrameInfo], is_internal: bool
) -> Optional[Dict]:
playwright_module_path = str(Path(playwright.__file__).parents[0])
last_internal_api_name = ""
api_name = ""
stack: List[Dict] = []
for frame in stack_trace:
if "_generated.py" in frame.filename:
break
stack.append(
{"file": frame.filename, "line": frame.lineno, "function": frame.name}
)
stack.reverse()
return stack
for frame in st:
is_playwright_internal = frame.filename.startswith(playwright_module_path)

method_name = ""
if "self" in frame[0].f_locals:
method_name = frame[0].f_locals["self"].__class__.__name__ + "."
method_name += frame[0].f_code.co_name

if not is_playwright_internal:
stack.append(
{
"file": frame.filename,
"line": frame.lineno,
"function": method_name,
}
)
if is_playwright_internal:
last_internal_api_name = method_name
elif last_internal_api_name:
api_name = last_internal_api_name
last_internal_api_name = ""
if not api_name:
api_name = last_internal_api_name
if api_name:
return {
"apiName": api_name,
"stack": stack,
"isInternal": is_internal,
}
return None
7 changes: 4 additions & 3 deletions playwright/_impl/_sync_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import inspect
import traceback
from types import TracebackType
from typing import Any, Awaitable, Callable, Dict, Generic, List, Type, TypeVar, cast
Expand Down Expand Up @@ -74,11 +75,11 @@ def __init__(self, impl_obj: Any) -> None:
def __str__(self) -> str:
return self._impl_obj.__str__()

def _sync(self, api_name: str, coro: Awaitable) -> Any:
def _sync(self, coro: Awaitable) -> Any:
__tracebackhide__ = True
g_self = greenlet.getcurrent()
task = self._loop.create_task(coro)
setattr(task, "__pw_api_name__", api_name)
setattr(task, "__pw_stack__", inspect.stack())
setattr(task, "__pw_stack_trace__", traceback.extract_stack())

task.add_done_callback(lambda _: g_self.switch())
Expand Down Expand Up @@ -147,7 +148,7 @@ def __exit__(
self,
exc_type: Type[BaseException],
exc_val: BaseException,
traceback: TracebackType,
_traceback: TracebackType,
) -> None:
self.close()

Expand Down
41 changes: 21 additions & 20 deletions playwright/_impl/_wait_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,19 @@ def _wait_for_event_info_before(self, wait_id: str, event: str) -> None:
)

def _wait_for_event_info_after(self, wait_id: str, error: Exception = None) -> None:
try:
info = {
"waitId": wait_id,
"phase": "after",
}
if error:
info["error"] = str(error)
self._channel.send_no_reply(
self._channel._connection.wrap_api_call(
lambda: self._channel.send_no_reply(
"waitForEventInfo",
{
"info": info,
"info": {
"waitId": wait_id,
"phase": "after",
**({"error": str(error)} if error else {}),
},
},
)
except Exception:
pass
),
True,
)

def reject_on_event(
self,
Expand Down Expand Up @@ -129,15 +127,18 @@ def result(self) -> asyncio.Future:
def log(self, message: str) -> None:
self._logs.append(message)
try:
self._channel.send_no_reply(
"waitForEventInfo",
{
"info": {
"waitId": self._wait_id,
"phase": "log",
"message": message,
self._channel._connection.wrap_api_call(
lambda: self._channel.send_no_reply(
"waitForEventInfo",
{
"info": {
"waitId": self._wait_id,
"phase": "log",
"message": message,
},
},
},
),
True,
)
except Exception:
pass
Expand Down
Loading