-
Notifications
You must be signed in to change notification settings - Fork 288
Annotate a Python function so that it narrows down the return type of an union involving a typevar #1639
-
Hello,
I'm currently in the process of adding type annotations for a medium-sized library. The library has some quite complex cases that I would like to handle. I managed to simplify these complex cases to the following, much easier, one.
Consider the following pseudo-identity function:
def g(x, y): return x, y
Let's assume that this function accepts either two parameters of the same type (TL;DR: any pair of objects of the same type implementing the Comparable protocol I defined but let's consider subclasses of int for this example) OR a special string literal (for this example, let's simply consider a str). To summarize, these calls are valid: g(1, 2), g('hi', 'world'), g(1, 'world'), etc. but these calls aren't: g({}, 3), g([], True), etc.
I'm looking for a way to annotate this function so that the return type can be deduced by mypy precisely. For instance, I'm expecting g(1,2) to lead to a tuple of ints, g('hi', 'world') to a tuple of str, g(1, 'world') to tuple[int, str], and the invalid cases would lead to an error.
Naively, I wrote:
from typing import TypeVar T = TypeVar('T', bound=int) def g(x: T | str, y: T | str) -> tuple[T | str, T | str]: return x, y
AFAIK, this should be enough to ensure that both parameters have the same type (and are instances of int) OR that one (or both) of them is/are str.
However, when I execute mypy on this snippet as follows:
reveal_type(g(3, 4)) reveal_type(g('hello', 'world')) reveal_type(g('hello', 2))
I got this output:
main.py:8: note: Revealed type is "tuple[Union[builtins.int, builtins.str], Union[builtins.int, builtins.str]]"
main.py:9: note: Revealed type is "tuple[builtins.str, builtins.str]"
main.py:10: note: Revealed type is "tuple[Union[builtins.int, builtins.str], Union[builtins.int, builtins.str]]"
Success: no issues found in 1 source file
You can see this live in mypy Playground: https://mypy-play.net/?mypy=latest&python=3.12&gist=a5063b88271ddd94a58c82e3376deb5c
The second line is what I expected (a tuple of strings), but the two other calls should (or could) be refined to tuple[int, int] and tuple[str, int]. The result is similar with PyRight, and merely similar with Pyre (the latter indicates tuple[str, Literal[2]] instead of tuple[str, int] for the last case, for some unknown reason).
Is this something out of scope of the typing module or out of scope of mypy? Am I missing something?
Notice that I cannot write T = TypeVar('T', int, str). While this works (for the above example), remember that I used int to simplify the case, where in practice I want T to be any class implementing my Comparable protocol (so I need to define T with an upper bound, not with a list of exact types, this is, T = TypeVar('T', bound=Comparable)).
To put some more context for this question (do not read if you don't care ;-), the library I want to annotate defines an Interval class made of two bounds (the lower and upper ones). These bounds can be any object that supports comparison (e.g., ints, floats, dates). To handle infinite and semi-infinite intervals, the library defines a two special objects, namely inf and -inf that are respectively singleton instances of _PInf and _NInf. My goal is to annotate this Interval class so that mypy can check whether Interval(x, y) is valid (i.e., x and y are of the same Comparable type, or x or y or both are instances of _PInf _NInf).
Generalizing the above example, I currently have something like this:
Comparable = ... # Protocol with __eq__, __lt__, __le__, ... T = TypeVar('T', bound=Comparable) Bound: TypeAlias = T | _PInf | _NInf class Interval(Generic[T]): def __init__(self, lower: Bound[T], upper: Bound[T]) -> None: ...
The goal is (1) to make sure that both bounds are "compatible" (as explained in previous paragraph), and (2) that, for example, Interval(1, 2).lower + 1 does not trigger any complaint from mypy, while Interval(-inf, 2).lower + 1 does.
Beta Was this translation helpful? Give feedback.
All reactions
For polymorphic functions like this, typing.overload is usually the best approach.
Code sample in pyright playground
from typing import TypeVar, overload T = TypeVar("T", bound=int) @overload def g(x: T, y: T) -> tuple[T, T]: ... @overload def g(x: T, y: str) -> tuple[T, str]: ... @overload def g(x: str, y: T) -> tuple[str, T]: ... @overload def g(x: str, y: str) -> tuple[str, str]: ... def g(x: T | str, y: T | str) -> tuple[T | str, T | str]: return x, y
Replies: 2 comments 1 reply
-
For polymorphic functions like this, typing.overload is usually the best approach.
Code sample in pyright playground
from typing import TypeVar, overload T = TypeVar("T", bound=int) @overload def g(x: T, y: T) -> tuple[T, T]: ... @overload def g(x: T, y: str) -> tuple[T, str]: ... @overload def g(x: str, y: T) -> tuple[str, T]: ... @overload def g(x: str, y: str) -> tuple[str, str]: ... def g(x: T | str, y: T | str) -> tuple[T | str, T | str]: return x, y
Beta Was this translation helpful? Give feedback.
All reactions
-
❤️ 1
-
Thanks for your answer!
I'll try this approach in the context of my generic Interval class and provide some feedback ;)
Beta Was this translation helpful? Give feedback.
All reactions
-
I couldn't make it work in my context, but for other reasons I'll try to address soon :-)
Thanks again for your answer! :)
Beta Was this translation helpful? Give feedback.