-
Notifications
You must be signed in to change notification settings - Fork 287
Closed
@sfc-gh-bchinn
Description
Say I have a noop decorator:
def trace[**P, T](func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) def func_with_log(*args: P.args, **kwargs: P.kwargs) -> T: return func(*args, **kwargs) return func_with_log class Foo: @trace def foo(self, a: int) -> str: return "foo" Foo().foo(1)
Great. Now let's replace Callable with a Protocol defining __call__, say if we want to access func.__name__ or access specific args/kwargs being passed in:
class MyCallable[**P, T](Protocol): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... def trace[**P, T](func: MyCallable[P, T]) -> MyCallable[P, T]: @functools.wraps(func) def func_with_log(*args: P.args, **kwargs: P.kwargs) -> T: return func(*args, **kwargs) return func_with_log class Foo: @trace def foo(self, a: int) -> str: return "foo" Foo().foo(1)
This fails on mypy with
error: Missing positional argument "a" in call to "__call__" of "MyCallable" [call-arg]
error: Argument 1 to "__call__" of "MyCallable" has incompatible type "int"; expected "Foo" [arg-type]
Per #1040, we should add __get__ to return a Protocol with the post-bound signature:
class MyCallable[**P, T](Protocol): def __call__(self_, self: Any, *args: P.args, **kwargs: P.kwargs) -> T: ... def __get__(self_, *args: Any, **kwargs: Any) -> MyCallableBound[P, T]: ... class MyCallableBound[**P, T](Protocol): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... def trace[**P, T](func: MyCallable[P, T]) -> MyCallable[P, T]: @functools.wraps(func) def func_with_log(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: return func(self, *args, **kwargs) return func_with_log class Foo: @trace def foo(self, a: int) -> str: return "foo" Foo().foo(1)
Now this fails with:
error: Incompatible return value type (got "_Wrapped[[Any, **P], T, [Any, **P], T]", expected "MyCallable[P, T]") [return-value]
note: "_Wrapped" is missing following "MyCallable" protocol member:
note: __get__
It works if I comment out @functools.wraps(). For now, we can workaround it with
import contextlib from typing import Callable, TypeVar WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', '__annotate__', '__type_params__') WRAPPER_UPDATES = ('__dict__',) def wraps[T](func: T) -> Callable[[T], T]: def decorator(new_func: T) -> T: for attr in WRAPPER_ASSIGNMENTS: with contextlib.suppress(AttributeError): setattr(new_func, attr, getattr(func, attr)) for attr in WRAPPER_UPDATES: getattr(new_func, attr).update(getattr(func, attr, {})) setattr(new_func, "__wrapped__", func) return new_func return decorator
Metadata
Metadata
Assignees
Labels
No labels