|
19 | 19 | import types
|
20 | 20 | import warnings
|
21 | 21 | from dis import COMPILER_FLAG_NAMES
|
22 |
| -from typing import Any, Callable, ContextManager, Iterable |
| 22 | +from pathlib import Path |
| 23 | +from typing import Any, Callable, ContextManager, Iterable, Sequence |
23 | 24 |
|
24 | 25 | from prompt_toolkit.formatted_text import OneStyleAndTextTuple
|
25 | 26 | from prompt_toolkit.patch_stdout import patch_stdout as patch_stdout_context
|
@@ -64,7 +65,7 @@ def _has_coroutine_flag(code: types.CodeType) -> bool:
|
64 | 65 |
|
65 | 66 | class PythonRepl(PythonInput):
|
66 | 67 | def __init__(self, *a, **kw) -> None:
|
67 |
| - self._startup_paths = kw.pop("startup_paths", None) |
| 68 | + self._startup_paths: Sequence[str | Path] | None = kw.pop("startup_paths", None) |
68 | 69 | super().__init__(*a, **kw)
|
69 | 70 | self._load_start_paths()
|
70 | 71 |
|
@@ -348,7 +349,7 @@ def _store_eval_result(self, result: object) -> None:
|
348 | 349 | def get_compiler_flags(self) -> int:
|
349 | 350 | return super().get_compiler_flags() | PyCF_ALLOW_TOP_LEVEL_AWAIT
|
350 | 351 |
|
351 |
| - def _compile_with_flags(self, code: str, mode: str): |
| 352 | + def _compile_with_flags(self, code: str, mode: str) -> Any: |
352 | 353 | "Compile code with the right compiler flags."
|
353 | 354 | return compile(
|
354 | 355 | code,
|
@@ -459,13 +460,13 @@ def enter_to_continue() -> None:
|
459 | 460 |
|
460 | 461 |
|
461 | 462 | def embed(
|
462 |
| - globals=None, |
463 |
| - locals=None, |
| 463 | + globals: dict[str, Any] | None = None, |
| 464 | + locals: dict[str, Any] | None = None, |
464 | 465 | configure: Callable[[PythonRepl], None] | None = None,
|
465 | 466 | vi_mode: bool = False,
|
466 | 467 | history_filename: str | None = None,
|
467 | 468 | title: str | None = None,
|
468 |
| - startup_paths=None, |
| 469 | + startup_paths: Sequence[str | Path] | None = None, |
469 | 470 | patch_stdout: bool = False,
|
470 | 471 | return_asyncio_coroutine: bool = False,
|
471 | 472 | ) -> None:
|
@@ -494,10 +495,10 @@ def embed(
|
494 | 495 |
|
495 | 496 | locals = locals or globals
|
496 | 497 |
|
497 |
| - def get_globals(): |
| 498 | + def get_globals() -> dict[str, Any]: |
498 | 499 | return globals
|
499 | 500 |
|
500 |
| - def get_locals(): |
| 501 | + def get_locals() -> dict[str, Any]: |
501 | 502 | return locals
|
502 | 503 |
|
503 | 504 | # Create REPL.
|
|
0 commit comments