|
18 | 18 | import json
|
19 | 19 | import json as json_utils
|
20 | 20 | import mimetypes
|
| 21 | +import re |
21 | 22 | from collections import defaultdict
|
22 | 23 | from pathlib import Path
|
23 | 24 | from types import SimpleNamespace
|
|
46 | 47 | )
|
47 | 48 | from playwright._impl._connection import (
|
48 | 49 | ChannelOwner,
|
| 50 | + Connection, |
49 | 51 | from_channel,
|
50 | 52 | from_nullable_channel,
|
51 | 53 | )
|
52 | 54 | from playwright._impl._errors import Error
|
53 | 55 | from playwright._impl._event_context_manager import EventContextManagerImpl
|
54 |
| -from playwright._impl._helper import async_readfile, locals_to_params |
| 56 | +from playwright._impl._helper import ( |
| 57 | + URLMatcher, |
| 58 | + WebSocketRouteHandlerCallback, |
| 59 | + async_readfile, |
| 60 | + locals_to_params, |
| 61 | +) |
| 62 | +from playwright._impl._str_utils import escape_regex_flags |
55 | 63 | from playwright._impl._waiter import Waiter
|
56 | 64 |
|
57 | 65 | if TYPE_CHECKING: # pragma: no cover
|
@@ -548,6 +556,214 @@ async def _race_with_page_close(self, future: Coroutine) -> None:
|
548 | 556 | await asyncio.gather(fut, return_exceptions=True)
|
549 | 557 |
|
550 | 558 |
|
| 559 | +class ServerWebSocketRoute: |
| 560 | + def __init__(self, ws: "WebSocketRoute"): |
| 561 | + self._ws = ws |
| 562 | + |
| 563 | + def on_message(self, handler: Callable[[Union[str, bytes]], Any]) -> None: |
| 564 | + self._ws._on_server_message = handler |
| 565 | + |
| 566 | + def on_close(self, handler: Callable[[Optional[int], Optional[str]], Any]) -> None: |
| 567 | + self._ws._on_server_close = handler |
| 568 | + |
| 569 | + def connect_to_server(self) -> None: |
| 570 | + raise NotImplementedError( |
| 571 | + "connectToServer must be called on the page-side WebSocketRoute" |
| 572 | + ) |
| 573 | + |
| 574 | + @property |
| 575 | + def url(self) -> str: |
| 576 | + return self._ws._initializer["url"] |
| 577 | + |
| 578 | + def close(self, code: int = None, reason: str = None) -> None: |
| 579 | + try: |
| 580 | + asyncio.create_task( |
| 581 | + self._ws._channel.send( |
| 582 | + "close", |
| 583 | + { |
| 584 | + "code": code, |
| 585 | + "reason": reason, |
| 586 | + }, |
| 587 | + ) |
| 588 | + ) |
| 589 | + except: |
| 590 | + pass |
| 591 | + |
| 592 | + def send(self, message: Union[str, bytes]) -> None: |
| 593 | + if isinstance(message, str): |
| 594 | + asyncio.create_task( |
| 595 | + self._ws._channel.send( |
| 596 | + "sendToServer", {"message": message, "isBase64": False} |
| 597 | + ) |
| 598 | + ) |
| 599 | + else: |
| 600 | + asyncio.create_task( |
| 601 | + self._ws._channel.send( |
| 602 | + "sendToServer", |
| 603 | + {"message": base64.b64encode(message).decode(), "isBase64": True}, |
| 604 | + ) |
| 605 | + ) |
| 606 | + |
| 607 | + |
| 608 | +class WebSocketRoute(ChannelOwner): |
| 609 | + def __init__( |
| 610 | + self, parent: ChannelOwner, type: str, guid: str, initializer: Dict |
| 611 | + ) -> None: |
| 612 | + super().__init__(parent, type, guid, initializer) |
| 613 | + self._on_page_message: Optional[Callable[[Union[str, bytes]], Any]] = None |
| 614 | + self._on_page_close: Optional[ |
| 615 | + Callable[[Optional[int], Optional[str]], Any] |
| 616 | + ] = None |
| 617 | + self._on_server_message: Optional[Callable[[Union[str, bytes]], Any]] = None |
| 618 | + self._on_server_close: Optional[ |
| 619 | + Callable[[Optional[int], Optional[str]], Any] |
| 620 | + ] = None |
| 621 | + self._server = ServerWebSocketRoute(self) |
| 622 | + self._connected = False |
| 623 | + |
| 624 | + self._channel.on("messageFromPage", self._channel_message_from_page) |
| 625 | + self._channel.on("messageFromServer", self._channel_message_from_server) |
| 626 | + self._channel.on("closePage", self._channel_close_page) |
| 627 | + self._channel.on("closeServer", self._channel_close_server) |
| 628 | + |
| 629 | + def _channel_message_from_page(self, event: Dict) -> None: |
| 630 | + if self._on_page_message: |
| 631 | + self._on_page_message( |
| 632 | + base64.b64decode(event["message"]) |
| 633 | + if event["isBase64"] |
| 634 | + else event["message"] |
| 635 | + ) |
| 636 | + elif self._connected: |
| 637 | + try: |
| 638 | + asyncio.create_task(self._channel.send("sendToServer", event)) |
| 639 | + except: |
| 640 | + pass |
| 641 | + |
| 642 | + def _channel_message_from_server(self, event: Dict) -> None: |
| 643 | + if self._on_server_message: |
| 644 | + self._on_server_message( |
| 645 | + base64.b64decode(event["message"]) |
| 646 | + if event["isBase64"] |
| 647 | + else event["message"] |
| 648 | + ) |
| 649 | + else: |
| 650 | + try: |
| 651 | + asyncio.create_task(self._channel.send("sendToPage", event)) |
| 652 | + except: |
| 653 | + pass |
| 654 | + |
| 655 | + def _channel_close_page(self, event: Dict) -> None: |
| 656 | + if self._on_page_close: |
| 657 | + self._on_page_close(event["code"], event["reason"]) |
| 658 | + else: |
| 659 | + try: |
| 660 | + asyncio.create_task(self._channel.send("closeServer", event)) |
| 661 | + except: |
| 662 | + pass |
| 663 | + |
| 664 | + def _channel_close_server(self, event: Dict) -> None: |
| 665 | + if self._on_server_close: |
| 666 | + self._on_server_close(event["code"], event["reason"]) |
| 667 | + else: |
| 668 | + try: |
| 669 | + asyncio.create_task(self._channel.send("closePage", event)) |
| 670 | + except: |
| 671 | + pass |
| 672 | + |
| 673 | + @property |
| 674 | + def url(self) -> str: |
| 675 | + return self._initializer["url"] |
| 676 | + |
| 677 | + async def close(self, code: int = None, reason: str = None) -> None: |
| 678 | + try: |
| 679 | + await self._channel.send( |
| 680 | + "closePage", {"code": code, "reason": reason, "wasClean": True} |
| 681 | + ) |
| 682 | + except: |
| 683 | + pass |
| 684 | + |
| 685 | + def connect_to_server(self) -> "WebSocketRoute": |
| 686 | + if self._connected: |
| 687 | + raise Error("Already connected to the server") |
| 688 | + self._connected = True |
| 689 | + asyncio.create_task(self._channel.send("connect")) |
| 690 | + return cast("WebSocketRoute", self._server) |
| 691 | + |
| 692 | + def send(self, message: Union[str, bytes]) -> None: |
| 693 | + if isinstance(message, str): |
| 694 | + try: |
| 695 | + asyncio.create_task( |
| 696 | + self._channel.send( |
| 697 | + "sendToPage", {"message": message, "isBase64": False} |
| 698 | + ) |
| 699 | + ) |
| 700 | + except: |
| 701 | + pass |
| 702 | + else: |
| 703 | + try: |
| 704 | + asyncio.create_task( |
| 705 | + self._channel.send( |
| 706 | + "sendToPage", |
| 707 | + { |
| 708 | + "message": base64.b64encode(message).decode(), |
| 709 | + "isBase64": True, |
| 710 | + }, |
| 711 | + ) |
| 712 | + ) |
| 713 | + except: |
| 714 | + pass |
| 715 | + |
| 716 | + def on_message(self, handler: Callable[[Union[str, bytes]], Any]) -> None: |
| 717 | + self._on_page_message = handler |
| 718 | + |
| 719 | + def on_close(self, handler: Callable[[Optional[int], Optional[str]], Any]) -> None: |
| 720 | + self._on_page_close = handler |
| 721 | + |
| 722 | + async def _after_handle(self) -> None: |
| 723 | + if self._connected: |
| 724 | + return |
| 725 | + # Ensure that websocket is "open" and can send messages without an actual server connection. |
| 726 | + await self._channel.send("ensureOpened") |
| 727 | + |
| 728 | + |
| 729 | +class WebSocketRouteHandler: |
| 730 | + def __init__(self, matcher: URLMatcher, handler: WebSocketRouteHandlerCallback): |
| 731 | + self.matcher = matcher |
| 732 | + self.handler = handler |
| 733 | + |
| 734 | + @staticmethod |
| 735 | + def prepare_interception_patterns( |
| 736 | + handlers: List["WebSocketRouteHandler"], |
| 737 | + ) -> List[dict]: |
| 738 | + patterns = [] |
| 739 | + all_urls = False |
| 740 | + for handler in handlers: |
| 741 | + if isinstance(handler.matcher.match, str): |
| 742 | + patterns.append({"glob": handler.matcher.match}) |
| 743 | + elif isinstance(handler.matcher._regex_obj, re.Pattern): |
| 744 | + patterns.append( |
| 745 | + { |
| 746 | + "regexSource": handler.matcher._regex_obj.pattern, |
| 747 | + "regexFlags": escape_regex_flags(handler.matcher._regex_obj), |
| 748 | + } |
| 749 | + ) |
| 750 | + else: |
| 751 | + all_urls = True |
| 752 | + |
| 753 | + if all_urls: |
| 754 | + return [{"glob": "**/*"}] |
| 755 | + return patterns |
| 756 | + |
| 757 | + def matches(self, ws_url: str) -> bool: |
| 758 | + return self.matcher.matches(ws_url) |
| 759 | + |
| 760 | + async def handle(self, websocket_route: "WebSocketRoute") -> None: |
| 761 | + maybe_future = self.handler(websocket_route) |
| 762 | + if maybe_future: |
| 763 | + breakpoint() |
| 764 | + await websocket_route._after_handle() |
| 765 | + |
| 766 | + |
551 | 767 | class Response(ChannelOwner):
|
552 | 768 | def __init__(
|
553 | 769 | self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
|
|
0 commit comments