|
| 1 | +from collections.abc import Awaitable |
| 2 | +import re |
| 3 | +from typing import ( |
| 4 | + TYPE_CHECKING, |
| 5 | + Any, |
| 6 | + Callable, |
| 7 | + Generic, |
| 8 | + Optional, |
| 9 | + TypeVar, |
| 10 | + Union, |
| 11 | + cast, |
| 12 | + overload, |
| 13 | +) |
| 14 | +from typing_extensions import ParamSpec, Self |
| 15 | + |
| 16 | +import httpx |
| 17 | + |
| 18 | +from githubkit.response import Response |
| 19 | +from githubkit.utils import is_async |
| 20 | + |
| 21 | +if TYPE_CHECKING: |
| 22 | + from githubkit.versions import RestVersionSwitcher |
| 23 | + |
| 24 | +CP = ParamSpec("CP") |
| 25 | +CT = TypeVar("CT") |
| 26 | +RT = TypeVar("RT") |
| 27 | +RTS = TypeVar("RTS") |
| 28 | + |
| 29 | +R = Union[ |
| 30 | + Callable[CP, Response[RT]], |
| 31 | + Callable[CP, Awaitable[Response[RT]]], |
| 32 | +] |
| 33 | + |
| 34 | +# https://github.com/octokit/plugin-paginate-rest.js/blob/1f44b5469b31ddec9621000e6e1aee63c71ea8bf/src/iterator.ts#L40 |
| 35 | +NEXT_LINK_PATTERN = r'<([^<>]+)>;\s*rel="next"' |
| 36 | + |
| 37 | + |
| 38 | +# https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api |
| 39 | +# https://github.com/octokit/plugin-paginate-rest.js/blob/1f44b5469b31ddec9621000e6e1aee63c71ea8bf/src/iterator.ts |
| 40 | +class Paginator(Generic[RT]): |
| 41 | + """Paginate through the responses of the rest api request.""" |
| 42 | + |
| 43 | + @overload |
| 44 | + def __init__( |
| 45 | + self: "Paginator[RTS]", |
| 46 | + rest: "RestVersionSwitcher", |
| 47 | + request: R[CP, list[RTS]], |
| 48 | + map_func: None = None, |
| 49 | + *args: CP.args, |
| 50 | + **kwargs: CP.kwargs, |
| 51 | + ): ... |
| 52 | + |
| 53 | + @overload |
| 54 | + def __init__( |
| 55 | + self: "Paginator[RTS]", |
| 56 | + rest: "RestVersionSwitcher", |
| 57 | + request: R[CP, CT], |
| 58 | + map_func: Callable[[Response[CT]], list[RTS]], |
| 59 | + *args: CP.args, |
| 60 | + **kwargs: CP.kwargs, |
| 61 | + ): ... |
| 62 | + |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + rest: "RestVersionSwitcher", |
| 66 | + request: R[CP, CT], |
| 67 | + map_func: Optional[Callable[[Response[CT]], list[RT]]] = None, |
| 68 | + *args: CP.args, |
| 69 | + **kwargs: CP.kwargs, |
| 70 | + ): |
| 71 | + self.rest = rest |
| 72 | + |
| 73 | + self.request = request |
| 74 | + self.args = args |
| 75 | + self.kwargs = kwargs |
| 76 | + |
| 77 | + self.map_func = map_func |
| 78 | + |
| 79 | + self._initialized: bool = False |
| 80 | + self._request_method: Optional[str] = None |
| 81 | + self._response_model: Optional[Any] = None |
| 82 | + self._next_link: Optional[httpx.URL] = None |
| 83 | + |
| 84 | + self._index: int = 0 |
| 85 | + self._cached_data: list[RT] = [] |
| 86 | + |
| 87 | + @property |
| 88 | + def finalized(self) -> bool: |
| 89 | + """Whether the paginator is finalized or not.""" |
| 90 | + return self._initialized and self._next_link is None |
| 91 | + |
| 92 | + def reset(self) -> None: |
| 93 | + """Reset the paginator to the initial state.""" |
| 94 | + |
| 95 | + self._initialized = False |
| 96 | + self._next_link = None |
| 97 | + self._index = 0 |
| 98 | + self._cached_data = [] |
| 99 | + |
| 100 | + def __next__(self) -> RT: |
| 101 | + while self._index >= len(self._cached_data): |
| 102 | + self._get_next_page() |
| 103 | + if self.finalized: |
| 104 | + raise StopIteration |
| 105 | + |
| 106 | + current = self._cached_data[self._index] |
| 107 | + self._index += 1 |
| 108 | + return current |
| 109 | + |
| 110 | + def __iter__(self: Self) -> Self: |
| 111 | + if is_async(self.request): |
| 112 | + raise TypeError(f"Request method {self.request} is not an sync function") |
| 113 | + return self |
| 114 | + |
| 115 | + async def __anext__(self) -> RT: |
| 116 | + while self._index >= len(self._cached_data): |
| 117 | + await self._aget_next_page() |
| 118 | + if self.finalized: |
| 119 | + raise StopAsyncIteration |
| 120 | + |
| 121 | + current = self._cached_data[self._index] |
| 122 | + self._index += 1 |
| 123 | + return current |
| 124 | + |
| 125 | + def __aiter__(self: Self) -> Self: |
| 126 | + if not is_async(self.request): |
| 127 | + raise TypeError(f"Request method {self.request} is not an async function") |
| 128 | + return self |
| 129 | + |
| 130 | + def _find_next_link(self, response: Response[Any]) -> Optional[httpx.URL]: |
| 131 | + """Find the next link in the response headers.""" |
| 132 | + if links := response.headers.get("link"): |
| 133 | + if match := re.search(NEXT_LINK_PATTERN, links): |
| 134 | + return httpx.URL(match.group(1)) |
| 135 | + return None |
| 136 | + |
| 137 | + def _apply_map_func(self, response: Response[Any]) -> list[RT]: |
| 138 | + if self.map_func is not None: |
| 139 | + result = self.map_func(response) |
| 140 | + if not isinstance(result, list): |
| 141 | + raise TypeError(f"Map function must return a list, got {type(result)}") |
| 142 | + else: |
| 143 | + result = cast(Response[list[RT]], response).parsed_data |
| 144 | + if not isinstance(result, list): |
| 145 | + raise TypeError(f"Response is not a list, got {type(result)}") |
| 146 | + return result |
| 147 | + |
| 148 | + def _fill_cache_data(self, data: list[RT]) -> None: |
| 149 | + """Fill the cache with the data.""" |
| 150 | + self._cached_data = data |
| 151 | + self._index = 0 |
| 152 | + |
| 153 | + def _get_next_page(self) -> None: |
| 154 | + if not self._initialized: |
| 155 | + # First request |
| 156 | + response = cast( |
| 157 | + Response[Any], |
| 158 | + self.request(*self.args, **self.kwargs), |
| 159 | + ) |
| 160 | + self._initialized = True |
| 161 | + self._request_method = response.raw_request.method |
| 162 | + else: |
| 163 | + # Next request |
| 164 | + if self._next_link is None: |
| 165 | + raise RuntimeError("Paginator is finalized, no more pages to fetch.") |
| 166 | + if self._request_method is None: |
| 167 | + raise RuntimeError("Request method is not set, this should not happen.") |
| 168 | + if self._response_model is None: |
| 169 | + raise RuntimeError("Response model is not set, this should not happen.") |
| 170 | + |
| 171 | + # we request the next page with the same method and response model |
| 172 | + response = cast( |
| 173 | + Response[Any], |
| 174 | + self.rest._github.request( |
| 175 | + self._request_method, |
| 176 | + self._next_link, |
| 177 | + headers=self.kwargs.get("headers"), # type: ignore |
| 178 | + response_model=self._response_model, # type: ignore |
| 179 | + ), |
| 180 | + ) |
| 181 | + |
| 182 | + self._next_link = self._find_next_link(response) |
| 183 | + self._fill_cache_data(self._apply_map_func(response)) |
| 184 | + |
| 185 | + async def _aget_next_page(self) -> None: |
| 186 | + if not self._initialized: |
| 187 | + # First request |
| 188 | + response = cast( |
| 189 | + Response[Any], |
| 190 | + await self.request(*self.args, **self.kwargs), # type: ignore |
| 191 | + ) |
| 192 | + self._initialized = True |
| 193 | + self._request_method = response.raw_request.method |
| 194 | + else: |
| 195 | + # Next request |
| 196 | + if self._next_link is None: |
| 197 | + raise RuntimeError("Paginator is finalized, no more pages to fetch.") |
| 198 | + if self._request_method is None: |
| 199 | + raise RuntimeError("Request method is not set, this should not happen.") |
| 200 | + if self._response_model is None: |
| 201 | + raise RuntimeError("Response model is not set, this should not happen.") |
| 202 | + |
| 203 | + response = cast( |
| 204 | + Response[Any], |
| 205 | + await self.rest._github.request( |
| 206 | + self._request_method, |
| 207 | + self._next_link, |
| 208 | + headers=self.kwargs.get("headers"), # type: ignore |
| 209 | + response_model=self._response_model, # type: ignore |
| 210 | + ), |
| 211 | + ) |
| 212 | + |
| 213 | + self._next_link = self._find_next_link(response) |
| 214 | + self._fill_cache_data(self._apply_map_func(response)) |
0 commit comments