From b48832f97562e6c7342a167bb10ac592e32eca54 Mon Sep 17 00:00:00 2001 From: Perchun Pak Date: Wed, 10 Jun 2026 21:09:18 +0200 Subject: [PATCH] Support all context managers --- src/inject/__init__.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/inject/__init__.py b/src/inject/__init__.py index 851f310..8f3628c 100644 --- a/src/inject/__init__.py +++ b/src/inject/__init__.py @@ -112,7 +112,14 @@ def my_config(binder): Injectable = t.Union[object, t.Any] T = t.TypeVar("T", bound=Injectable) Binding = t.Union[type[Injectable], t.Hashable] -Constructor = t.Callable[[], Injectable] +Constructor = t.Callable[ + [], + t.Union[ + Injectable, + contextlib.AbstractContextManager[Injectable], + contextlib.AbstractAsyncContextManager[Injectable], + ], +] Provider = Constructor BinderCallable = t.Callable[["Binder"], t.Optional["Binder"]] @@ -411,7 +418,7 @@ def _aggregate_sync_stack( param: sync_stack.enter_context(inst) for param, inst in kwargs.items() if param not in provided_params - and isinstance(inst, contextlib._GeneratorContextManager) # noqa: SLF001 + and isinstance(inst, contextlib.AbstractContextManager) } kwargs.update(executed_kwargs) @@ -426,7 +433,7 @@ async def _aggregate_async_stack( param: await async_stack.enter_async_context(inst) for param, inst in kwargs.items() if param not in provided_params - and isinstance(inst, contextlib._AsyncGeneratorContextManager) # noqa: SLF001 + and isinstance(inst, contextlib.AbstractAsyncContextManager) } kwargs.update(executed_kwargs)