Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d4485f4

Browse files
committed
✨ add CanArray* unop and binop protocols
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
1 parent 858a3db commit d4485f4

File tree

8 files changed

+364
-14
lines changed

8 files changed

+364
-14
lines changed

‎pyproject.toml‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929
"typing-extensions>=4.14.1",
3030
"optype>=0.9.3; python_version < '3.11'",
3131
"optype>=0.12.2; python_version >= '3.11'",
32+
"tomli>=1.2.0 ; python_full_version < '3.11'",
3233
]
3334

3435
[project.urls]
@@ -123,9 +124,12 @@ ignore = [
123124
"D107", # Missing docstring in __init__
124125
"D203", # 1 blank line required before class docstring
125126
"D213", # Multi-line docstring summary should start at the second line
127+
"D401", # First line of docstring should be in imperative mood
126128
"FBT", # flake8-boolean-trap
127129
"FIX", # flake8-fixme
128130
"ISC001", # Conflicts with formatter
131+
"PLW1641", # Object does not implement `__hash__` method
132+
"PYI041", # Use `float` instead of `int | float`
129133
]
130134

131135
[tool.ruff.lint.pylint]

‎src/array_api_typing/_array.py‎

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
11
__all__ = (
22
"Array",
3+
"BoolArray",
34
"HasArrayNamespace",
5+
"NumericArray",
46
)
57

8+
from pathlib import Path
69
from types import ModuleType
7-
from typing import Literal, Protocol
10+
from typing import Literal, Never, Protocol, TypeAlias
811
from typing_extensions import TypeVar
912

13+
import optype as op
14+
15+
from ._utils import docstring_setter
16+
17+
# Load docstrings from TOML file
18+
try:
19+
import tomllib
20+
except ImportError:
21+
import tomli as tomllib # type: ignore[import-not-found, no-redef]
22+
23+
_docstrings_path = Path(__file__).parent / "_array_docstrings.toml"
24+
with _docstrings_path.open("rb") as f:
25+
_array_docstrings = tomllib.load(f)["docstrings"]
26+
1027
NS_co = TypeVar("NS_co", covariant=True, default=ModuleType)
28+
T_contra = TypeVar("T_contra", contravariant=True)
29+
R_co = TypeVar("R_co", covariant=True, default=Never)
1130

1231

1332
class HasArrayNamespace(Protocol[NS_co]):
@@ -33,8 +52,37 @@ def __array_namespace__(
3352
) -> NS_co: ...
3453

3554

55+
@docstring_setter(**_array_docstrings)
3656
class Array(
3757
HasArrayNamespace[NS_co],
38-
Protocol[NS_co],
58+
op.CanPosSelf,
59+
op.CanNegSelf,
60+
op.CanAddSame[T_contra, R_co],
61+
op.CanSubSame[T_contra, R_co],
62+
op.CanMulSame[T_contra, R_co],
63+
op.CanTruedivSame[T_contra, R_co],
64+
op.CanFloordivSame[T_contra, R_co],
65+
op.CanModSame[T_contra, R_co],
66+
op.CanPowSame[T_contra, R_co],
67+
Protocol[T_contra, R_co, NS_co],
3968
):
4069
"""Array API specification for array object attributes and methods."""
70+
71+
72+
BoolArray: TypeAlias = Array[bool, Array[float, Never, NS_co], NS_co]
73+
"""Array API specification for boolean array object attributes and methods.
74+
75+
Specifically, this type alias fills the `T_contra` type variable with
76+
`bool`, allowing for `bool` objects to be added, subtracted, multiplied, etc. to
77+
the array object.
78+
79+
"""
80+
81+
NumericArray: TypeAlias = Array[float | int, NS_co]
82+
"""Array API specification for numeric array object attributes and methods.
83+
84+
Specifically, this type alias fills the `T_contra` type variable with `float
85+
| int`, allowing for `float | int` objects to be added, subtracted, multiplied,
86+
etc. to the array object.
87+
88+
"""
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
[docstrings]
2+
__pos__ = '''
3+
Evaluates `+self_i` for each element of an array instance.
4+
5+
Returns:
6+
Self: An array containing the evaluated result for each element.
7+
The returned array must have the same data type as `self`.
8+
9+
See Also:
10+
array_api_typing.Positive
11+
12+
'''
13+
14+
__neg__ = '''
15+
Evaluates `-self_i` for each element of an array instance.
16+
17+
Returns:
18+
Self: an array containing the evaluated result for each element in
19+
`self`. The returned array must have a data type determined by Type
20+
Promotion Rules.
21+
22+
See Also:
23+
array_api_typing.Negative
24+
25+
'''
26+
27+
__add__ = '''
28+
Calculates the sum for each element of an array instance with the respective
29+
element of the array `other`.
30+
31+
Args:
32+
other: addend array. Must be compatible with `self` (see
33+
Broadcasting). Should have a numeric data type.
34+
35+
Returns:
36+
Self: an array containing the element-wise sums. The returned array
37+
must have a data type determined by Type Promotion Rules.
38+
39+
See Also:
40+
array_api_typing.Add
41+
42+
'''
43+
44+
__sub__ = '''
45+
Calculates the difference for each element of an array instance with the
46+
respective element of the array other.
47+
48+
The result of `self_i - other_i` must be the same as `self_i +
49+
(-other_i)` and must be governed by the same floating-point rules as
50+
addition (see `CanArrayAdd`).
51+
52+
Args:
53+
other: subtrahend array. Must be compatible with `self` (see
54+
Broadcasting). Should have a numeric data type.
55+
56+
Returns:
57+
Self: an array containing the element-wise differences. The returned
58+
array must have a data type determined by Type Promotion Rules.
59+
60+
See Also:
61+
array_api_typing.Subtract
62+
63+
'''
64+
65+
__mul__ = '''
66+
Calculates the product for each element of an array instance with the
67+
respective element of the array `other`.
68+
69+
Args:
70+
other: multiplicand array. Must be compatible with `self` (see
71+
Broadcasting). Should have a numeric data type.
72+
73+
Returns:
74+
Self: an array containing the element-wise products. The returned
75+
array must have a data type determined by Type Promotion Rules.
76+
77+
See Also:
78+
array_api_typing.Multiply
79+
80+
'''
81+
82+
__truediv__ = '''
83+
Evaluates `self_i / other_i` for each element of an array instance with the
84+
respective element of the array `other`.
85+
86+
Args:
87+
other: Must be compatible with `self` (see Broadcasting). Should have a
88+
numeric data type.
89+
90+
Returns:
91+
Self: an array containing the element-wise results. The returned array
92+
should have a floating-point data type determined by Type Promotion
93+
Rules.
94+
95+
See Also:
96+
array_api_typing.TrueDiv
97+
98+
'''
99+
100+
__floordiv__ = '''
101+
Evaluates `self_i // other_i` for each element of an array instance with the
102+
respective element of the array `other`.
103+
104+
Args:
105+
other: Must be compatible with `self` (see Broadcasting). Should have a
106+
numeric data type.
107+
108+
Returns:
109+
Self: an array containing the element-wise results. The returned array
110+
must have a data type determined by Type Promotion Rules.
111+
112+
See Also:
113+
array_api_typing.FloorDiv
114+
115+
'''
116+
117+
__mod__ = '''
118+
Evaluates `self_i % other_i` for each element of an array instance with the
119+
respective element of the array `other`.
120+
121+
Args:
122+
other: Must be compatible with `self` (see Broadcasting). Should have a
123+
numeric data type.
124+
125+
Returns:
126+
Self: an array containing the element-wise results. Each element-wise
127+
result must have the same sign as the respective element `other_i`.
128+
The returned array must have a floating-point data type determined
129+
by Type Promotion Rules.
130+
131+
See Also:
132+
array_api_typing.Remainder
133+
134+
'''
135+
136+
__pow__ = '''
137+
Calculates an implementation-dependent approximation of exponentiation by
138+
raising each element (the base) of an array instance to the power of
139+
`other_i` (the exponent), where `other_i` is the corresponding element of
140+
the array `other`.
141+
142+
Args:
143+
other: array whose elements correspond to the exponentiation exponent.
144+
Must be compatible with `self` (see Broadcasting). Should have a
145+
numeric data type.
146+
147+
Returns:
148+
Self: an array containing the element-wise results. The returned array
149+
must have a data type determined by Type Promotion Rules.
150+
151+
'''

‎src/array_api_typing/_namespace.py‎

Whitespace-only changes.

‎src/array_api_typing/_utils.py‎

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Utility functions."""
2+
3+
from collections.abc import Callable
4+
from enum import Enum, auto
5+
from typing import Literal, TypeVar
6+
7+
ClassT = TypeVar("ClassT")
8+
DocstringTypes = str | None
9+
10+
11+
class _Sentinel(Enum):
12+
SKIP = auto()
13+
14+
15+
def set_docstrings(
16+
obj: type[ClassT],
17+
main: DocstringTypes | Literal[_Sentinel.SKIP] = _Sentinel.SKIP,
18+
/,
19+
**method_docs: DocstringTypes,
20+
) -> type[ClassT]:
21+
"""Set the docstring for a class and its methods.
22+
23+
Args:
24+
obj: The class to set the docstring for.
25+
main: The main docstring for the class. If not provided, the
26+
class docstring will not be modified.
27+
method_docs: A mapping of method names to their docstrings. If a method
28+
is not provided, its docstring will not be modified.
29+
30+
Returns:
31+
The class with updated docstrings.
32+
33+
"""
34+
if main is not _Sentinel.SKIP:
35+
obj.__doc__ = main
36+
37+
for name, doc in method_docs.items():
38+
method = getattr(obj, name)
39+
method.__doc__ = doc
40+
return obj
41+
42+
43+
def docstring_setter(
44+
main: DocstringTypes | Literal[_Sentinel.SKIP] = _Sentinel.SKIP,
45+
/,
46+
**method_docs: DocstringTypes,
47+
) -> Callable[[type[ClassT]], type[ClassT]]:
48+
"""Decorator to set docstrings for a class and its methods.
49+
50+
Args:
51+
main: The main docstring for the class. If not provided, the
52+
class docstring will not be modified.
53+
method_docs: A mapping of method names to their docstrings. If a method
54+
is not provided, its docstring will not be modified.
55+
56+
Returns:
57+
A decorator that sets the docstrings for the class and its methods.
58+
59+
"""
60+
61+
def decorator(cls: type[ClassT]) -> type[ClassT]:
62+
return set_docstrings(cls, main, **method_docs)
63+
64+
return decorator

‎tests/integration/test_numpy1.pyi‎

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,53 @@
1-
from typing import Any
1+
# Test array_api_typing with numpy < 2.0
2+
#
3+
# NOTES:
4+
# - `np.array_api` uses dtype objects instead of dtype classes, preventing the
5+
# use of `np.float32` and `np.int32` as type aliases in type annotations.
6+
# - `bool` doesn't seem to be a valid dtype in `np.array_api`. The valid dtypes
7+
# are signedinteger of 8, 16, 32, and 64 bits, unsignedinteger of 8, 16, 32,
8+
# and 64 bits, and floating of 32 and 64 bits and None.
29

3-
# requires numpy < 2
4-
import numpy.array_api as np # type: ignore[import-not-found]
10+
from typing import Any, Never, TypeAlias
11+
12+
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
13+
from numpy import float32, floating, int32, integer
514

615
import array_api_typing as xpt
16+
from array_api_typing._array import NumericArray
17+
18+
F: TypeAlias = floating[Any]
19+
F32: TypeAlias = float32 # Note: np.array_api uses dtype objects.
20+
I: TypeAlias = integer[Any]
21+
I32: TypeAlias = int32 # Note: np.array_api uses dtype objects.
22+
23+
# Define an NDArray against which we can test the protocols
24+
nparr = np.eye(2)
25+
nparr_i32 = np.asarray([1], dtype=np.int32)
26+
nparr_f32 = np.asarray([1.0], dtype=np.float32)
27+
28+
# =========================================================
29+
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`
30+
31+
arr_ns: xpt.HasArrayNamespace[Any] = nparr
32+
arr_ns_i32: xpt.HasArrayNamespace[Any] = nparr_i32
33+
arr_ns_f32: xpt.HasArrayNamespace[Any] = nparr_f32
34+
35+
# =========================================================
36+
# Ensure that `np.ndarray` instances are assignable to `xpt.Array`.
37+
38+
# Generic Array type
39+
arr_array: xpt.Array[Never] = nparr
40+
41+
# Float Array types
42+
arr_float: xpt.Array[float] = nparr_f32
43+
arr_f: xpt.Array[F] = nparr_f32
44+
arr_f32: xpt.Array[F32] = nparr_f32
45+
46+
# Integer Array types
47+
arr_int: xpt.Array[int, xpt.Array[float | int]] = nparr_i32
48+
arr_i: xpt.Array[int | float, xpt.Array[float | int]] = nparr_i32
749

8-
###
9-
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`.
50+
# =========================================================
51+
# Check np.ndarray against BoolArray and NumericArray type aliases
1052

11-
arr = np.eye(2)
12-
arr_namespace: xpt.HasArrayNamespace[Any] = arr
53+
numericarray: NumericArray = nparr_f32

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /