diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..7e6cba08e --- /dev/null +++ b/.editorconfig @@ -0,0 +1,9 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*.{py,pyi,pxd,pyx}] +ij_visual_guides = 80,88 diff --git a/docs/api/index.rst b/docs/api/index.rst index c6b4cfa85..7258f7de0 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -2,10 +2,11 @@ API Documentation ================= .. toctree:: - :maxdepth: 2 + :maxdepth: 2 top-level providers containers wiring errors + asgi-lifespan diff --git a/docs/conf.py b/docs/conf.py index 380da2da5..4de57da77 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -72,7 +72,7 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: diff --git a/docs/examples/django.rst b/docs/examples/django.rst index 08e6e7573..1a5b781fd 100644 --- a/docs/examples/django.rst +++ b/docs/examples/django.rst @@ -78,7 +78,7 @@ Container is wired to the ``views`` module in the app config ``web/apps.py``: .. literalinclude:: ../../examples/miniapps/django/web/apps.py :language: python - :emphasize-lines: 13 + :emphasize-lines: 12 Tests ----- diff --git a/docs/introduction/key_features.rst b/docs/introduction/key_features.rst index 0870ac11e..1975e8fcd 100644 --- a/docs/introduction/key_features.rst +++ b/docs/introduction/key_features.rst @@ -31,7 +31,7 @@ Key features of the ``Dependency Injector``: The framework stands on the `PEP20 (The Zen of Python) `_ principle: -.. code-block:: plain +.. code-block:: text Explicit is better than implicit diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 06acc716e..05bb9e9ac 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -14,8 +14,8 @@ follows `Semantic versioning`_ with updated documentation and examples. See discussion: https://github.com/ets-labs/python-dependency-injector/pull/721#issuecomment-2025263718 -- Fix ``root`` property shadowing in ``ConfigurationOption`` (`#875 https://github.com/ets-labs/python-dependency-injector/pull/875`_) -- Fix incorrect monkeypatching during ``wire()`` that could violate MRO in some classes (`#886 https://github.com/ets-labs/python-dependency-injector/pull/886`_) +- Fix ``root`` property shadowing in ``ConfigurationOption`` (`#875 `_) +- Fix incorrect monkeypatching during ``wire()`` that could violate MRO in some classes (`#886 `_) - ABI3 wheels are now published for CPython. - Drop support of Python 3.7. @@ -371,8 +371,8 @@ Many thanks to `ZipFile `_ for both contributions. - Make refactoring of wiring module and tests. See PR # `#406 `_. Thanks to `@withshubh `_ for the contribution: - - Remove unused imports in tests. - - Use literal syntax to create data structure in tests. + - Remove unused imports in tests. + - Use literal syntax to create data structure in tests. - Add integration with a static analysis tool `DeepSource `_. 4.26.0 diff --git a/docs/providers/resource.rst b/docs/providers/resource.rst index 918dfa664..fda5b3d7a 100644 --- a/docs/providers/resource.rst +++ b/docs/providers/resource.rst @@ -61,11 +61,12 @@ When you call ``.shutdown()`` method on a resource provider, it will remove the if any, and switch to uninitialized state. Some of resource initializer types support specifying custom resource shutdown. -Resource provider supports 3 types of initializers: +Resource provider supports 4 types of initializers: - Function -- Generator -- Subclass of ``resources.Resource`` +- Context Manager +- Generator (legacy) +- Subclass of ``resources.Resource`` (legacy) Function initializer -------------------- @@ -103,8 +104,44 @@ you configure global resource: Function initializer does not provide a way to specify custom resource shutdown. -Generator initializer ---------------------- +Context Manager initializer +--------------------------- + +This is an extension to the Function initializer. Resource provider automatically detects if the initializer returns a +context manager and uses it to manage the resource lifecycle. + +.. code-block:: python + + from dependency_injector import containers, providers + + class DatabaseConnection: + def __init__(self, host, port, user, password): + self.host = host + self.port = port + self.user = user + self.password = password + + def __enter__(self): + print(f"Connecting to {self.host}:{self.port} as {self.user}") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + print("Closing connection") + + + class Container(containers.DeclarativeContainer): + + config = providers.Configuration() + db = providers.Resource( + DatabaseConnection, + host=config.db.host, + port=config.db.port, + user=config.db.user, + password=config.db.password, + ) + +Generator initializer (legacy) +------------------------------ Resource provider can use 2-step generators: @@ -154,8 +191,13 @@ object is not mandatory. You can leave ``yield`` statement empty: argument2=..., ) -Subclass initializer --------------------- +.. note:: + + Generator initializers are automatically wrapped with ``contextmanager`` or ``asynccontextmanager`` decorator when + provided to a ``Resource`` provider. + +Subclass initializer (legacy) +----------------------------- You can create resource initializer by implementing a subclass of the ``resources.Resource``: @@ -263,10 +305,11 @@ Asynchronous function initializer: argument2=..., ) -Asynchronous generator initializer: +Asynchronous Context Manager initializer: .. code-block:: python + @asynccontextmanager async def init_async_resource(argument1=..., argument2=...): connection = await connect() yield connection @@ -358,5 +401,54 @@ See also: - Wiring :ref:`async-injections-wiring` - :ref:`fastapi-redis-example` +ASGI Lifespan Protocol Support +------------------------------ + +The :mod:`dependency_injector.ext.starlette` module provides a :class:`~dependency_injector.ext.starlette.Lifespan` +class that integrates resource providers with ASGI applications using the `Lifespan Protocol`_. This allows resources to +be automatically initialized at application startup and properly shut down when the application stops. + +.. code-block:: python + + from contextlib import asynccontextmanager + from dependency_injector import containers, providers + from dependency_injector.wiring import Provide, inject + from dependency_injector.ext.starlette import Lifespan + from fastapi import FastAPI, Request, Depends, APIRouter + + class Connection: ... + + @asynccontextmanager + async def init_database(): + print("opening database connection") + yield Connection() + print("closing database connection") + + router = APIRouter() + + @router.get("/") + @inject + async def index(request: Request, db: Connection = Depends(Provide["db"])): + # use the database connection here + return "OK!" + + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + db = providers.Resource(init_database) + lifespan = providers.Singleton(Lifespan, __self__) + app = providers.Singleton(FastAPI, lifespan=lifespan) + _include_router = providers.Resource( + app.provided.include_router.call(), + router, + ) + + if __name__ == "__main__": + import uvicorn + + container = Container() + app = container.app() + uvicorn.run(app, host="localhost", port=8000) + +.. _Lifespan Protocol: https://asgi.readthedocs.io/en/latest/specs/lifespan.html .. disqus:: diff --git a/docs/tutorials/aiohttp.rst b/docs/tutorials/aiohttp.rst index 57b1c9599..10b33b495 100644 --- a/docs/tutorials/aiohttp.rst +++ b/docs/tutorials/aiohttp.rst @@ -257,7 +257,7 @@ Let's check that it works. Open another terminal session and use ``httpie``: You should see: -.. code-block:: json +.. code-block:: http HTTP/1.1 200 OK Content-Length: 844 @@ -596,7 +596,7 @@ and make a request to the API in the terminal: You should see: -.. code-block:: json +.. code-block:: http HTTP/1.1 200 OK Content-Length: 492 diff --git a/docs/tutorials/cli.rst b/docs/tutorials/cli.rst index 88014ff34..ea3c84675 100644 --- a/docs/tutorials/cli.rst +++ b/docs/tutorials/cli.rst @@ -84,7 +84,7 @@ Create next structure in the project root directory. All files are empty. That's Initial project layout: -.. code-block:: bash +.. code-block:: text ./ ├── movies/ @@ -109,7 +109,7 @@ Now it's time to install the project requirements. We will use next packages: Put next lines into the ``requirements.txt`` file: -.. code-block:: bash +.. code-block:: text dependency-injector pyyaml @@ -134,7 +134,7 @@ We will create a script that creates database files. First add the folder ``data/`` in the root of the project and then add the file ``fixtures.py`` inside of it: -.. code-block:: bash +.. code-block:: text :emphasize-lines: 2-3 ./ @@ -205,13 +205,13 @@ Now run in the terminal: You should see: -.. code-block:: bash +.. code-block:: text OK Check that files ``movies.csv`` and ``movies.db`` have appeared in the ``data/`` folder: -.. code-block:: bash +.. code-block:: text :emphasize-lines: 4-5 ./ @@ -289,7 +289,7 @@ After each step we will add the provider to the container. Create the ``entities.py`` in the ``movies`` package: -.. code-block:: bash +.. code-block:: text :emphasize-lines: 10 ./ @@ -356,7 +356,7 @@ Let's move on to the finders. Create the ``finders.py`` in the ``movies`` package: -.. code-block:: bash +.. code-block:: text :emphasize-lines: 11 ./ @@ -465,7 +465,7 @@ The configuration file is ready. Move on to the lister. Create the ``listers.py`` in the ``movies`` package: -.. code-block:: bash +.. code-block:: text :emphasize-lines: 12 ./ @@ -613,7 +613,7 @@ Run in the terminal: You should see: -.. code-block:: plain +.. code-block:: text Francis Lawrence movies: - Movie(title='The Hunger Games: Mockingjay - Part 2', year=2015, director='Francis Lawrence') @@ -752,7 +752,7 @@ Run in the terminal: You should see: -.. code-block:: plain +.. code-block:: text Francis Lawrence movies: - Movie(title='The Hunger Games: Mockingjay - Part 2', year=2015, director='Francis Lawrence') @@ -868,7 +868,7 @@ Run in the terminal line by line: The output should be similar for each command: -.. code-block:: plain +.. code-block:: text Francis Lawrence movies: - Movie(title='The Hunger Games: Mockingjay - Part 2', year=2015, director='Francis Lawrence') @@ -888,7 +888,7 @@ We will use `pytest `_ and Create ``tests.py`` in the ``movies`` package: -.. code-block:: bash +.. code-block:: text :emphasize-lines: 13 ./ @@ -977,7 +977,7 @@ Run in the terminal: You should see: -.. code-block:: +.. code-block:: text platform darwin -- Python 3.10.0, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 plugins: cov-3.0.0 diff --git a/docs/tutorials/flask.rst b/docs/tutorials/flask.rst index 8c22aa5a0..b8f81adf3 100644 --- a/docs/tutorials/flask.rst +++ b/docs/tutorials/flask.rst @@ -280,7 +280,7 @@ Now let's fill in the layout. Put next into the ``base.html``: -.. code-block:: html +.. code-block:: jinja @@ -313,7 +313,7 @@ And put something to the index page. Put next into the ``index.html``: -.. code-block:: html +.. code-block:: jinja {% extends "base.html" %} diff --git a/docs/wiring.rst b/docs/wiring.rst index 02f64c60d..912b320e4 100644 --- a/docs/wiring.rst +++ b/docs/wiring.rst @@ -127,6 +127,7 @@ To inject the provider itself use ``Provide[foo.provider]``: def foo(bar_provider: Factory[Bar] = Provide[Container.bar.provider]): bar = bar_provider(argument="baz") ... + You can also use ``Provider[foo]`` for injecting the provider itself: .. code-block:: python @@ -631,6 +632,36 @@ or with a single container ``register_loader_containers(container)`` multiple ti To unregister a container use ``unregister_loader_containers(container)``. Wiring module will uninstall the import hook when unregister last container. +Few notes on performance +------------------------ + +``.wire()`` utilize caching to speed up the wiring process. At the end it clears the cache to avoid memory leaks. +But this may not always be desirable, when you want to keep the cache for the next wiring +(e.g. due to usage of multiple containers or during unit tests). + +To keep the cache after wiring, you can set flag ``keep_cache=True`` (works with ``WiringConfiguration`` too): + +.. code-block:: python + + container1.wire( + modules=["yourapp.module1", "yourapp.module2"], + keep_cache=True, + ) + container2.wire( + modules=["yourapp.module2", "yourapp.module3"], + keep_cache=True, + ) + ... + +and then clear it manually when you need it: + +.. code-block:: python + + from dependency_injector.wiring import clear_cache + + clear_cache() + + Integration with other frameworks --------------------------------- diff --git a/examples/miniapps/movie-lister/data/fixtures.py b/examples/miniapps/movie-lister/data/fixtures.py index aa1691d5a..0153e0cf7 100644 --- a/examples/miniapps/movie-lister/data/fixtures.py +++ b/examples/miniapps/movie-lister/data/fixtures.py @@ -18,10 +18,9 @@ def create_csv(movies_data, path): - with open(path, "w") as opened_file: + with open(path, "w", newline="") as opened_file: writer = csv.writer(opened_file) - for row in movies_data: - writer.writerow(row) + writer.writerows(movies_data) def create_sqlite(movies_data, path): diff --git a/examples/miniapps/movie-lister/movies/finders.py b/examples/miniapps/movie-lister/movies/finders.py index 52b8ed555..5e6d2c9c0 100644 --- a/examples/miniapps/movie-lister/movies/finders.py +++ b/examples/miniapps/movie-lister/movies/finders.py @@ -29,7 +29,7 @@ def __init__( super().__init__(movie_factory) def find_all(self) -> List[Movie]: - with open(self._csv_file_path) as csv_file: + with open(self._csv_file_path, newline="") as csv_file: csv_reader = csv.reader(csv_file, delimiter=self._delimiter) return [self._movie_factory(*row) for row in csv_reader] diff --git a/examples/providers/resource.py b/examples/providers/resource.py index 2079a929c..c712468a8 100644 --- a/examples/providers/resource.py +++ b/examples/providers/resource.py @@ -3,10 +3,12 @@ import sys import logging from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager from dependency_injector import containers, providers +@contextmanager def init_thread_pool(max_workers: int): thread_pool = ThreadPoolExecutor(max_workers=max_workers) yield thread_pool diff --git a/pyproject.toml b/pyproject.toml index 7512cb94c..885531786 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ show_missing = true [tool.isort] profile = "black" +combine_as_imports = true [tool.pylint.main] ignore = ["tests"] diff --git a/src/dependency_injector/__init__.pyi b/src/dependency_injector/__init__.pyi deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/dependency_injector/_cwiring.pyi b/src/dependency_injector/_cwiring.pyi index e7ff12f4e..c779b8c45 100644 --- a/src/dependency_injector/_cwiring.pyi +++ b/src/dependency_injector/_cwiring.pyi @@ -1,23 +1,18 @@ -from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar +from typing import Any, Dict from .providers import Provider -T = TypeVar("T") +class DependencyResolver: + def __init__( + self, + kwargs: Dict[str, Any], + injections: Dict[str, Provider[Any]], + closings: Dict[str, Provider[Any]], + /, + ) -> None: ... + def __enter__(self) -> Dict[str, Any]: ... + def __exit__(self, *exc_info: Any) -> None: ... + async def __aenter__(self) -> Dict[str, Any]: ... + async def __aexit__(self, *exc_info: Any) -> None: ... -def _sync_inject( - fn: Callable[..., T], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - injections: Dict[str, Provider[Any]], - closings: Dict[str, Provider[Any]], - /, -) -> T: ... -async def _async_inject( - fn: Callable[..., Awaitable[T]], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - injections: Dict[str, Provider[Any]], - closings: Dict[str, Provider[Any]], - /, -) -> T: ... def _isawaitable(instance: Any) -> bool: ... diff --git a/src/dependency_injector/_cwiring.pyx b/src/dependency_injector/_cwiring.pyx index 84a5485f8..3e2775c7c 100644 --- a/src/dependency_injector/_cwiring.pyx +++ b/src/dependency_injector/_cwiring.pyx @@ -1,83 +1,110 @@ """Wiring optimizations module.""" -import asyncio -import collections.abc -import inspect -import types +from asyncio import gather +from collections.abc import Awaitable +from inspect import CO_ITERABLE_COROUTINE +from types import CoroutineType, GeneratorType +from .providers cimport Provider, Resource, NULL_AWAITABLE from .wiring import _Marker -from .providers cimport Provider, Resource +cimport cython -def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): - cdef object result +@cython.internal +@cython.no_gc +cdef class KWPair: + cdef str name + cdef object value + + def __cinit__(self, str name, object value, /): + self.name = name + self.value = value + + +cdef inline bint _is_injectable(dict kwargs, str name): + return name not in kwargs or isinstance(kwargs[name], _Marker) + + +cdef class DependencyResolver: + cdef dict kwargs cdef dict to_inject - cdef object arg_key - cdef Provider provider + cdef dict injections + cdef dict closings - to_inject = kwargs.copy() - for arg_key, provider in injections.items(): - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): - to_inject[arg_key] = provider() + def __init__(self, dict kwargs, dict injections, dict closings, /): + self.kwargs = kwargs + self.to_inject = kwargs.copy() + self.injections = injections + self.closings = closings - result = fn(*args, **to_inject) + async def _await_injection(self, kw_pair: KWPair, /) -> None: + self.to_inject[kw_pair.name] = await kw_pair.value - if closings: - for arg_key, provider in closings.items(): - if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): - continue - if not isinstance(provider, Resource): - continue - provider.shutdown() + cdef object _await_injections(self, to_await: list): + return gather(*map(self._await_injection, to_await)) - return result + cdef void _handle_injections_sync(self): + cdef Provider provider + for name, provider in self.injections.items(): + if _is_injectable(self.kwargs, name): + self.to_inject[name] = provider() -async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): - cdef object result - cdef dict to_inject - cdef list to_inject_await = [] - cdef list to_close_await = [] - cdef object arg_key - cdef Provider provider - - to_inject = kwargs.copy() - for arg_key, provider in injections.items(): - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): - provide = provider() - if provider.is_async_mode_enabled(): - to_inject_await.append((arg_key, provide)) - elif _isawaitable(provide): - to_inject_await.append((arg_key, provide)) - else: - to_inject[arg_key] = provide - - if to_inject_await: - async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await)) - for provide, (injection, _) in zip(async_to_inject, to_inject_await): - to_inject[injection] = provide - - result = await fn(*args, **to_inject) - - if closings: - for arg_key, provider in closings.items(): - if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): - continue - if not isinstance(provider, Resource): - continue - shutdown = provider.shutdown() - if _isawaitable(shutdown): - to_close_await.append(shutdown) - - await asyncio.gather(*to_close_await) - - return result + cdef list _handle_injections_async(self): + cdef list to_await = [] + cdef Provider provider + + for name, provider in self.injections.items(): + if _is_injectable(self.kwargs, name): + provide = provider() + + if provider.is_async_mode_enabled() or _isawaitable(provide): + to_await.append(KWPair(name, provide)) + else: + self.to_inject[name] = provide + + return to_await + + cdef void _handle_closings_sync(self): + cdef Provider provider + + for name, provider in self.closings.items(): + if _is_injectable(self.kwargs, name) and isinstance(provider, Resource): + provider.shutdown() + + cdef list _handle_closings_async(self): + cdef list to_await = [] + cdef Provider provider + + for name, provider in self.closings.items(): + if _is_injectable(self.kwargs, name) and isinstance(provider, Resource): + if _isawaitable(shutdown := provider.shutdown()): + to_await.append(shutdown) + + return to_await + + def __enter__(self): + self._handle_injections_sync() + return self.to_inject + + def __exit__(self, *_): + self._handle_closings_sync() + + async def __aenter__(self): + if to_await := self._handle_injections_async(): + await self._await_injections(to_await) + return self.to_inject + + def __aexit__(self, *_): + if to_await := self._handle_closings_async(): + return gather(*to_await) + return NULL_AWAITABLE cdef bint _isawaitable(object instance): """Return true if object can be passed to an ``await`` expression.""" - return (isinstance(instance, types.CoroutineType) or - isinstance(instance, types.GeneratorType) and - bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or - isinstance(instance, collections.abc.Awaitable)) + return (isinstance(instance, CoroutineType) or + isinstance(instance, GeneratorType) and + bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or + isinstance(instance, Awaitable)) diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index ec41ea8e6..ca608f28d 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -1,23 +1,28 @@ from pathlib import Path from typing import ( - Generic, - Type, - Dict, - List, - Tuple, - Optional, Any, - Union, - ClassVar, + Awaitable, Callable as _Callable, + ClassVar, + Dict, + Generic, Iterable, Iterator, + List, + Optional, + Tuple, + Type, TypeVar, - Awaitable, + Union, overload, ) -from .providers import Provider, Self, ProviderParent +try: + from typing import Self as _Self +except ImportError: + from typing_extensions import Self as _Self + +from .providers import Provider, ProviderParent, Self C_Base = TypeVar("C_Base", bound="Container") C = TypeVar("C", bound="DeclarativeContainer") @@ -30,32 +35,34 @@ class WiringConfiguration: packages: List[Any] from_package: Optional[str] auto_wire: bool + keep_cache: bool def __init__( self, modules: Optional[Iterable[Any]] = None, packages: Optional[Iterable[Any]] = None, from_package: Optional[str] = None, auto_wire: bool = True, + keep_cache: bool = False, ) -> None: ... class Container: - provider_type: Type[Provider] = Provider - providers: Dict[str, Provider] + provider_type: Type[Provider[Any]] = Provider + providers: Dict[str, Provider[Any]] dependencies: Dict[str, Provider[Any]] - overridden: Tuple[Provider] + overridden: Tuple[Provider[Any], ...] wiring_config: WiringConfiguration auto_load_config: bool = True __self__: Self def __init__(self) -> None: ... - def __deepcopy__(self, memo: Optional[Dict[str, Any]]) -> Provider: ... - def __setattr__(self, name: str, value: Union[Provider, Any]) -> None: ... - def __getattr__(self, name: str) -> Provider: ... + def __deepcopy__(self, memo: Optional[Dict[str, Any]]) -> _Self: ... + def __setattr__(self, name: str, value: Union[Provider[Any], Any]) -> None: ... + def __getattr__(self, name: str) -> Provider[Any]: ... def __delattr__(self, name: str) -> None: ... - def set_providers(self, **providers: Provider): ... - def set_provider(self, name: str, provider: Provider) -> None: ... + def set_providers(self, **providers: Provider[Any]) -> None: ... + def set_provider(self, name: str, provider: Provider[Any]) -> None: ... def override(self, overriding: Union[Container, Type[Container]]) -> None: ... def override_providers( - self, **overriding_providers: Union[Provider, Any] + self, **overriding_providers: Union[Provider[Any], Any] ) -> ProvidersOverridingContext[C_Base]: ... def reset_last_overriding(self) -> None: ... def reset_override(self) -> None: ... @@ -67,8 +74,8 @@ class Container: from_package: Optional[str] = None, ) -> None: ... def unwire(self) -> None: ... - def init_resources(self) -> Optional[Awaitable]: ... - def shutdown_resources(self) -> Optional[Awaitable]: ... + def init_resources(self) -> Optional[Awaitable[None]]: ... + def shutdown_resources(self) -> Optional[Awaitable[None]]: ... def load_config(self) -> None: ... def apply_container_providers_overridings(self) -> None: ... def reset_singletons(self) -> SingletonResetContext[C_Base]: ... @@ -79,10 +86,10 @@ class Container: ) -> None: ... def from_json_schema(self, filepath: Union[Path, str]) -> None: ... @overload - def resolve_provider_name(self, provider: Provider) -> str: ... + def resolve_provider_name(self, provider: Provider[Any]) -> str: ... @classmethod @overload - def resolve_provider_name(cls, provider: Provider) -> str: ... + def resolve_provider_name(cls, provider: Provider[Any]) -> str: ... @property def parent(self) -> Optional[ProviderParent]: ... @property @@ -97,14 +104,14 @@ class Container: class DynamicContainer(Container): ... class DeclarativeContainer(Container): - cls_providers: ClassVar[Dict[str, Provider]] - inherited_providers: ClassVar[Dict[str, Provider]] - def __init__(self, **overriding_providers: Union[Provider, Any]) -> None: ... + cls_providers: ClassVar[Dict[str, Provider[Any]]] + inherited_providers: ClassVar[Dict[str, Provider[Any]]] + def __init__(self, **overriding_providers: Union[Provider[Any], Any]) -> None: ... @classmethod def override(cls, overriding: Union[Container, Type[Container]]) -> None: ... @classmethod def override_providers( - cls, **overriding_providers: Union[Provider, Any] + cls, **overriding_providers: Union[Provider[Any], Any] ) -> ProvidersOverridingContext[C_Base]: ... @classmethod def reset_last_overriding(cls) -> None: ... @@ -113,7 +120,7 @@ class DeclarativeContainer(Container): class ProvidersOverridingContext(Generic[T]): def __init__( - self, container: T, overridden_providers: Iterable[Union[Provider, Any]] + self, container: T, overridden_providers: Iterable[Union[Provider[Any], Any]] ) -> None: ... def __enter__(self) -> T: ... def __exit__(self, *_: Any) -> None: ... diff --git a/src/dependency_injector/containers.pyx b/src/dependency_injector/containers.pyx index 2f4c4af58..bd0a4821b 100644 --- a/src/dependency_injector/containers.pyx +++ b/src/dependency_injector/containers.pyx @@ -20,14 +20,15 @@ from .wiring import wire, unwire class WiringConfiguration: """Container wiring configuration.""" - def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True): + def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True, keep_cache=False): self.modules = [*modules] if modules else [] self.packages = [*packages] if packages else [] self.from_package = from_package self.auto_wire = auto_wire + self.keep_cache = keep_cache def __deepcopy__(self, memo=None): - return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire) + return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire, self.keep_cache) class Container: @@ -258,7 +259,7 @@ class DynamicContainer(Container): """Check if auto wiring is needed.""" return self.wiring_config.auto_wire is True - def wire(self, modules=None, packages=None, from_package=None): + def wire(self, modules=None, packages=None, from_package=None, keep_cache=None): """Wire container providers with provided packages and modules. :rtype: None @@ -289,10 +290,14 @@ class DynamicContainer(Container): if not modules and not packages: return + if keep_cache is None: + keep_cache = self.wiring_config.keep_cache + wire( container=self, modules=modules, packages=packages, + keep_cache=keep_cache, ) if modules: diff --git a/src/dependency_injector/ext/aiohttp.py b/src/dependency_injector/ext/aiohttp.py index 976089c32..43990a7d3 100644 --- a/src/dependency_injector/ext/aiohttp.py +++ b/src/dependency_injector/ext/aiohttp.py @@ -7,7 +7,6 @@ from dependency_injector import providers - warnings.warn( 'Module "dependency_injector.ext.aiohttp" is deprecated since ' 'version 4.0.0. Use "dependency_injector.wiring" module instead.', diff --git a/src/dependency_injector/ext/aiohttp.pyi b/src/dependency_injector/ext/aiohttp.pyi index 370cc9b00..c524712c5 100644 --- a/src/dependency_injector/ext/aiohttp.pyi +++ b/src/dependency_injector/ext/aiohttp.pyi @@ -1,14 +1,16 @@ -from typing import Awaitable as _Awaitable +from typing import Any, Awaitable as _Awaitable, TypeVar from dependency_injector import providers -class Application(providers.Singleton): ... -class Extension(providers.Singleton): ... -class Middleware(providers.DelegatedCallable): ... -class MiddlewareFactory(providers.Factory): ... +T = TypeVar("T") -class View(providers.Callable): - def as_view(self) -> _Awaitable: ... +class Application(providers.Singleton[T]): ... +class Extension(providers.Singleton[T]): ... +class Middleware(providers.DelegatedCallable[T]): ... +class MiddlewareFactory(providers.Factory[T]): ... -class ClassBasedView(providers.Factory): - def as_view(self) -> _Awaitable: ... +class View(providers.Callable[T]): + def as_view(self) -> _Awaitable[T]: ... + +class ClassBasedView(providers.Factory[T]): + def as_view(self) -> _Awaitable[T]: ... diff --git a/src/dependency_injector/ext/flask.py b/src/dependency_injector/ext/flask.py index 498a9eee4..15b9df0aa 100644 --- a/src/dependency_injector/ext/flask.py +++ b/src/dependency_injector/ext/flask.py @@ -1,12 +1,12 @@ """Flask extension module.""" from __future__ import absolute_import + import warnings from flask import request as flask_request -from dependency_injector import providers, errors - +from dependency_injector import errors, providers warnings.warn( 'Module "dependency_injector.ext.flask" is deprecated since ' diff --git a/src/dependency_injector/ext/flask.pyi b/src/dependency_injector/ext/flask.pyi index 9b180c895..1c791b88d 100644 --- a/src/dependency_injector/ext/flask.pyi +++ b/src/dependency_injector/ext/flask.pyi @@ -1,19 +1,21 @@ -from typing import Union, Optional, Callable as _Callable, Any +from typing import Any, Callable as _Callable, Optional, TypeVar, Union + +from flask.wrappers import Request -from flask import request as flask_request from dependency_injector import providers -request: providers.Object[flask_request] +request: providers.Object[Request] +T = TypeVar("T") -class Application(providers.Singleton): ... -class Extension(providers.Singleton): ... +class Application(providers.Singleton[T]): ... +class Extension(providers.Singleton[T]): ... -class View(providers.Callable): - def as_view(self) -> _Callable[..., Any]: ... +class View(providers.Callable[T]): + def as_view(self) -> _Callable[..., T]: ... -class ClassBasedView(providers.Factory): - def as_view(self, name: str) -> _Callable[..., Any]: ... +class ClassBasedView(providers.Factory[T]): + def as_view(self, name: str) -> _Callable[..., T]: ... def as_view( - provider: Union[View, ClassBasedView], name: Optional[str] = None -) -> _Callable[..., Any]: ... + provider: Union[View[T], ClassBasedView[T]], name: Optional[str] = None +) -> _Callable[..., T]: ... diff --git a/src/dependency_injector/providers.pxd b/src/dependency_injector/providers.pxd index b4eb471de..21ed7f229 100644 --- a/src/dependency_injector/providers.pxd +++ b/src/dependency_injector/providers.pxd @@ -697,3 +697,10 @@ cdef inline object __future_result(object instance): future_result = asyncio.Future() future_result.set_result(instance) return future_result + + +cdef class NullAwaitable: + pass + + +cdef NullAwaitable NULL_AWAITABLE diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index d276903b1..43e49d7e3 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -15,8 +15,11 @@ import re import sys import threading import warnings +from asyncio import ensure_future from configparser import ConfigParser as IniConfigParser +from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar +from inspect import isasyncgenfunction, isgeneratorfunction try: from inspect import _is_coroutine_mark as _is_coroutine_marker @@ -3598,6 +3601,17 @@ cdef class Dict(Provider): return __provide_keyword_args(kwargs, self._kwargs, self._kwargs_len, self._async_mode) +@cython.no_gc +cdef class NullAwaitable: + def __next__(self): + raise StopIteration from None + + def __await__(self): + return self + + +cdef NullAwaitable NULL_AWAITABLE = NullAwaitable() + cdef class Resource(Provider): """Resource provider provides a component with initialization and shutdown.""" @@ -3653,6 +3667,12 @@ cdef class Resource(Provider): def set_provides(self, provides): """Set provider provides.""" provides = _resolve_string_import(provides) + + if isasyncgenfunction(provides): + provides = asynccontextmanager(provides) + elif isgeneratorfunction(provides): + provides = contextmanager(provides) + self._provides = provides return self @@ -3753,28 +3773,21 @@ cdef class Resource(Provider): """Shutdown resource.""" if not self._initialized: if self._async_mode == ASYNC_MODE_ENABLED: - result = asyncio.Future() - result.set_result(None) - return result + return NULL_AWAITABLE return if self._shutdowner: - try: - shutdown = self._shutdowner(self._resource) - except StopIteration: - pass - else: - if inspect.isawaitable(shutdown): - return self._create_shutdown_future(shutdown) + future = self._shutdowner(None, None, None) + + if __is_future_or_coroutine(future): + return ensure_future(self._shutdown_async(future)) self._resource = None self._initialized = False self._shutdowner = None if self._async_mode == ASYNC_MODE_ENABLED: - result = asyncio.Future() - result.set_result(None) - return result + return NULL_AWAITABLE @property def related(self): @@ -3784,165 +3797,75 @@ cdef class Resource(Provider): yield from filter(is_provider, self.kwargs.values()) yield from super().related + async def _shutdown_async(self, future) -> None: + try: + await future + finally: + self._resource = None + self._initialized = False + self._shutdowner = None + + async def _handle_async_cm(self, obj) -> None: + try: + self._resource = resource = await obj.__aenter__() + self._shutdowner = obj.__aexit__ + return resource + except: + self._initialized = False + raise + + async def _provide_async(self, future) -> None: + try: + obj = await future + + if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): + self._resource = await obj.__aenter__() + self._shutdowner = obj.__aexit__ + elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): + self._resource = obj.__enter__() + self._shutdowner = obj.__exit__ + else: + self._resource = obj + self._shutdowner = None + + return self._resource + except: + self._initialized = False + raise + cpdef object _provide(self, tuple args, dict kwargs): if self._initialized: return self._resource - if self._is_resource_subclass(self._provides): - initializer = self._provides() - self._resource = __call( - initializer.init, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) - self._shutdowner = initializer.shutdown - elif self._is_async_resource_subclass(self._provides): - initializer = self._provides() - async_init = __call( - initializer.init, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) - self._initialized = True - return self._create_init_future(async_init, initializer.shutdown) - elif inspect.isgeneratorfunction(self._provides): - initializer = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) - self._resource = next(initializer) - self._shutdowner = initializer.send - elif iscoroutinefunction(self._provides): - initializer = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) + obj = __call( + self._provides, + args, + self._args, + self._args_len, + kwargs, + self._kwargs, + self._kwargs_len, + self._async_mode, + ) + + if __is_future_or_coroutine(obj): self._initialized = True - return self._create_init_future(initializer) - elif isasyncgenfunction(self._provides): - initializer = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) + self._resource = resource = ensure_future(self._provide_async(obj)) + return resource + elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): + self._resource = obj.__enter__() + self._shutdowner = obj.__exit__ + elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): self._initialized = True - return self._create_async_gen_init_future(initializer) - elif callable(self._provides): - self._resource = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) + self._resource = resource = ensure_future(self._handle_async_cm(obj)) + return resource else: - raise Error("Unknown type of resource initializer") + self._resource = obj + self._shutdowner = None self._initialized = True return self._resource - def _create_init_future(self, future, shutdowner=None): - callback = self._async_init_callback - if shutdowner: - callback = functools.partial(callback, shutdowner=shutdowner) - - future = asyncio.ensure_future(future) - future.add_done_callback(callback) - self._resource = future - - return future - - def _create_async_gen_init_future(self, initializer): - if inspect.isasyncgen(initializer): - return self._create_init_future(initializer.__anext__(), initializer.asend) - - future = asyncio.Future() - - create_initializer = asyncio.ensure_future(initializer) - create_initializer.add_done_callback(functools.partial(self._async_create_gen_callback, future)) - self._resource = future - - return future - - def _async_init_callback(self, initializer, shutdowner=None): - try: - resource = initializer.result() - except Exception: - self._initialized = False - else: - self._resource = resource - self._shutdowner = shutdowner - - def _async_create_gen_callback(self, future, initializer_future): - initializer = initializer_future.result() - init_future = self._create_init_future(initializer.__anext__(), initializer.asend) - init_future.add_done_callback(functools.partial(self._async_trigger_result, future)) - - def _async_trigger_result(self, future, future_result): - future.set_result(future_result.result()) - - def _create_shutdown_future(self, shutdown_future): - future = asyncio.Future() - shutdown_future = asyncio.ensure_future(shutdown_future) - shutdown_future.add_done_callback(functools.partial(self._async_shutdown_callback, future)) - return future - - def _async_shutdown_callback(self, future_result, shutdowner): - try: - shutdowner.result() - except StopAsyncIteration: - pass - - self._resource = None - self._initialized = False - self._shutdowner = None - - future_result.set_result(None) - - @staticmethod - def _is_resource_subclass(instance): - if not isinstance(instance, type): - return - from . import resources - return issubclass(instance, resources.Resource) - - @staticmethod - def _is_async_resource_subclass(instance): - if not isinstance(instance, type): - return - from . import resources - return issubclass(instance, resources.AsyncResource) - cdef class Container(Provider): """Container provider provides an instance of declarative container. @@ -4993,14 +4916,6 @@ def iscoroutinefunction(obj): return False -def isasyncgenfunction(obj): - """Check if object is an asynchronous generator function.""" - try: - return inspect.isasyncgenfunction(obj) - except AttributeError: - return False - - def _resolve_string_import(provides): if provides is None: return provides diff --git a/src/dependency_injector/resources.py b/src/dependency_injector/resources.py index 7d71d4d82..8722af229 100644 --- a/src/dependency_injector/resources.py +++ b/src/dependency_injector/resources.py @@ -1,23 +1,54 @@ """Resources module.""" -import abc -from typing import TypeVar, Generic, Optional - +from abc import ABCMeta, abstractmethod +from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar T = TypeVar("T") -class Resource(Generic[T], metaclass=abc.ABCMeta): +class Resource(Generic[T], metaclass=ABCMeta): + __slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj") + + obj: Optional[T] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.obj = None - @abc.abstractmethod - def init(self, *args, **kwargs) -> Optional[T]: ... + @abstractmethod + def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ... def shutdown(self, resource: Optional[T]) -> None: ... + def __enter__(self) -> Optional[T]: + self.obj = obj = self.init(*self.args, **self.kwargs) + return obj + + def __exit__(self, *exc_info: Any) -> None: + self.shutdown(self.obj) + self.obj = None + -class AsyncResource(Generic[T], metaclass=abc.ABCMeta): +class AsyncResource(Generic[T], metaclass=ABCMeta): + __slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj") - @abc.abstractmethod - async def init(self, *args, **kwargs) -> Optional[T]: ... + obj: Optional[T] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.obj = None + + @abstractmethod + async def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ... async def shutdown(self, resource: Optional[T]) -> None: ... + + async def __aenter__(self) -> Optional[T]: + self.obj = obj = await self.init(*self.args, **self.kwargs) + return obj + + async def __aexit__(self, *exc_info: Any) -> None: + await self.shutdown(self.obj) + self.obj = None diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index b8534ee52..0477eed48 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -10,6 +10,7 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Callable, Dict, Iterable, @@ -24,7 +25,17 @@ cast, ) -from typing_extensions import Self +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +try: + from functools import cache +except ImportError: + from functools import lru_cache + + cache = lru_cache(maxsize=None) # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362 if sys.version_info >= (3, 9): @@ -409,6 +420,7 @@ def wire( # noqa: C901 *, modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None, + keep_cache: bool = False, ) -> None: """Wire container providers with provided packages and modules.""" modules = [*modules] if modules else [] @@ -449,6 +461,9 @@ def wire( # noqa: C901 for patched in _patched_registry.get_callables_from_module(module): _bind_injections(patched, providers_map) + if not keep_cache: + clear_cache() + def unwire( # noqa: C901 *, @@ -604,6 +619,7 @@ def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]: return marker +@cache def _fetch_reference_injections( # noqa: C901 fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -708,6 +724,8 @@ def _get_patched( if inspect.iscoroutinefunction(fn): patched = _get_async_patched(fn, patched_object) + elif inspect.isasyncgenfunction(fn): + patched = _get_async_gen_patched(fn, patched_object) else: patched = _get_sync_patched(fn, patched_object) @@ -1023,36 +1041,41 @@ def is_loader_installed() -> bool: _loader = AutoLoader() # Optimizations -from ._cwiring import _async_inject # noqa -from ._cwiring import _sync_inject # noqa +from ._cwiring import DependencyResolver # noqa: E402 # Wiring uses the following Python wrapper because there is # no possibility to compile a first-type citizen coroutine in Cython. def _get_async_patched(fn: F, patched: PatchedCallable) -> F: @functools.wraps(fn) - async def _patched(*args, **kwargs): - return await _async_inject( - fn, - args, - kwargs, - patched.injections, - patched.closing, - ) + async def _patched(*args: Any, **raw_kwargs: Any) -> Any: + resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing) + + async with resolver as kwargs: + return await fn(*args, **kwargs) + + return cast(F, _patched) + + +def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F: + @functools.wraps(fn) + async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]: + resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing) + + async with resolver as kwargs: + async for obj in fn(*args, **kwargs): + yield obj return cast(F, _patched) def _get_sync_patched(fn: F, patched: PatchedCallable) -> F: @functools.wraps(fn) - def _patched(*args, **kwargs): - return _sync_inject( - fn, - args, - kwargs, - patched.injections, - patched.closing, - ) + def _patched(*args: Any, **raw_kwargs: Any) -> Any: + resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing) + + with resolver as kwargs: + return fn(*args, **kwargs) return cast(F, _patched) @@ -1078,3 +1101,8 @@ def _get_members_and_annotated(obj: Any) -> Iterable[Tuple[str, Any]]: member = args[1] members.append((annotation_name, member)) return members + + +def clear_cache() -> None: + """Clear all caches used by :func:`wire`.""" + _fetch_reference_injections.cache_clear() diff --git a/tests/unit/providers/resource/test_async_resource_py35.py b/tests/unit/providers/resource/test_async_resource_py35.py index 1ca950a88..6458584dd 100644 --- a/tests/unit/providers/resource/test_async_resource_py35.py +++ b/tests/unit/providers/resource/test_async_resource_py35.py @@ -2,12 +2,13 @@ import asyncio import inspect -import sys +from contextlib import asynccontextmanager from typing import Any -from dependency_injector import containers, providers, resources from pytest import mark, raises +from dependency_injector import containers, providers, resources + @mark.asyncio async def test_init_async_function(): @@ -70,6 +71,46 @@ async def _init(): assert _init.shutdown_counter == 2 +@mark.asyncio +async def test_init_async_context_manager() -> None: + resource = object() + + init_counter = 0 + shutdown_counter = 0 + + @asynccontextmanager + async def _init(): + nonlocal init_counter, shutdown_counter + + await asyncio.sleep(0.001) + init_counter += 1 + + yield resource + + await asyncio.sleep(0.001) + shutdown_counter += 1 + + provider = providers.Resource(_init) + + result1 = await provider() + assert result1 is resource + assert init_counter == 1 + assert shutdown_counter == 0 + + await provider.shutdown() + assert init_counter == 1 + assert shutdown_counter == 1 + + result2 = await provider() + assert result2 is resource + assert init_counter == 2 + assert shutdown_counter == 1 + + await provider.shutdown() + assert init_counter == 2 + assert shutdown_counter == 2 + + @mark.asyncio async def test_init_async_class(): resource = object() diff --git a/tests/unit/providers/resource/test_resource_py35.py b/tests/unit/providers/resource/test_resource_py35.py index 9b906bd75..842d8ba6f 100644 --- a/tests/unit/providers/resource/test_resource_py35.py +++ b/tests/unit/providers/resource/test_resource_py35.py @@ -2,10 +2,12 @@ import decimal import sys +from contextlib import contextmanager from typing import Any -from dependency_injector import containers, providers, resources, errors -from pytest import raises, mark +from pytest import mark, raises + +from dependency_injector import containers, errors, providers, resources def init_fn(*args, **kwargs): @@ -123,6 +125,41 @@ def _init(): assert _init.shutdown_counter == 2 +def test_init_context_manager() -> None: + init_counter, shutdown_counter = 0, 0 + + @contextmanager + def _init(): + nonlocal init_counter, shutdown_counter + + init_counter += 1 + yield + shutdown_counter += 1 + + init_counter = 0 + shutdown_counter = 0 + + provider = providers.Resource(_init) + + result1 = provider() + assert result1 is None + assert init_counter == 1 + assert shutdown_counter == 0 + + provider.shutdown() + assert init_counter == 1 + assert shutdown_counter == 1 + + result2 = provider() + assert result2 is None + assert init_counter == 2 + assert shutdown_counter == 1 + + provider.shutdown() + assert init_counter == 2 + assert shutdown_counter == 2 + + def test_init_class(): class TestResource(resources.Resource): init_counter = 0 @@ -190,7 +227,7 @@ def init(self): def test_init_not_callable(): provider = providers.Resource(1) - with raises(errors.Error): + with raises(TypeError, match=r"object is not callable"): provider.init() diff --git a/tests/unit/samples/wiring/asyncinjections.py b/tests/unit/samples/wiring/asyncinjections.py index 204300e3a..e08610179 100644 --- a/tests/unit/samples/wiring/asyncinjections.py +++ b/tests/unit/samples/wiring/asyncinjections.py @@ -1,7 +1,9 @@ import asyncio +from typing_extensions import Annotated + from dependency_injector import containers, providers -from dependency_injector.wiring import inject, Provide, Closing +from dependency_injector.wiring import Closing, Provide, inject class TestResource: @@ -42,6 +44,15 @@ async def async_injection( return resource1, resource2 +@inject +async def async_generator_injection( + resource1: object = Provide[Container.resource1], + resource2: object = Closing[Provide[Container.resource2]], +): + yield resource1 + yield resource2 + + @inject async def async_injection_with_closing( resource1: object = Closing[Provide[Container.resource1]], diff --git a/tests/unit/wiring/provider_ids/test_async_injections_py36.py b/tests/unit/wiring/provider_ids/test_async_injections_py36.py index f17f19c77..70f9eb171 100644 --- a/tests/unit/wiring/provider_ids/test_async_injections_py36.py +++ b/tests/unit/wiring/provider_ids/test_async_injections_py36.py @@ -32,6 +32,23 @@ async def test_async_injections(): assert asyncinjections.resource2.shutdown_counter == 0 +@mark.asyncio +async def test_async_generator_injections() -> None: + resources = [] + + async for resource in asyncinjections.async_generator_injection(): + resources.append(resource) + + assert len(resources) == 2 + assert resources[0] is asyncinjections.resource1 + assert asyncinjections.resource1.init_counter == 1 + assert asyncinjections.resource1.shutdown_counter == 0 + + assert resources[1] is asyncinjections.resource2 + assert asyncinjections.resource2.init_counter == 1 + assert asyncinjections.resource2.shutdown_counter == 1 + + @mark.asyncio async def test_async_injections_with_closing(): resource1, resource2 = await asyncinjections.async_injection_with_closing() diff --git a/tests/unit/wiring/test_cache.py b/tests/unit/wiring/test_cache.py new file mode 100644 index 000000000..d6c1f45fd --- /dev/null +++ b/tests/unit/wiring/test_cache.py @@ -0,0 +1,46 @@ +"""Tests for string module and package names.""" + +from typing import Iterator, Optional + +from pytest import fixture, mark +from samples.wiring.container import Container + +from dependency_injector.wiring import _fetch_reference_injections + + +@fixture +def container() -> Iterator[Container]: + container = Container() + yield container + container.unwire() + + +@mark.parametrize( + ["arg_value", "wc_value", "empty_cache"], + [ + (None, False, True), + (False, True, True), + (True, False, False), + (None, True, False), + ], +) +def test_fetch_reference_injections_cache( + container: Container, + arg_value: Optional[bool], + wc_value: bool, + empty_cache: bool, +) -> None: + container.wiring_config.keep_cache = wc_value + container.wire( + modules=["samples.wiring.module"], + packages=["samples.wiring.package"], + keep_cache=arg_value, + ) + cache_info = _fetch_reference_injections.cache_info() + + if empty_cache: + assert cache_info == (0, 0, None, 0) + else: + assert cache_info.hits > 0 + assert cache_info.misses > 0 + assert cache_info.currsize > 0