Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f203d12

Browse files
committed
✨HasDType
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent b12ea8e commit f203d12

File tree

4 files changed

+65
-7
lines changed

4 files changed

+65
-7
lines changed

‎src/array_api_typing/__init__.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
__all__ = (
44
"Array",
55
"HasArrayNamespace",
6+
"HasDType",
67
"__version__",
78
"__version_tuple__",
89
)
910

10-
from ._array import Array, HasArrayNamespace
11+
from ._array import Array, HasArrayNamespace, HasDType
1112
from ._version import version as __version__, version_tuple as __version_tuple__

‎src/array_api_typing/_array.py‎

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import TypeVar
99

1010
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
11+
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
1112

1213

1314
class HasArrayNamespace(Protocol[NamespaceT_co]):
@@ -38,8 +39,32 @@ def __array_namespace__(
3839
) -> NamespaceT_co: ...
3940

4041

42+
class HasDType(Protocol[DTypeT_co]):
43+
"""Protocol for array classes that have a data type attribute."""
44+
45+
@property
46+
def dtype(self, /) -> DTypeT_co:
47+
"""Data type of the array elements."""
48+
...
49+
50+
4151
class Array(
42-
HasArrayNamespace[NamespaceT_co],
43-
Protocol[NamespaceT_co],
52+
# ------ Attributes -------
53+
HasDType[DTypeT_co],
54+
# -------------------------
55+
Protocol[DTypeT_co, NamespaceT_co],
4456
):
45-
"""Array API specification for array object attributes and methods."""
57+
"""Array API specification for array object attributes and methods.
58+
59+
The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
60+
NamespaceT]`` where:
61+
62+
- `DTypeT` is the data type of the array elements.
63+
- `NamespaceT` is the type of the array namespace. It defaults to
64+
`ModuleType`, which is the most common form of array namespace (e.g.,
65+
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
66+
`types.SimpleNamespace`, to allow for wrapper libraries to
67+
semi-dynamically define their own array namespaces based on the wrapped
68+
array type.
69+
70+
"""

‎tests/integration/test_numpy1p0.pyi‎

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: disable-error-code="no-redef"
22

33
from types import ModuleType
4-
from typing import TypeAlias
4+
from typing import Any
55

66
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
77

@@ -29,8 +29,23 @@ ns: ModuleType = a_ns.__array_namespace__()
2929
# backpropagated to the type of `a_ns`
3030
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3131

32+
# =========================================================
33+
# `xpt.HasDType`
34+
35+
# Check DTypeT_co assignment
36+
_: xpt.HasDType[Any] = nparr
37+
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
38+
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
39+
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
40+
3241
# =========================================================
3342
# `xpt.Array`
3443

3544
# Check NamespaceT_co assignment
36-
a_ns: xpt.Array[ModuleType] = nparr
45+
a_ns: xpt.Array[Any, ModuleType] = nparr
46+
47+
# Check DTypeT_co assignment
48+
_: xpt.Array[Any] = nparr
49+
_: xpt.Array[np.dtype[np.int32]] = nparr_i32
50+
_: xpt.Array[np.dtype[np.float32]] = nparr_f32
51+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

‎tests/integration/test_numpy2p0.pyi‎

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import numpy.typing as npt
99
import array_api_typing as xpt
1010

1111
# DType aliases
12+
F: TypeAlias = np.floating[Any]
1213
F32: TypeAlias = np.float32
14+
I: TypeAlias = np.integer[Any]
1315
I32: TypeAlias = np.int32
1416

1517
# Define NDArrays against which we can test the protocols
@@ -35,8 +37,23 @@ ns: ModuleType = a_ns.__array_namespace__()
3537
# backpropagated to the type of `a_ns`
3638
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3739

40+
# =========================================================
41+
# `xpt.HasDType`
42+
43+
# Check DTypeT_co assignment
44+
_: xpt.HasDType[Any] = nparr
45+
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
46+
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
47+
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
48+
3849
# =========================================================
3950
# `xpt.Array`
4051

4152
# Check NamespaceT_co assignment
42-
a_ns: xpt.Array[ModuleType] = nparr
53+
a_ns: xpt.Array[Any, ModuleType] = nparr
54+
55+
# Check DTypeT_co assignment
56+
_: xpt.Array[Any] = nparr
57+
_: xpt.Array[np.dtype[I32]] = nparr_i32
58+
_: xpt.Array[np.dtype[F32]] = nparr_f32
59+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /