Skip to content

Optional dependencies #336

@shaunc

Description

@shaunc

I am trying to create a container that optionally takes a dependency, and otherwise provides a value derived from another provider. The (IMO) hacky solution I have so far is a custom provider which either provides a Callable, or a default value if the callable has an error. Then I use this with the Callable being the dependency provider.

My questions are (1) is there a better way? and (2) even using this method, DefaultCallable defined below seems like a hack -- how can I improve?

T = TypeVar("T")

class DefaultCallable(providers.Provider):

    __slots__ = ("_callable", "_default")

    def __init__(
        self, callable: Callable[..., T], default: T, *args, **kwargs
    ):
        self._default = default
        self._callable = providers.Callable(callable, *args, **kwargs)

        super().__init__()

    def __deepcopy__(self, memo):
        copied = memo.get(id(self))
        if copied is not None:
            return copied

        # TODO: type?
        copied = self.__class__(
            cast(Callable[..., T], self._callable.provides),
            providers.deepcopy(self._default, memo),
            *providers.deepcopy(self._callable.args, memo),
            **providers.deepcopy(self._callable.kwargs, memo),
        )
        self._copy_overridings(copied, memo)
        return copied

    def _provide(self, args, kwargs):
        try:
            return self._callable(*args, **kwargs)
        except Exception:
            # TODO: why do we need to check if is provider?
            # type?
            if getattr(cast(Any, self._default), "__IS_PROVIDER__", False):
                return cast(Any, self._default)()
            else:
                return self._default

# Used like

class Foo(containers.DeclarativeContainer):

    #: specify dv for pattern discovery (optional)
    dv_in: Provider[xr.DataArray] = providers.Dependency(
        instance_of=xr.DataArray
    )

    #: dv for pattern discovery (specified or default)
    dv: Provider[xr.DataArray] = DefaultCallable(
        # cast(Callable[..., xr.DataArray], dv_in), type??
        cast(Any, dv_in),
        problem.training.provided["dv"],
    )

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions