Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit e1e7644

Browse files
committed
✨HasDType
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent eaa42ce commit e1e7644

File tree

4 files changed

+66
-7
lines changed

4 files changed

+66
-7
lines changed

‎src/array_api_typing/__init__.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
__all__ = (
44
"Array",
55
"HasArrayNamespace",
6+
"HasDType",
67
"__version__",
78
"__version_tuple__",
89
)
910

10-
from ._array import Array, HasArrayNamespace
11+
from ._array import Array, HasArrayNamespace, HasDType
1112
from ._version import version as __version__, version_tuple as __version_tuple__

‎src/array_api_typing/_array.py‎

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import TypeVar
99

1010
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
11+
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
1112

1213

1314
class HasArrayNamespace(Protocol[NamespaceT_co]):
@@ -57,8 +58,32 @@ def __array_namespace__(
5758
...
5859

5960

61+
class HasDType(Protocol[DTypeT_co]):
62+
"""Protocol for array classes that have a data type attribute."""
63+
64+
@property
65+
def dtype(self, /) -> DTypeT_co:
66+
"""Data type of the array elements."""
67+
...
68+
69+
6070
class Array(
61-
HasArrayNamespace[NamespaceT_co],
62-
Protocol[NamespaceT_co],
71+
# ------ Attributes -------
72+
HasDType[DTypeT_co],
73+
# -------------------------
74+
Protocol[DTypeT_co, NamespaceT_co],
6375
):
64-
"""Array API specification for array object attributes and methods."""
76+
"""Array API specification for array object attributes and methods.
77+
78+
The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
79+
NamespaceT]`` where:
80+
81+
- `DTypeT` is the data type of the array elements.
82+
- `NamespaceT` is the type of the array namespace. It defaults to
83+
`ModuleType`, which is the most common form of array namespace (e.g.,
84+
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
85+
`types.SimpleNamespace`, to allow for wrapper libraries to
86+
semi-dynamically define their own array namespaces based on the wrapped
87+
array type.
88+
89+
"""

‎tests/integration/test_numpy1p0.pyi‎

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# mypy: disable-error-code="no-redef"
22

33
from types import ModuleType
4-
from typing import TypeAlias
4+
from typing import Any
55

66
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
7+
from numpy import dtype
78

89
import array_api_typing as xpt
910

@@ -28,8 +29,25 @@ ns: ModuleType = a_ns.__array_namespace__()
2829
# backpropagated to the type of `a_ns`
2930
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3031

32+
# =========================================================
33+
# `xpt.HasDType`
34+
35+
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
36+
# type annotate specific dtypes like `np.float32` or `np.int32`.
37+
38+
_: xpt.HasDType[dtype[Any]] = nparr
39+
_: xpt.HasDType[dtype[Any]] = nparr_i32
40+
_: xpt.HasDType[dtype[Any]] = nparr_f32
41+
3142
# =========================================================
3243
# `xpt.Array`
3344

3445
# Check NamespaceT_co assignment
35-
a_ns: xpt.Array[ModuleType] = nparr
46+
a_ns: xpt.Array[Any, ModuleType] = nparr
47+
48+
# Check DTypeT_co assignment
49+
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
50+
# type annotate specific dtypes like `np.float32` or `np.int32`.
51+
_: xpt.Array[dtype[Any]] = nparr
52+
_: xpt.Array[dtype[Any]] = nparr_i32
53+
_: xpt.Array[dtype[Any]] = nparr_f32

‎tests/integration/test_numpy2p0.pyi‎

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,23 @@ ns: ModuleType = a_ns.__array_namespace__()
3535
# backpropagated to the type of `a_ns`
3636
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3737

38+
# =========================================================
39+
# `xpt.HasDType`
40+
41+
# Check DTypeT_co assignment
42+
_: xpt.HasDType[Any] = nparr
43+
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
44+
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
45+
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
46+
3847
# =========================================================
3948
# `xpt.Array`
4049

4150
# Check NamespaceT_co assignment
42-
a_ns: xpt.Array[ModuleType] = nparr
51+
a_ns: xpt.Array[Any, ModuleType] = nparr
52+
53+
# Check DTypeT_co assignment
54+
_: xpt.Array[Any] = nparr
55+
_: xpt.Array[np.dtype[I32]] = nparr_i32
56+
_: xpt.Array[np.dtype[F32]] = nparr_f32
57+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /