Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit fe8dc7c

Browse files
committed
✨ add CanArray* unop and binop protocols
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
1 parent 858a3db commit fe8dc7c

File tree

8 files changed

+364
-14
lines changed

8 files changed

+364
-14
lines changed

‎pyproject.toml‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929
"typing-extensions>=4.14.1",
3030
"optype>=0.9.3; python_version < '3.11'",
3131
"optype>=0.12.2; python_version >= '3.11'",
32+
"tomli>=1.2.0 ; python_full_version < '3.11'",
3233
]
3334

3435
[project.urls]
@@ -123,9 +124,12 @@ ignore = [
123124
"D107", # Missing docstring in __init__
124125
"D203", # 1 blank line required before class docstring
125126
"D213", # Multi-line docstring summary should start at the second line
127+
"D401", # First line of docstring should be in imperative mood
126128
"FBT", # flake8-boolean-trap
127129
"FIX", # flake8-fixme
128130
"ISC001", # Conflicts with formatter
131+
"PLW1641", # Object does not implement `__hash__` method
132+
"PYI041", # Use `float` instead of `int | float`
129133
]
130134

131135
[tool.ruff.lint.pylint]

‎src/array_api_typing/_array.py‎

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
11
__all__ = (
22
"Array",
3+
"BoolArray",
34
"HasArrayNamespace",
5+
"NumericArray",
46
)
57

8+
from pathlib import Path
69
from types import ModuleType
7-
from typing import Literal, Protocol
10+
from typing import Literal, Never, Protocol, TypeAlias
811
from typing_extensions import TypeVar
912

13+
import optype as op
14+
15+
from ._utils import docstring_setter
16+
17+
# Load docstrings from TOML file
18+
try:
19+
import tomllib
20+
except ImportError:
21+
import tomli as tomllib # type: ignore[import-not-found, no-redef]
22+
23+
_docstrings_path = Path(__file__).parent / "_array_docstrings.toml"
24+
with _docstrings_path.open("rb") as f:
25+
_array_docstrings = tomllib.load(f)["docstrings"]
26+
1027
NS_co = TypeVar("NS_co", covariant=True, default=ModuleType)
28+
T_contra = TypeVar("T_contra", contravariant=True)
29+
R_co = TypeVar("R_co", covariant=True, default=Never)
1130

1231

1332
class HasArrayNamespace(Protocol[NS_co]):
@@ -33,8 +52,37 @@ def __array_namespace__(
3352
) -> NS_co: ...
3453

3554

55+
@docstring_setter(**_array_docstrings)
3656
class Array(
3757
HasArrayNamespace[NS_co],
38-
Protocol[NS_co],
58+
op.CanPosSelf,
59+
op.CanNegSelf,
60+
op.CanAddSame[T_contra, R_co],
61+
op.CanSubSame[T_contra, R_co],
62+
op.CanMulSame[T_contra, R_co],
63+
op.CanTruedivSame[T_contra, R_co],
64+
op.CanFloordivSame[T_contra, R_co],
65+
op.CanModSame[T_contra, R_co],
66+
op.CanPowSame[T_contra, R_co],
67+
Protocol[T_contra, R_co, NS_co],
3968
):
4069
"""Array API specification for array object attributes and methods."""
70+
71+
72+
BoolArray: TypeAlias = Array[bool, Array[float, Never, NS_co], NS_co]
73+
"""Array API specification for boolean array object attributes and methods.
74+
75+
Specifically, this type alias fills the `T_contra` type variable with
76+
`bool`, allowing for `bool` objects to be added, subtracted, multiplied, etc. to
77+
the array object.
78+
79+
"""
80+
81+
NumericArray: TypeAlias = Array[float | int, NS_co]
82+
"""Array API specification for numeric array object attributes and methods.
83+
84+
Specifically, this type alias fills the `T_contra` type variable with `float
85+
| int`, allowing for `float | int` objects to be added, subtracted, multiplied,
86+
etc. to the array object.
87+
88+
"""
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
[docstrings]
2+
__pos__ = '''
3+
Evaluates `+self_i` for each element of an array instance.
4+
5+
Returns:
6+
Self: An array containing the evaluated result for each element.
7+
The returned array must have the same data type as `self`.
8+
9+
See Also:
10+
array_api_typing.Positive
11+
12+
'''
13+
14+
__neg__ = '''
15+
Evaluates `-self_i` for each element of an array instance.
16+
17+
Returns:
18+
Self: an array containing the evaluated result for each element in
19+
`self`. The returned array must have a data type determined by Type
20+
Promotion Rules.
21+
22+
See Also:
23+
array_api_typing.Negative
24+
25+
'''
26+
27+
__add__ = '''
28+
Calculates the sum for each element of an array instance with the respective
29+
element of the array `other`.
30+
31+
Args:
32+
other: addend array. Must be compatible with `self` (see
33+
Broadcasting). Should have a numeric data type.
34+
35+
Returns:
36+
Self: an array containing the element-wise sums. The returned array
37+
must have a data type determined by Type Promotion Rules.
38+
39+
See Also:
40+
array_api_typing.Add
41+
42+
'''
43+
44+
__sub__ = '''
45+
Calculates the difference for each element of an array instance with the
46+
respective element of the array other.
47+
48+
The result of `self_i - other_i` must be the same as `self_i +
49+
(-other_i)` and must be governed by the same floating-point rules as
50+
addition (see `CanArrayAdd`).
51+
52+
Args:
53+
other: subtrahend array. Must be compatible with `self` (see
54+
Broadcasting). Should have a numeric data type.
55+
56+
Returns:
57+
Self: an array containing the element-wise differences. The returned
58+
array must have a data type determined by Type Promotion Rules.
59+
60+
See Also:
61+
array_api_typing.Subtract
62+
63+
'''
64+
65+
__mul__ = '''
66+
Calculates the product for each element of an array instance with the
67+
respective element of the array `other`.
68+
69+
Args:
70+
other: multiplicand array. Must be compatible with `self` (see
71+
Broadcasting). Should have a numeric data type.
72+
73+
Returns:
74+
Self: an array containing the element-wise products. The returned
75+
array must have a data type determined by Type Promotion Rules.
76+
77+
See Also:
78+
array_api_typing.Multiply
79+
80+
'''
81+
82+
__truediv__ = '''
83+
Evaluates `self_i / other_i` for each element of an array instance with the
84+
respective element of the array `other`.
85+
86+
Args:
87+
other: Must be compatible with `self` (see Broadcasting). Should have a
88+
numeric data type.
89+
90+
Returns:
91+
Self: an array containing the element-wise results. The returned array
92+
should have a floating-point data type determined by Type Promotion
93+
Rules.
94+
95+
See Also:
96+
array_api_typing.TrueDiv
97+
98+
'''
99+
100+
__floordiv__ = '''
101+
Evaluates `self_i // other_i` for each element of an array instance with the
102+
respective element of the array `other`.
103+
104+
Args:
105+
other: Must be compatible with `self` (see Broadcasting). Should have a
106+
numeric data type.
107+
108+
Returns:
109+
Self: an array containing the element-wise results. The returned array
110+
must have a data type determined by Type Promotion Rules.
111+
112+
See Also:
113+
array_api_typing.FloorDiv
114+
115+
'''
116+
117+
__mod__ = '''
118+
Evaluates `self_i % other_i` for each element of an array instance with the
119+
respective element of the array `other`.
120+
121+
Args:
122+
other: Must be compatible with `self` (see Broadcasting). Should have a
123+
numeric data type.
124+
125+
Returns:
126+
Self: an array containing the element-wise results. Each element-wise
127+
result must have the same sign as the respective element `other_i`.
128+
The returned array must have a floating-point data type determined
129+
by Type Promotion Rules.
130+
131+
See Also:
132+
array_api_typing.Remainder
133+
134+
'''
135+
136+
__pow__ = '''
137+
Calculates an implementation-dependent approximation of exponentiation by
138+
raising each element (the base) of an array instance to the power of
139+
`other_i` (the exponent), where `other_i` is the corresponding element of
140+
the array `other`.
141+
142+
Args:
143+
other: array whose elements correspond to the exponentiation exponent.
144+
Must be compatible with `self` (see Broadcasting). Should have a
145+
numeric data type.
146+
147+
Returns:
148+
Self: an array containing the element-wise results. The returned array
149+
must have a data type determined by Type Promotion Rules.
150+
151+
'''

‎src/array_api_typing/_namespace.py‎

Whitespace-only changes.

‎src/array_api_typing/_utils.py‎

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Utility functions."""
2+
3+
from collections.abc import Callable
4+
from enum import Enum, auto
5+
from typing import Literal, TypeVar
6+
7+
ClassT = TypeVar("ClassT")
8+
DocstringTypes = str | None
9+
10+
11+
class _Sentinel(Enum):
12+
SKIP = auto()
13+
14+
15+
def set_docstrings(
16+
obj: type[ClassT],
17+
main: DocstringTypes | Literal[_Sentinel.SKIP] = _Sentinel.SKIP,
18+
/,
19+
**method_docs: DocstringTypes,
20+
) -> type[ClassT]:
21+
"""Set the docstring for a class and its methods.
22+
23+
Args:
24+
obj: The class to set the docstring for.
25+
main: The main docstring for the class. If not provided, the
26+
class docstring will not be modified.
27+
method_docs: A mapping of method names to their docstrings. If a method
28+
is not provided, its docstring will not be modified.
29+
30+
Returns:
31+
The class with updated docstrings.
32+
33+
"""
34+
if main is not _Sentinel.SKIP:
35+
obj.__doc__ = main
36+
37+
for name, doc in method_docs.items():
38+
method = getattr(obj, name)
39+
method.__doc__ = doc
40+
return obj
41+
42+
43+
def docstring_setter(
44+
main: DocstringTypes | Literal[_Sentinel.SKIP] = _Sentinel.SKIP,
45+
/,
46+
**method_docs: DocstringTypes,
47+
) -> Callable[[type[ClassT]], type[ClassT]]:
48+
"""Decorator to set docstrings for a class and its methods.
49+
50+
Args:
51+
main: The main docstring for the class. If not provided, the
52+
class docstring will not be modified.
53+
method_docs: A mapping of method names to their docstrings. If a method
54+
is not provided, its docstring will not be modified.
55+
56+
Returns:
57+
A decorator that sets the docstrings for the class and its methods.
58+
59+
"""
60+
61+
def decorator(cls: type[ClassT]) -> type[ClassT]:
62+
return set_docstrings(cls, main, **method_docs)
63+
64+
return decorator

‎tests/integration/test_numpy1.pyi‎

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,53 @@
1-
fromtypingimportAny
1+
# Test array_api_typing with numpy < 2.0
22

3-
# requires numpy < 2
4-
import numpy.array_api as np # type: ignore[import-not-found]
3+
from typing import Any, Never, TypeAlias
4+
5+
import numpy as np
6+
from numpy.array_api import asarray # type: ignore[import-not-found, unused-ignore]
57

68
import array_api_typing as xpt
9+
from array_api_typing._array import BoolArray, NumericArray
10+
11+
F: TypeAlias = np.floating[Any]
12+
F32: TypeAlias = np.float32
13+
I: TypeAlias = np.integer[Any]
14+
I32: TypeAlias = np.int32
15+
16+
# Define an NDArray against which we can test the protocols
17+
nparr = np.eye(2)
18+
nparr_i32 = asarray([1], dtype=np.int32)
19+
nparr_f32 = asarray([1.0], dtype=np.float32)
20+
nparr_b = asarray([True], dtype=np.bool)
21+
22+
# =========================================================
23+
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`
24+
25+
arr_ns: xpt.HasArrayNamespace[Any] = nparr
26+
arr_ns_i32: xpt.HasArrayNamespace[Any] = nparr_i32
27+
arr_ns_f32: xpt.HasArrayNamespace[Any] = nparr_f32
28+
29+
# =========================================================
30+
# Ensure that `np.ndarray` instances are assignable to `xpt.Array`.
31+
32+
# Generic Array type
33+
arr_array: xpt.Array[Never] = nparr
34+
35+
# Float Array types
36+
arr_float: xpt.Array[float] = nparr_f32
37+
arr_f: xpt.Array[F] = nparr_f32
38+
arr_f32: xpt.Array[F32] = nparr_f32
39+
40+
# Integer Array types
41+
arr_int: xpt.Array[int, xpt.Array[float | int]] = nparr_i32
42+
arr_i: xpt.Array[I, xpt.Array[float | int]] = nparr_i32
43+
arr_i32: xpt.Array[I32, xpt.Array[F32 | I32]] = nparr_i32
44+
45+
# Boolean Array types
46+
arr_bool: xpt.Array[bool, xpt.Array[float | int | bool]] = nparr_b
47+
arr_b: xpt.Array[np.bool_, xpt.Array[F | I | np.bool_]] = nparr_b
748

8-
###
9-
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`.
49+
# =========================================================
50+
# Check np.ndarray against BoolArray and NumericArray type aliases
1051

11-
arr=np.eye(2)
12-
arr_namespace: xpt.HasArrayNamespace[Any] = arr
52+
boolarray: BoolArray=nparr_b
53+
numericarray: NumericArray = nparr_f32

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /