Skip to content

How to use functools.wraps with method decorators #2127

@sfc-gh-bchinn

Description

@sfc-gh-bchinn

Say I have a noop decorator:

def trace[**P, T](func: Callable[P, T]) -> Callable[P, T]:
    @functools.wraps(func)
    def func_with_log(*args: P.args, **kwargs: P.kwargs) -> T:
        return func(*args, **kwargs)
    return func_with_log

class Foo:
    @trace
    def foo(self, a: int) -> str:
        return "foo"

Foo().foo(1)

Great. Now let's replace Callable with a Protocol defining __call__, say if we want to access func.__name__ or access specific args/kwargs being passed in:

class MyCallable[**P, T](Protocol):
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

def trace[**P, T](func: MyCallable[P, T]) -> MyCallable[P, T]:
    @functools.wraps(func)
    def func_with_log(*args: P.args, **kwargs: P.kwargs) -> T:
        return func(*args, **kwargs)
    return func_with_log

class Foo:
    @trace
    def foo(self, a: int) -> str:
        return "foo"

Foo().foo(1)

This fails on mypy with

error: Missing positional argument "a" in call to "__call__" of "MyCallable"  [call-arg]
error: Argument 1 to "__call__" of "MyCallable" has incompatible type "int"; expected "Foo"  [arg-type]

Per #1040, we should add __get__ to return a Protocol with the post-bound signature:

class MyCallable[**P, T](Protocol):
    def __call__(self_, self: Any, *args: P.args, **kwargs: P.kwargs) -> T: ...
    def __get__(self_, *args: Any, **kwargs: Any) -> MyCallableBound[P, T]: ...
class MyCallableBound[**P, T](Protocol):
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

def trace[**P, T](func: MyCallable[P, T]) -> MyCallable[P, T]:
    @functools.wraps(func)
    def func_with_log(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
        return func(self, *args, **kwargs)
    return func_with_log

class Foo:
    @trace
    def foo(self, a: int) -> str:
        return "foo"

Foo().foo(1)

Now this fails with:

error: Incompatible return value type (got "_Wrapped[[Any, **P], T, [Any, **P], T]", expected "MyCallable[P, T]")  [return-value]
note: "_Wrapped" is missing following "MyCallable" protocol member:
note:     __get__

It works if I comment out @functools.wraps(). For now, we can workaround it with

import contextlib
from typing import Callable, TypeVar

WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
                       '__annotate__', '__type_params__')
WRAPPER_UPDATES = ('__dict__',)

def wraps[T](func: T) -> Callable[[T], T]:
    def decorator(new_func: T) -> T:
        for attr in WRAPPER_ASSIGNMENTS:
            with contextlib.suppress(AttributeError):
                setattr(new_func, attr, getattr(func, attr))
        for attr in WRAPPER_UPDATES:
            getattr(new_func, attr).update(getattr(func, attr, {}))
        setattr(new_func, "__wrapped__", func)
        return new_func

    return decorator

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions