diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 000000000..7e6cba08e
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,9 @@
+root = true
+
+[*]
+end_of_line = lf
+insert_final_newline = true
+trim_trailing_whitespace = true
+
+[*.{py,pyi,pxd,pyx}]
+ij_visual_guides = 80,88
diff --git a/docs/api/index.rst b/docs/api/index.rst
index c6b4cfa85..7258f7de0 100644
--- a/docs/api/index.rst
+++ b/docs/api/index.rst
@@ -2,10 +2,11 @@ API Documentation
=================
.. toctree::
- :maxdepth: 2
+ :maxdepth: 2
top-level
providers
containers
wiring
errors
+ asgi-lifespan
diff --git a/docs/conf.py b/docs/conf.py
index 380da2da5..4de57da77 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -72,7 +72,7 @@
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
-language = None
+language = "en"
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
diff --git a/docs/examples/django.rst b/docs/examples/django.rst
index 08e6e7573..1a5b781fd 100644
--- a/docs/examples/django.rst
+++ b/docs/examples/django.rst
@@ -78,7 +78,7 @@ Container is wired to the ``views`` module in the app config ``web/apps.py``:
.. literalinclude:: ../../examples/miniapps/django/web/apps.py
:language: python
- :emphasize-lines: 13
+ :emphasize-lines: 12
Tests
-----
diff --git a/docs/introduction/key_features.rst b/docs/introduction/key_features.rst
index 0870ac11e..1975e8fcd 100644
--- a/docs/introduction/key_features.rst
+++ b/docs/introduction/key_features.rst
@@ -31,7 +31,7 @@ Key features of the ``Dependency Injector``:
The framework stands on the `PEP20 (The Zen of Python) `_ principle:
-.. code-block:: plain
+.. code-block:: text
Explicit is better than implicit
diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst
index 06acc716e..05bb9e9ac 100644
--- a/docs/main/changelog.rst
+++ b/docs/main/changelog.rst
@@ -14,8 +14,8 @@ follows `Semantic versioning`_
with updated documentation and examples.
See discussion:
https://github.com/ets-labs/python-dependency-injector/pull/721#issuecomment-2025263718
-- Fix ``root`` property shadowing in ``ConfigurationOption`` (`#875 https://github.com/ets-labs/python-dependency-injector/pull/875`_)
-- Fix incorrect monkeypatching during ``wire()`` that could violate MRO in some classes (`#886 https://github.com/ets-labs/python-dependency-injector/pull/886`_)
+- Fix ``root`` property shadowing in ``ConfigurationOption`` (`#875 `_)
+- Fix incorrect monkeypatching during ``wire()`` that could violate MRO in some classes (`#886 `_)
- ABI3 wheels are now published for CPython.
- Drop support of Python 3.7.
@@ -371,8 +371,8 @@ Many thanks to `ZipFile `_ for both contributions.
- Make refactoring of wiring module and tests.
See PR # `#406 `_.
Thanks to `@withshubh `_ for the contribution:
- - Remove unused imports in tests.
- - Use literal syntax to create data structure in tests.
+ - Remove unused imports in tests.
+ - Use literal syntax to create data structure in tests.
- Add integration with a static analysis tool `DeepSource `_.
4.26.0
diff --git a/docs/providers/resource.rst b/docs/providers/resource.rst
index 918dfa664..fda5b3d7a 100644
--- a/docs/providers/resource.rst
+++ b/docs/providers/resource.rst
@@ -61,11 +61,12 @@ When you call ``.shutdown()`` method on a resource provider, it will remove the
if any, and switch to uninitialized state. Some of resource initializer types support specifying custom
resource shutdown.
-Resource provider supports 3 types of initializers:
+Resource provider supports 4 types of initializers:
- Function
-- Generator
-- Subclass of ``resources.Resource``
+- Context Manager
+- Generator (legacy)
+- Subclass of ``resources.Resource`` (legacy)
Function initializer
--------------------
@@ -103,8 +104,44 @@ you configure global resource:
Function initializer does not provide a way to specify custom resource shutdown.
-Generator initializer
----------------------
+Context Manager initializer
+---------------------------
+
+This is an extension to the Function initializer. Resource provider automatically detects if the initializer returns a
+context manager and uses it to manage the resource lifecycle.
+
+.. code-block:: python
+
+ from dependency_injector import containers, providers
+
+ class DatabaseConnection:
+ def __init__(self, host, port, user, password):
+ self.host = host
+ self.port = port
+ self.user = user
+ self.password = password
+
+ def __enter__(self):
+ print(f"Connecting to {self.host}:{self.port} as {self.user}")
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ print("Closing connection")
+
+
+ class Container(containers.DeclarativeContainer):
+
+ config = providers.Configuration()
+ db = providers.Resource(
+ DatabaseConnection,
+ host=config.db.host,
+ port=config.db.port,
+ user=config.db.user,
+ password=config.db.password,
+ )
+
+Generator initializer (legacy)
+------------------------------
Resource provider can use 2-step generators:
@@ -154,8 +191,13 @@ object is not mandatory. You can leave ``yield`` statement empty:
argument2=...,
)
-Subclass initializer
---------------------
+.. note::
+
+ Generator initializers are automatically wrapped with ``contextmanager`` or ``asynccontextmanager`` decorator when
+ provided to a ``Resource`` provider.
+
+Subclass initializer (legacy)
+-----------------------------
You can create resource initializer by implementing a subclass of the ``resources.Resource``:
@@ -263,10 +305,11 @@ Asynchronous function initializer:
argument2=...,
)
-Asynchronous generator initializer:
+Asynchronous Context Manager initializer:
.. code-block:: python
+ @asynccontextmanager
async def init_async_resource(argument1=..., argument2=...):
connection = await connect()
yield connection
@@ -358,5 +401,54 @@ See also:
- Wiring :ref:`async-injections-wiring`
- :ref:`fastapi-redis-example`
+ASGI Lifespan Protocol Support
+------------------------------
+
+The :mod:`dependency_injector.ext.starlette` module provides a :class:`~dependency_injector.ext.starlette.Lifespan`
+class that integrates resource providers with ASGI applications using the `Lifespan Protocol`_. This allows resources to
+be automatically initialized at application startup and properly shut down when the application stops.
+
+.. code-block:: python
+
+ from contextlib import asynccontextmanager
+ from dependency_injector import containers, providers
+ from dependency_injector.wiring import Provide, inject
+ from dependency_injector.ext.starlette import Lifespan
+ from fastapi import FastAPI, Request, Depends, APIRouter
+
+ class Connection: ...
+
+ @asynccontextmanager
+ async def init_database():
+ print("opening database connection")
+ yield Connection()
+ print("closing database connection")
+
+ router = APIRouter()
+
+ @router.get("/")
+ @inject
+ async def index(request: Request, db: Connection = Depends(Provide["db"])):
+ # use the database connection here
+ return "OK!"
+
+ class Container(containers.DeclarativeContainer):
+ __self__ = providers.Self()
+ db = providers.Resource(init_database)
+ lifespan = providers.Singleton(Lifespan, __self__)
+ app = providers.Singleton(FastAPI, lifespan=lifespan)
+ _include_router = providers.Resource(
+ app.provided.include_router.call(),
+ router,
+ )
+
+ if __name__ == "__main__":
+ import uvicorn
+
+ container = Container()
+ app = container.app()
+ uvicorn.run(app, host="localhost", port=8000)
+
+.. _Lifespan Protocol: https://asgi.readthedocs.io/en/latest/specs/lifespan.html
.. disqus::
diff --git a/docs/tutorials/aiohttp.rst b/docs/tutorials/aiohttp.rst
index 57b1c9599..10b33b495 100644
--- a/docs/tutorials/aiohttp.rst
+++ b/docs/tutorials/aiohttp.rst
@@ -257,7 +257,7 @@ Let's check that it works. Open another terminal session and use ``httpie``:
You should see:
-.. code-block:: json
+.. code-block:: http
HTTP/1.1 200 OK
Content-Length: 844
@@ -596,7 +596,7 @@ and make a request to the API in the terminal:
You should see:
-.. code-block:: json
+.. code-block:: http
HTTP/1.1 200 OK
Content-Length: 492
diff --git a/docs/tutorials/cli.rst b/docs/tutorials/cli.rst
index 88014ff34..ea3c84675 100644
--- a/docs/tutorials/cli.rst
+++ b/docs/tutorials/cli.rst
@@ -84,7 +84,7 @@ Create next structure in the project root directory. All files are empty. That's
Initial project layout:
-.. code-block:: bash
+.. code-block:: text
./
├── movies/
@@ -109,7 +109,7 @@ Now it's time to install the project requirements. We will use next packages:
Put next lines into the ``requirements.txt`` file:
-.. code-block:: bash
+.. code-block:: text
dependency-injector
pyyaml
@@ -134,7 +134,7 @@ We will create a script that creates database files.
First add the folder ``data/`` in the root of the project and then add the file
``fixtures.py`` inside of it:
-.. code-block:: bash
+.. code-block:: text
:emphasize-lines: 2-3
./
@@ -205,13 +205,13 @@ Now run in the terminal:
You should see:
-.. code-block:: bash
+.. code-block:: text
OK
Check that files ``movies.csv`` and ``movies.db`` have appeared in the ``data/`` folder:
-.. code-block:: bash
+.. code-block:: text
:emphasize-lines: 4-5
./
@@ -289,7 +289,7 @@ After each step we will add the provider to the container.
Create the ``entities.py`` in the ``movies`` package:
-.. code-block:: bash
+.. code-block:: text
:emphasize-lines: 10
./
@@ -356,7 +356,7 @@ Let's move on to the finders.
Create the ``finders.py`` in the ``movies`` package:
-.. code-block:: bash
+.. code-block:: text
:emphasize-lines: 11
./
@@ -465,7 +465,7 @@ The configuration file is ready. Move on to the lister.
Create the ``listers.py`` in the ``movies`` package:
-.. code-block:: bash
+.. code-block:: text
:emphasize-lines: 12
./
@@ -613,7 +613,7 @@ Run in the terminal:
You should see:
-.. code-block:: plain
+.. code-block:: text
Francis Lawrence movies:
- Movie(title='The Hunger Games: Mockingjay - Part 2', year=2015, director='Francis Lawrence')
@@ -752,7 +752,7 @@ Run in the terminal:
You should see:
-.. code-block:: plain
+.. code-block:: text
Francis Lawrence movies:
- Movie(title='The Hunger Games: Mockingjay - Part 2', year=2015, director='Francis Lawrence')
@@ -868,7 +868,7 @@ Run in the terminal line by line:
The output should be similar for each command:
-.. code-block:: plain
+.. code-block:: text
Francis Lawrence movies:
- Movie(title='The Hunger Games: Mockingjay - Part 2', year=2015, director='Francis Lawrence')
@@ -888,7 +888,7 @@ We will use `pytest `_ and
Create ``tests.py`` in the ``movies`` package:
-.. code-block:: bash
+.. code-block:: text
:emphasize-lines: 13
./
@@ -977,7 +977,7 @@ Run in the terminal:
You should see:
-.. code-block::
+.. code-block:: text
platform darwin -- Python 3.10.0, pytest-6.2.5, py-1.10.0, pluggy-1.0.0
plugins: cov-3.0.0
diff --git a/docs/tutorials/flask.rst b/docs/tutorials/flask.rst
index 8c22aa5a0..b8f81adf3 100644
--- a/docs/tutorials/flask.rst
+++ b/docs/tutorials/flask.rst
@@ -280,7 +280,7 @@ Now let's fill in the layout.
Put next into the ``base.html``:
-.. code-block:: html
+.. code-block:: jinja
@@ -313,7 +313,7 @@ And put something to the index page.
Put next into the ``index.html``:
-.. code-block:: html
+.. code-block:: jinja
{% extends "base.html" %}
diff --git a/docs/wiring.rst b/docs/wiring.rst
index 02f64c60d..912b320e4 100644
--- a/docs/wiring.rst
+++ b/docs/wiring.rst
@@ -127,6 +127,7 @@ To inject the provider itself use ``Provide[foo.provider]``:
def foo(bar_provider: Factory[Bar] = Provide[Container.bar.provider]):
bar = bar_provider(argument="baz")
...
+
You can also use ``Provider[foo]`` for injecting the provider itself:
.. code-block:: python
@@ -631,6 +632,36 @@ or with a single container ``register_loader_containers(container)`` multiple ti
To unregister a container use ``unregister_loader_containers(container)``.
Wiring module will uninstall the import hook when unregister last container.
+Few notes on performance
+------------------------
+
+``.wire()`` utilize caching to speed up the wiring process. At the end it clears the cache to avoid memory leaks.
+But this may not always be desirable, when you want to keep the cache for the next wiring
+(e.g. due to usage of multiple containers or during unit tests).
+
+To keep the cache after wiring, you can set flag ``keep_cache=True`` (works with ``WiringConfiguration`` too):
+
+.. code-block:: python
+
+ container1.wire(
+ modules=["yourapp.module1", "yourapp.module2"],
+ keep_cache=True,
+ )
+ container2.wire(
+ modules=["yourapp.module2", "yourapp.module3"],
+ keep_cache=True,
+ )
+ ...
+
+and then clear it manually when you need it:
+
+.. code-block:: python
+
+ from dependency_injector.wiring import clear_cache
+
+ clear_cache()
+
+
Integration with other frameworks
---------------------------------
diff --git a/examples/miniapps/movie-lister/data/fixtures.py b/examples/miniapps/movie-lister/data/fixtures.py
index aa1691d5a..0153e0cf7 100644
--- a/examples/miniapps/movie-lister/data/fixtures.py
+++ b/examples/miniapps/movie-lister/data/fixtures.py
@@ -18,10 +18,9 @@
def create_csv(movies_data, path):
- with open(path, "w") as opened_file:
+ with open(path, "w", newline="") as opened_file:
writer = csv.writer(opened_file)
- for row in movies_data:
- writer.writerow(row)
+ writer.writerows(movies_data)
def create_sqlite(movies_data, path):
diff --git a/examples/miniapps/movie-lister/movies/finders.py b/examples/miniapps/movie-lister/movies/finders.py
index 52b8ed555..5e6d2c9c0 100644
--- a/examples/miniapps/movie-lister/movies/finders.py
+++ b/examples/miniapps/movie-lister/movies/finders.py
@@ -29,7 +29,7 @@ def __init__(
super().__init__(movie_factory)
def find_all(self) -> List[Movie]:
- with open(self._csv_file_path) as csv_file:
+ with open(self._csv_file_path, newline="") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=self._delimiter)
return [self._movie_factory(*row) for row in csv_reader]
diff --git a/examples/providers/resource.py b/examples/providers/resource.py
index 2079a929c..c712468a8 100644
--- a/examples/providers/resource.py
+++ b/examples/providers/resource.py
@@ -3,10 +3,12 @@
import sys
import logging
from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager
from dependency_injector import containers, providers
+@contextmanager
def init_thread_pool(max_workers: int):
thread_pool = ThreadPoolExecutor(max_workers=max_workers)
yield thread_pool
diff --git a/pyproject.toml b/pyproject.toml
index 7512cb94c..885531786 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -91,6 +91,7 @@ show_missing = true
[tool.isort]
profile = "black"
+combine_as_imports = true
[tool.pylint.main]
ignore = ["tests"]
diff --git a/src/dependency_injector/__init__.pyi b/src/dependency_injector/__init__.pyi
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/dependency_injector/_cwiring.pyi b/src/dependency_injector/_cwiring.pyi
index e7ff12f4e..c779b8c45 100644
--- a/src/dependency_injector/_cwiring.pyi
+++ b/src/dependency_injector/_cwiring.pyi
@@ -1,23 +1,18 @@
-from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar
+from typing import Any, Dict
from .providers import Provider
-T = TypeVar("T")
+class DependencyResolver:
+ def __init__(
+ self,
+ kwargs: Dict[str, Any],
+ injections: Dict[str, Provider[Any]],
+ closings: Dict[str, Provider[Any]],
+ /,
+ ) -> None: ...
+ def __enter__(self) -> Dict[str, Any]: ...
+ def __exit__(self, *exc_info: Any) -> None: ...
+ async def __aenter__(self) -> Dict[str, Any]: ...
+ async def __aexit__(self, *exc_info: Any) -> None: ...
-def _sync_inject(
- fn: Callable[..., T],
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- injections: Dict[str, Provider[Any]],
- closings: Dict[str, Provider[Any]],
- /,
-) -> T: ...
-async def _async_inject(
- fn: Callable[..., Awaitable[T]],
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- injections: Dict[str, Provider[Any]],
- closings: Dict[str, Provider[Any]],
- /,
-) -> T: ...
def _isawaitable(instance: Any) -> bool: ...
diff --git a/src/dependency_injector/_cwiring.pyx b/src/dependency_injector/_cwiring.pyx
index 84a5485f8..3e2775c7c 100644
--- a/src/dependency_injector/_cwiring.pyx
+++ b/src/dependency_injector/_cwiring.pyx
@@ -1,83 +1,110 @@
"""Wiring optimizations module."""
-import asyncio
-import collections.abc
-import inspect
-import types
+from asyncio import gather
+from collections.abc import Awaitable
+from inspect import CO_ITERABLE_COROUTINE
+from types import CoroutineType, GeneratorType
+from .providers cimport Provider, Resource, NULL_AWAITABLE
from .wiring import _Marker
-from .providers cimport Provider, Resource
+cimport cython
-def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
- cdef object result
+@cython.internal
+@cython.no_gc
+cdef class KWPair:
+ cdef str name
+ cdef object value
+
+ def __cinit__(self, str name, object value, /):
+ self.name = name
+ self.value = value
+
+
+cdef inline bint _is_injectable(dict kwargs, str name):
+ return name not in kwargs or isinstance(kwargs[name], _Marker)
+
+
+cdef class DependencyResolver:
+ cdef dict kwargs
cdef dict to_inject
- cdef object arg_key
- cdef Provider provider
+ cdef dict injections
+ cdef dict closings
- to_inject = kwargs.copy()
- for arg_key, provider in injections.items():
- if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
- to_inject[arg_key] = provider()
+ def __init__(self, dict kwargs, dict injections, dict closings, /):
+ self.kwargs = kwargs
+ self.to_inject = kwargs.copy()
+ self.injections = injections
+ self.closings = closings
- result = fn(*args, **to_inject)
+ async def _await_injection(self, kw_pair: KWPair, /) -> None:
+ self.to_inject[kw_pair.name] = await kw_pair.value
- if closings:
- for arg_key, provider in closings.items():
- if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
- continue
- if not isinstance(provider, Resource):
- continue
- provider.shutdown()
+ cdef object _await_injections(self, to_await: list):
+ return gather(*map(self._await_injection, to_await))
- return result
+ cdef void _handle_injections_sync(self):
+ cdef Provider provider
+ for name, provider in self.injections.items():
+ if _is_injectable(self.kwargs, name):
+ self.to_inject[name] = provider()
-async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
- cdef object result
- cdef dict to_inject
- cdef list to_inject_await = []
- cdef list to_close_await = []
- cdef object arg_key
- cdef Provider provider
-
- to_inject = kwargs.copy()
- for arg_key, provider in injections.items():
- if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
- provide = provider()
- if provider.is_async_mode_enabled():
- to_inject_await.append((arg_key, provide))
- elif _isawaitable(provide):
- to_inject_await.append((arg_key, provide))
- else:
- to_inject[arg_key] = provide
-
- if to_inject_await:
- async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await))
- for provide, (injection, _) in zip(async_to_inject, to_inject_await):
- to_inject[injection] = provide
-
- result = await fn(*args, **to_inject)
-
- if closings:
- for arg_key, provider in closings.items():
- if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
- continue
- if not isinstance(provider, Resource):
- continue
- shutdown = provider.shutdown()
- if _isawaitable(shutdown):
- to_close_await.append(shutdown)
-
- await asyncio.gather(*to_close_await)
-
- return result
+ cdef list _handle_injections_async(self):
+ cdef list to_await = []
+ cdef Provider provider
+
+ for name, provider in self.injections.items():
+ if _is_injectable(self.kwargs, name):
+ provide = provider()
+
+ if provider.is_async_mode_enabled() or _isawaitable(provide):
+ to_await.append(KWPair(name, provide))
+ else:
+ self.to_inject[name] = provide
+
+ return to_await
+
+ cdef void _handle_closings_sync(self):
+ cdef Provider provider
+
+ for name, provider in self.closings.items():
+ if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
+ provider.shutdown()
+
+ cdef list _handle_closings_async(self):
+ cdef list to_await = []
+ cdef Provider provider
+
+ for name, provider in self.closings.items():
+ if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
+ if _isawaitable(shutdown := provider.shutdown()):
+ to_await.append(shutdown)
+
+ return to_await
+
+ def __enter__(self):
+ self._handle_injections_sync()
+ return self.to_inject
+
+ def __exit__(self, *_):
+ self._handle_closings_sync()
+
+ async def __aenter__(self):
+ if to_await := self._handle_injections_async():
+ await self._await_injections(to_await)
+ return self.to_inject
+
+ def __aexit__(self, *_):
+ if to_await := self._handle_closings_async():
+ return gather(*to_await)
+ return NULL_AWAITABLE
cdef bint _isawaitable(object instance):
"""Return true if object can be passed to an ``await`` expression."""
- return (isinstance(instance, types.CoroutineType) or
- isinstance(instance, types.GeneratorType) and
- bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or
- isinstance(instance, collections.abc.Awaitable))
+ return (isinstance(instance, CoroutineType) or
+ isinstance(instance, GeneratorType) and
+ bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or
+ isinstance(instance, Awaitable))
diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi
index ec41ea8e6..ca608f28d 100644
--- a/src/dependency_injector/containers.pyi
+++ b/src/dependency_injector/containers.pyi
@@ -1,23 +1,28 @@
from pathlib import Path
from typing import (
- Generic,
- Type,
- Dict,
- List,
- Tuple,
- Optional,
Any,
- Union,
- ClassVar,
+ Awaitable,
Callable as _Callable,
+ ClassVar,
+ Dict,
+ Generic,
Iterable,
Iterator,
+ List,
+ Optional,
+ Tuple,
+ Type,
TypeVar,
- Awaitable,
+ Union,
overload,
)
-from .providers import Provider, Self, ProviderParent
+try:
+ from typing import Self as _Self
+except ImportError:
+ from typing_extensions import Self as _Self
+
+from .providers import Provider, ProviderParent, Self
C_Base = TypeVar("C_Base", bound="Container")
C = TypeVar("C", bound="DeclarativeContainer")
@@ -30,32 +35,34 @@ class WiringConfiguration:
packages: List[Any]
from_package: Optional[str]
auto_wire: bool
+ keep_cache: bool
def __init__(
self,
modules: Optional[Iterable[Any]] = None,
packages: Optional[Iterable[Any]] = None,
from_package: Optional[str] = None,
auto_wire: bool = True,
+ keep_cache: bool = False,
) -> None: ...
class Container:
- provider_type: Type[Provider] = Provider
- providers: Dict[str, Provider]
+ provider_type: Type[Provider[Any]] = Provider
+ providers: Dict[str, Provider[Any]]
dependencies: Dict[str, Provider[Any]]
- overridden: Tuple[Provider]
+ overridden: Tuple[Provider[Any], ...]
wiring_config: WiringConfiguration
auto_load_config: bool = True
__self__: Self
def __init__(self) -> None: ...
- def __deepcopy__(self, memo: Optional[Dict[str, Any]]) -> Provider: ...
- def __setattr__(self, name: str, value: Union[Provider, Any]) -> None: ...
- def __getattr__(self, name: str) -> Provider: ...
+ def __deepcopy__(self, memo: Optional[Dict[str, Any]]) -> _Self: ...
+ def __setattr__(self, name: str, value: Union[Provider[Any], Any]) -> None: ...
+ def __getattr__(self, name: str) -> Provider[Any]: ...
def __delattr__(self, name: str) -> None: ...
- def set_providers(self, **providers: Provider): ...
- def set_provider(self, name: str, provider: Provider) -> None: ...
+ def set_providers(self, **providers: Provider[Any]) -> None: ...
+ def set_provider(self, name: str, provider: Provider[Any]) -> None: ...
def override(self, overriding: Union[Container, Type[Container]]) -> None: ...
def override_providers(
- self, **overriding_providers: Union[Provider, Any]
+ self, **overriding_providers: Union[Provider[Any], Any]
) -> ProvidersOverridingContext[C_Base]: ...
def reset_last_overriding(self) -> None: ...
def reset_override(self) -> None: ...
@@ -67,8 +74,8 @@ class Container:
from_package: Optional[str] = None,
) -> None: ...
def unwire(self) -> None: ...
- def init_resources(self) -> Optional[Awaitable]: ...
- def shutdown_resources(self) -> Optional[Awaitable]: ...
+ def init_resources(self) -> Optional[Awaitable[None]]: ...
+ def shutdown_resources(self) -> Optional[Awaitable[None]]: ...
def load_config(self) -> None: ...
def apply_container_providers_overridings(self) -> None: ...
def reset_singletons(self) -> SingletonResetContext[C_Base]: ...
@@ -79,10 +86,10 @@ class Container:
) -> None: ...
def from_json_schema(self, filepath: Union[Path, str]) -> None: ...
@overload
- def resolve_provider_name(self, provider: Provider) -> str: ...
+ def resolve_provider_name(self, provider: Provider[Any]) -> str: ...
@classmethod
@overload
- def resolve_provider_name(cls, provider: Provider) -> str: ...
+ def resolve_provider_name(cls, provider: Provider[Any]) -> str: ...
@property
def parent(self) -> Optional[ProviderParent]: ...
@property
@@ -97,14 +104,14 @@ class Container:
class DynamicContainer(Container): ...
class DeclarativeContainer(Container):
- cls_providers: ClassVar[Dict[str, Provider]]
- inherited_providers: ClassVar[Dict[str, Provider]]
- def __init__(self, **overriding_providers: Union[Provider, Any]) -> None: ...
+ cls_providers: ClassVar[Dict[str, Provider[Any]]]
+ inherited_providers: ClassVar[Dict[str, Provider[Any]]]
+ def __init__(self, **overriding_providers: Union[Provider[Any], Any]) -> None: ...
@classmethod
def override(cls, overriding: Union[Container, Type[Container]]) -> None: ...
@classmethod
def override_providers(
- cls, **overriding_providers: Union[Provider, Any]
+ cls, **overriding_providers: Union[Provider[Any], Any]
) -> ProvidersOverridingContext[C_Base]: ...
@classmethod
def reset_last_overriding(cls) -> None: ...
@@ -113,7 +120,7 @@ class DeclarativeContainer(Container):
class ProvidersOverridingContext(Generic[T]):
def __init__(
- self, container: T, overridden_providers: Iterable[Union[Provider, Any]]
+ self, container: T, overridden_providers: Iterable[Union[Provider[Any], Any]]
) -> None: ...
def __enter__(self) -> T: ...
def __exit__(self, *_: Any) -> None: ...
diff --git a/src/dependency_injector/containers.pyx b/src/dependency_injector/containers.pyx
index 2f4c4af58..bd0a4821b 100644
--- a/src/dependency_injector/containers.pyx
+++ b/src/dependency_injector/containers.pyx
@@ -20,14 +20,15 @@ from .wiring import wire, unwire
class WiringConfiguration:
"""Container wiring configuration."""
- def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True):
+ def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True, keep_cache=False):
self.modules = [*modules] if modules else []
self.packages = [*packages] if packages else []
self.from_package = from_package
self.auto_wire = auto_wire
+ self.keep_cache = keep_cache
def __deepcopy__(self, memo=None):
- return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire)
+ return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire, self.keep_cache)
class Container:
@@ -258,7 +259,7 @@ class DynamicContainer(Container):
"""Check if auto wiring is needed."""
return self.wiring_config.auto_wire is True
- def wire(self, modules=None, packages=None, from_package=None):
+ def wire(self, modules=None, packages=None, from_package=None, keep_cache=None):
"""Wire container providers with provided packages and modules.
:rtype: None
@@ -289,10 +290,14 @@ class DynamicContainer(Container):
if not modules and not packages:
return
+ if keep_cache is None:
+ keep_cache = self.wiring_config.keep_cache
+
wire(
container=self,
modules=modules,
packages=packages,
+ keep_cache=keep_cache,
)
if modules:
diff --git a/src/dependency_injector/ext/aiohttp.py b/src/dependency_injector/ext/aiohttp.py
index 976089c32..43990a7d3 100644
--- a/src/dependency_injector/ext/aiohttp.py
+++ b/src/dependency_injector/ext/aiohttp.py
@@ -7,7 +7,6 @@
from dependency_injector import providers
-
warnings.warn(
'Module "dependency_injector.ext.aiohttp" is deprecated since '
'version 4.0.0. Use "dependency_injector.wiring" module instead.',
diff --git a/src/dependency_injector/ext/aiohttp.pyi b/src/dependency_injector/ext/aiohttp.pyi
index 370cc9b00..c524712c5 100644
--- a/src/dependency_injector/ext/aiohttp.pyi
+++ b/src/dependency_injector/ext/aiohttp.pyi
@@ -1,14 +1,16 @@
-from typing import Awaitable as _Awaitable
+from typing import Any, Awaitable as _Awaitable, TypeVar
from dependency_injector import providers
-class Application(providers.Singleton): ...
-class Extension(providers.Singleton): ...
-class Middleware(providers.DelegatedCallable): ...
-class MiddlewareFactory(providers.Factory): ...
+T = TypeVar("T")
-class View(providers.Callable):
- def as_view(self) -> _Awaitable: ...
+class Application(providers.Singleton[T]): ...
+class Extension(providers.Singleton[T]): ...
+class Middleware(providers.DelegatedCallable[T]): ...
+class MiddlewareFactory(providers.Factory[T]): ...
-class ClassBasedView(providers.Factory):
- def as_view(self) -> _Awaitable: ...
+class View(providers.Callable[T]):
+ def as_view(self) -> _Awaitable[T]: ...
+
+class ClassBasedView(providers.Factory[T]):
+ def as_view(self) -> _Awaitable[T]: ...
diff --git a/src/dependency_injector/ext/flask.py b/src/dependency_injector/ext/flask.py
index 498a9eee4..15b9df0aa 100644
--- a/src/dependency_injector/ext/flask.py
+++ b/src/dependency_injector/ext/flask.py
@@ -1,12 +1,12 @@
"""Flask extension module."""
from __future__ import absolute_import
+
import warnings
from flask import request as flask_request
-from dependency_injector import providers, errors
-
+from dependency_injector import errors, providers
warnings.warn(
'Module "dependency_injector.ext.flask" is deprecated since '
diff --git a/src/dependency_injector/ext/flask.pyi b/src/dependency_injector/ext/flask.pyi
index 9b180c895..1c791b88d 100644
--- a/src/dependency_injector/ext/flask.pyi
+++ b/src/dependency_injector/ext/flask.pyi
@@ -1,19 +1,21 @@
-from typing import Union, Optional, Callable as _Callable, Any
+from typing import Any, Callable as _Callable, Optional, TypeVar, Union
+
+from flask.wrappers import Request
-from flask import request as flask_request
from dependency_injector import providers
-request: providers.Object[flask_request]
+request: providers.Object[Request]
+T = TypeVar("T")
-class Application(providers.Singleton): ...
-class Extension(providers.Singleton): ...
+class Application(providers.Singleton[T]): ...
+class Extension(providers.Singleton[T]): ...
-class View(providers.Callable):
- def as_view(self) -> _Callable[..., Any]: ...
+class View(providers.Callable[T]):
+ def as_view(self) -> _Callable[..., T]: ...
-class ClassBasedView(providers.Factory):
- def as_view(self, name: str) -> _Callable[..., Any]: ...
+class ClassBasedView(providers.Factory[T]):
+ def as_view(self, name: str) -> _Callable[..., T]: ...
def as_view(
- provider: Union[View, ClassBasedView], name: Optional[str] = None
-) -> _Callable[..., Any]: ...
+ provider: Union[View[T], ClassBasedView[T]], name: Optional[str] = None
+) -> _Callable[..., T]: ...
diff --git a/src/dependency_injector/providers.pxd b/src/dependency_injector/providers.pxd
index b4eb471de..21ed7f229 100644
--- a/src/dependency_injector/providers.pxd
+++ b/src/dependency_injector/providers.pxd
@@ -697,3 +697,10 @@ cdef inline object __future_result(object instance):
future_result = asyncio.Future()
future_result.set_result(instance)
return future_result
+
+
+cdef class NullAwaitable:
+ pass
+
+
+cdef NullAwaitable NULL_AWAITABLE
diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx
index d276903b1..43e49d7e3 100644
--- a/src/dependency_injector/providers.pyx
+++ b/src/dependency_injector/providers.pyx
@@ -15,8 +15,11 @@ import re
import sys
import threading
import warnings
+from asyncio import ensure_future
from configparser import ConfigParser as IniConfigParser
+from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
+from inspect import isasyncgenfunction, isgeneratorfunction
try:
from inspect import _is_coroutine_mark as _is_coroutine_marker
@@ -3598,6 +3601,17 @@ cdef class Dict(Provider):
return __provide_keyword_args(kwargs, self._kwargs, self._kwargs_len, self._async_mode)
+@cython.no_gc
+cdef class NullAwaitable:
+ def __next__(self):
+ raise StopIteration from None
+
+ def __await__(self):
+ return self
+
+
+cdef NullAwaitable NULL_AWAITABLE = NullAwaitable()
+
cdef class Resource(Provider):
"""Resource provider provides a component with initialization and shutdown."""
@@ -3653,6 +3667,12 @@ cdef class Resource(Provider):
def set_provides(self, provides):
"""Set provider provides."""
provides = _resolve_string_import(provides)
+
+ if isasyncgenfunction(provides):
+ provides = asynccontextmanager(provides)
+ elif isgeneratorfunction(provides):
+ provides = contextmanager(provides)
+
self._provides = provides
return self
@@ -3753,28 +3773,21 @@ cdef class Resource(Provider):
"""Shutdown resource."""
if not self._initialized:
if self._async_mode == ASYNC_MODE_ENABLED:
- result = asyncio.Future()
- result.set_result(None)
- return result
+ return NULL_AWAITABLE
return
if self._shutdowner:
- try:
- shutdown = self._shutdowner(self._resource)
- except StopIteration:
- pass
- else:
- if inspect.isawaitable(shutdown):
- return self._create_shutdown_future(shutdown)
+ future = self._shutdowner(None, None, None)
+
+ if __is_future_or_coroutine(future):
+ return ensure_future(self._shutdown_async(future))
self._resource = None
self._initialized = False
self._shutdowner = None
if self._async_mode == ASYNC_MODE_ENABLED:
- result = asyncio.Future()
- result.set_result(None)
- return result
+ return NULL_AWAITABLE
@property
def related(self):
@@ -3784,165 +3797,75 @@ cdef class Resource(Provider):
yield from filter(is_provider, self.kwargs.values())
yield from super().related
+ async def _shutdown_async(self, future) -> None:
+ try:
+ await future
+ finally:
+ self._resource = None
+ self._initialized = False
+ self._shutdowner = None
+
+ async def _handle_async_cm(self, obj) -> None:
+ try:
+ self._resource = resource = await obj.__aenter__()
+ self._shutdowner = obj.__aexit__
+ return resource
+ except:
+ self._initialized = False
+ raise
+
+ async def _provide_async(self, future) -> None:
+ try:
+ obj = await future
+
+ if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
+ self._resource = await obj.__aenter__()
+ self._shutdowner = obj.__aexit__
+ elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
+ self._resource = obj.__enter__()
+ self._shutdowner = obj.__exit__
+ else:
+ self._resource = obj
+ self._shutdowner = None
+
+ return self._resource
+ except:
+ self._initialized = False
+ raise
+
cpdef object _provide(self, tuple args, dict kwargs):
if self._initialized:
return self._resource
- if self._is_resource_subclass(self._provides):
- initializer = self._provides()
- self._resource = __call(
- initializer.init,
- args,
- self._args,
- self._args_len,
- kwargs,
- self._kwargs,
- self._kwargs_len,
- self._async_mode,
- )
- self._shutdowner = initializer.shutdown
- elif self._is_async_resource_subclass(self._provides):
- initializer = self._provides()
- async_init = __call(
- initializer.init,
- args,
- self._args,
- self._args_len,
- kwargs,
- self._kwargs,
- self._kwargs_len,
- self._async_mode,
- )
- self._initialized = True
- return self._create_init_future(async_init, initializer.shutdown)
- elif inspect.isgeneratorfunction(self._provides):
- initializer = __call(
- self._provides,
- args,
- self._args,
- self._args_len,
- kwargs,
- self._kwargs,
- self._kwargs_len,
- self._async_mode,
- )
- self._resource = next(initializer)
- self._shutdowner = initializer.send
- elif iscoroutinefunction(self._provides):
- initializer = __call(
- self._provides,
- args,
- self._args,
- self._args_len,
- kwargs,
- self._kwargs,
- self._kwargs_len,
- self._async_mode,
- )
+ obj = __call(
+ self._provides,
+ args,
+ self._args,
+ self._args_len,
+ kwargs,
+ self._kwargs,
+ self._kwargs_len,
+ self._async_mode,
+ )
+
+ if __is_future_or_coroutine(obj):
self._initialized = True
- return self._create_init_future(initializer)
- elif isasyncgenfunction(self._provides):
- initializer = __call(
- self._provides,
- args,
- self._args,
- self._args_len,
- kwargs,
- self._kwargs,
- self._kwargs_len,
- self._async_mode,
- )
+ self._resource = resource = ensure_future(self._provide_async(obj))
+ return resource
+ elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
+ self._resource = obj.__enter__()
+ self._shutdowner = obj.__exit__
+ elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._initialized = True
- return self._create_async_gen_init_future(initializer)
- elif callable(self._provides):
- self._resource = __call(
- self._provides,
- args,
- self._args,
- self._args_len,
- kwargs,
- self._kwargs,
- self._kwargs_len,
- self._async_mode,
- )
+ self._resource = resource = ensure_future(self._handle_async_cm(obj))
+ return resource
else:
- raise Error("Unknown type of resource initializer")
+ self._resource = obj
+ self._shutdowner = None
self._initialized = True
return self._resource
- def _create_init_future(self, future, shutdowner=None):
- callback = self._async_init_callback
- if shutdowner:
- callback = functools.partial(callback, shutdowner=shutdowner)
-
- future = asyncio.ensure_future(future)
- future.add_done_callback(callback)
- self._resource = future
-
- return future
-
- def _create_async_gen_init_future(self, initializer):
- if inspect.isasyncgen(initializer):
- return self._create_init_future(initializer.__anext__(), initializer.asend)
-
- future = asyncio.Future()
-
- create_initializer = asyncio.ensure_future(initializer)
- create_initializer.add_done_callback(functools.partial(self._async_create_gen_callback, future))
- self._resource = future
-
- return future
-
- def _async_init_callback(self, initializer, shutdowner=None):
- try:
- resource = initializer.result()
- except Exception:
- self._initialized = False
- else:
- self._resource = resource
- self._shutdowner = shutdowner
-
- def _async_create_gen_callback(self, future, initializer_future):
- initializer = initializer_future.result()
- init_future = self._create_init_future(initializer.__anext__(), initializer.asend)
- init_future.add_done_callback(functools.partial(self._async_trigger_result, future))
-
- def _async_trigger_result(self, future, future_result):
- future.set_result(future_result.result())
-
- def _create_shutdown_future(self, shutdown_future):
- future = asyncio.Future()
- shutdown_future = asyncio.ensure_future(shutdown_future)
- shutdown_future.add_done_callback(functools.partial(self._async_shutdown_callback, future))
- return future
-
- def _async_shutdown_callback(self, future_result, shutdowner):
- try:
- shutdowner.result()
- except StopAsyncIteration:
- pass
-
- self._resource = None
- self._initialized = False
- self._shutdowner = None
-
- future_result.set_result(None)
-
- @staticmethod
- def _is_resource_subclass(instance):
- if not isinstance(instance, type):
- return
- from . import resources
- return issubclass(instance, resources.Resource)
-
- @staticmethod
- def _is_async_resource_subclass(instance):
- if not isinstance(instance, type):
- return
- from . import resources
- return issubclass(instance, resources.AsyncResource)
-
cdef class Container(Provider):
"""Container provider provides an instance of declarative container.
@@ -4993,14 +4916,6 @@ def iscoroutinefunction(obj):
return False
-def isasyncgenfunction(obj):
- """Check if object is an asynchronous generator function."""
- try:
- return inspect.isasyncgenfunction(obj)
- except AttributeError:
- return False
-
-
def _resolve_string_import(provides):
if provides is None:
return provides
diff --git a/src/dependency_injector/resources.py b/src/dependency_injector/resources.py
index 7d71d4d82..8722af229 100644
--- a/src/dependency_injector/resources.py
+++ b/src/dependency_injector/resources.py
@@ -1,23 +1,54 @@
"""Resources module."""
-import abc
-from typing import TypeVar, Generic, Optional
-
+from abc import ABCMeta, abstractmethod
+from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar
T = TypeVar("T")
-class Resource(Generic[T], metaclass=abc.ABCMeta):
+class Resource(Generic[T], metaclass=ABCMeta):
+ __slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
+
+ obj: Optional[T]
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.args = args
+ self.kwargs = kwargs
+ self.obj = None
- @abc.abstractmethod
- def init(self, *args, **kwargs) -> Optional[T]: ...
+ @abstractmethod
+ def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
def shutdown(self, resource: Optional[T]) -> None: ...
+ def __enter__(self) -> Optional[T]:
+ self.obj = obj = self.init(*self.args, **self.kwargs)
+ return obj
+
+ def __exit__(self, *exc_info: Any) -> None:
+ self.shutdown(self.obj)
+ self.obj = None
+
-class AsyncResource(Generic[T], metaclass=abc.ABCMeta):
+class AsyncResource(Generic[T], metaclass=ABCMeta):
+ __slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
- @abc.abstractmethod
- async def init(self, *args, **kwargs) -> Optional[T]: ...
+ obj: Optional[T]
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.args = args
+ self.kwargs = kwargs
+ self.obj = None
+
+ @abstractmethod
+ async def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
async def shutdown(self, resource: Optional[T]) -> None: ...
+
+ async def __aenter__(self) -> Optional[T]:
+ self.obj = obj = await self.init(*self.args, **self.kwargs)
+ return obj
+
+ async def __aexit__(self, *exc_info: Any) -> None:
+ await self.shutdown(self.obj)
+ self.obj = None
diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py
index b8534ee52..0477eed48 100644
--- a/src/dependency_injector/wiring.py
+++ b/src/dependency_injector/wiring.py
@@ -10,6 +10,7 @@
from typing import (
TYPE_CHECKING,
Any,
+ AsyncIterator,
Callable,
Dict,
Iterable,
@@ -24,7 +25,17 @@
cast,
)
-from typing_extensions import Self
+try:
+ from typing import Self
+except ImportError:
+ from typing_extensions import Self
+
+try:
+ from functools import cache
+except ImportError:
+ from functools import lru_cache
+
+ cache = lru_cache(maxsize=None)
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
if sys.version_info >= (3, 9):
@@ -409,6 +420,7 @@ def wire( # noqa: C901
*,
modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None,
+ keep_cache: bool = False,
) -> None:
"""Wire container providers with provided packages and modules."""
modules = [*modules] if modules else []
@@ -449,6 +461,9 @@ def wire( # noqa: C901
for patched in _patched_registry.get_callables_from_module(module):
_bind_injections(patched, providers_map)
+ if not keep_cache:
+ clear_cache()
+
def unwire( # noqa: C901
*,
@@ -604,6 +619,7 @@ def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
return marker
+@cache
def _fetch_reference_injections( # noqa: C901
fn: Callable[..., Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@@ -708,6 +724,8 @@ def _get_patched(
if inspect.iscoroutinefunction(fn):
patched = _get_async_patched(fn, patched_object)
+ elif inspect.isasyncgenfunction(fn):
+ patched = _get_async_gen_patched(fn, patched_object)
else:
patched = _get_sync_patched(fn, patched_object)
@@ -1023,36 +1041,41 @@ def is_loader_installed() -> bool:
_loader = AutoLoader()
# Optimizations
-from ._cwiring import _async_inject # noqa
-from ._cwiring import _sync_inject # noqa
+from ._cwiring import DependencyResolver # noqa: E402
# Wiring uses the following Python wrapper because there is
# no possibility to compile a first-type citizen coroutine in Cython.
def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
- async def _patched(*args, **kwargs):
- return await _async_inject(
- fn,
- args,
- kwargs,
- patched.injections,
- patched.closing,
- )
+ async def _patched(*args: Any, **raw_kwargs: Any) -> Any:
+ resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
+
+ async with resolver as kwargs:
+ return await fn(*args, **kwargs)
+
+ return cast(F, _patched)
+
+
+def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F:
+ @functools.wraps(fn)
+ async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]:
+ resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
+
+ async with resolver as kwargs:
+ async for obj in fn(*args, **kwargs):
+ yield obj
return cast(F, _patched)
def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
- def _patched(*args, **kwargs):
- return _sync_inject(
- fn,
- args,
- kwargs,
- patched.injections,
- patched.closing,
- )
+ def _patched(*args: Any, **raw_kwargs: Any) -> Any:
+ resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
+
+ with resolver as kwargs:
+ return fn(*args, **kwargs)
return cast(F, _patched)
@@ -1078,3 +1101,8 @@ def _get_members_and_annotated(obj: Any) -> Iterable[Tuple[str, Any]]:
member = args[1]
members.append((annotation_name, member))
return members
+
+
+def clear_cache() -> None:
+ """Clear all caches used by :func:`wire`."""
+ _fetch_reference_injections.cache_clear()
diff --git a/tests/unit/providers/resource/test_async_resource_py35.py b/tests/unit/providers/resource/test_async_resource_py35.py
index 1ca950a88..6458584dd 100644
--- a/tests/unit/providers/resource/test_async_resource_py35.py
+++ b/tests/unit/providers/resource/test_async_resource_py35.py
@@ -2,12 +2,13 @@
import asyncio
import inspect
-import sys
+from contextlib import asynccontextmanager
from typing import Any
-from dependency_injector import containers, providers, resources
from pytest import mark, raises
+from dependency_injector import containers, providers, resources
+
@mark.asyncio
async def test_init_async_function():
@@ -70,6 +71,46 @@ async def _init():
assert _init.shutdown_counter == 2
+@mark.asyncio
+async def test_init_async_context_manager() -> None:
+ resource = object()
+
+ init_counter = 0
+ shutdown_counter = 0
+
+ @asynccontextmanager
+ async def _init():
+ nonlocal init_counter, shutdown_counter
+
+ await asyncio.sleep(0.001)
+ init_counter += 1
+
+ yield resource
+
+ await asyncio.sleep(0.001)
+ shutdown_counter += 1
+
+ provider = providers.Resource(_init)
+
+ result1 = await provider()
+ assert result1 is resource
+ assert init_counter == 1
+ assert shutdown_counter == 0
+
+ await provider.shutdown()
+ assert init_counter == 1
+ assert shutdown_counter == 1
+
+ result2 = await provider()
+ assert result2 is resource
+ assert init_counter == 2
+ assert shutdown_counter == 1
+
+ await provider.shutdown()
+ assert init_counter == 2
+ assert shutdown_counter == 2
+
+
@mark.asyncio
async def test_init_async_class():
resource = object()
diff --git a/tests/unit/providers/resource/test_resource_py35.py b/tests/unit/providers/resource/test_resource_py35.py
index 9b906bd75..842d8ba6f 100644
--- a/tests/unit/providers/resource/test_resource_py35.py
+++ b/tests/unit/providers/resource/test_resource_py35.py
@@ -2,10 +2,12 @@
import decimal
import sys
+from contextlib import contextmanager
from typing import Any
-from dependency_injector import containers, providers, resources, errors
-from pytest import raises, mark
+from pytest import mark, raises
+
+from dependency_injector import containers, errors, providers, resources
def init_fn(*args, **kwargs):
@@ -123,6 +125,41 @@ def _init():
assert _init.shutdown_counter == 2
+def test_init_context_manager() -> None:
+ init_counter, shutdown_counter = 0, 0
+
+ @contextmanager
+ def _init():
+ nonlocal init_counter, shutdown_counter
+
+ init_counter += 1
+ yield
+ shutdown_counter += 1
+
+ init_counter = 0
+ shutdown_counter = 0
+
+ provider = providers.Resource(_init)
+
+ result1 = provider()
+ assert result1 is None
+ assert init_counter == 1
+ assert shutdown_counter == 0
+
+ provider.shutdown()
+ assert init_counter == 1
+ assert shutdown_counter == 1
+
+ result2 = provider()
+ assert result2 is None
+ assert init_counter == 2
+ assert shutdown_counter == 1
+
+ provider.shutdown()
+ assert init_counter == 2
+ assert shutdown_counter == 2
+
+
def test_init_class():
class TestResource(resources.Resource):
init_counter = 0
@@ -190,7 +227,7 @@ def init(self):
def test_init_not_callable():
provider = providers.Resource(1)
- with raises(errors.Error):
+ with raises(TypeError, match=r"object is not callable"):
provider.init()
diff --git a/tests/unit/samples/wiring/asyncinjections.py b/tests/unit/samples/wiring/asyncinjections.py
index 204300e3a..e08610179 100644
--- a/tests/unit/samples/wiring/asyncinjections.py
+++ b/tests/unit/samples/wiring/asyncinjections.py
@@ -1,7 +1,9 @@
import asyncio
+from typing_extensions import Annotated
+
from dependency_injector import containers, providers
-from dependency_injector.wiring import inject, Provide, Closing
+from dependency_injector.wiring import Closing, Provide, inject
class TestResource:
@@ -42,6 +44,15 @@ async def async_injection(
return resource1, resource2
+@inject
+async def async_generator_injection(
+ resource1: object = Provide[Container.resource1],
+ resource2: object = Closing[Provide[Container.resource2]],
+):
+ yield resource1
+ yield resource2
+
+
@inject
async def async_injection_with_closing(
resource1: object = Closing[Provide[Container.resource1]],
diff --git a/tests/unit/wiring/provider_ids/test_async_injections_py36.py b/tests/unit/wiring/provider_ids/test_async_injections_py36.py
index f17f19c77..70f9eb171 100644
--- a/tests/unit/wiring/provider_ids/test_async_injections_py36.py
+++ b/tests/unit/wiring/provider_ids/test_async_injections_py36.py
@@ -32,6 +32,23 @@ async def test_async_injections():
assert asyncinjections.resource2.shutdown_counter == 0
+@mark.asyncio
+async def test_async_generator_injections() -> None:
+ resources = []
+
+ async for resource in asyncinjections.async_generator_injection():
+ resources.append(resource)
+
+ assert len(resources) == 2
+ assert resources[0] is asyncinjections.resource1
+ assert asyncinjections.resource1.init_counter == 1
+ assert asyncinjections.resource1.shutdown_counter == 0
+
+ assert resources[1] is asyncinjections.resource2
+ assert asyncinjections.resource2.init_counter == 1
+ assert asyncinjections.resource2.shutdown_counter == 1
+
+
@mark.asyncio
async def test_async_injections_with_closing():
resource1, resource2 = await asyncinjections.async_injection_with_closing()
diff --git a/tests/unit/wiring/test_cache.py b/tests/unit/wiring/test_cache.py
new file mode 100644
index 000000000..d6c1f45fd
--- /dev/null
+++ b/tests/unit/wiring/test_cache.py
@@ -0,0 +1,46 @@
+"""Tests for string module and package names."""
+
+from typing import Iterator, Optional
+
+from pytest import fixture, mark
+from samples.wiring.container import Container
+
+from dependency_injector.wiring import _fetch_reference_injections
+
+
+@fixture
+def container() -> Iterator[Container]:
+ container = Container()
+ yield container
+ container.unwire()
+
+
+@mark.parametrize(
+ ["arg_value", "wc_value", "empty_cache"],
+ [
+ (None, False, True),
+ (False, True, True),
+ (True, False, False),
+ (None, True, False),
+ ],
+)
+def test_fetch_reference_injections_cache(
+ container: Container,
+ arg_value: Optional[bool],
+ wc_value: bool,
+ empty_cache: bool,
+) -> None:
+ container.wiring_config.keep_cache = wc_value
+ container.wire(
+ modules=["samples.wiring.module"],
+ packages=["samples.wiring.package"],
+ keep_cache=arg_value,
+ )
+ cache_info = _fetch_reference_injections.cache_info()
+
+ if empty_cache:
+ assert cache_info == (0, 0, None, 0)
+ else:
+ assert cache_info.hits > 0
+ assert cache_info.misses > 0
+ assert cache_info.currsize > 0