Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit ec86afa

Browse files
Add 3.7 features to python wrapper (#221)
* adds af_pad to python wrapper * adds meanvar to python wrapper * adds inverse square root to python wrapper * adds pinverse to python wrapper * adds NN convolve and gradient functions to wrapper * adds reduce by key to python wrapper missing convolve gradient function * adds confidenceCC to python wrapper * adds fp16 support to python wrapper * update version * remove stray print statements * adds axes_label_format to python wrapper, removes mistakenly copied code
1 parent aead039 commit ec86afa

23 files changed

+636
-7
lines changed

‎__af_version__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
# http://arrayfire.com/licenses/BSD-3-Clause
1010
########################################################
1111

12-
version = "3.5"
13-
release = "20170718"
12+
version = "3.7"
13+
release = "20200213"
1414
full_version = version + "." + release

‎arrayfire/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from .timer import *
7575
from .random import *
7676
from .sparse import *
77+
from .ml import *
7778

7879
# do not export default modules as part of arrayfire
7980
del ct

‎arrayfire/algorithm.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,31 @@ def _nan_reduce_all(a, c_func, nan_val):
4444
imag = imag.value
4545
return real if imag == 0 else real + imag * 1j
4646

47+
def _FNSD(dim, dims):
48+
if dim >= 0:
49+
return int(dim)
50+
51+
fnsd = 0
52+
for i, d in enumerate(dims):
53+
if d > 1:
54+
fnsd = i
55+
break
56+
return int(fnsd)
57+
58+
def _rbk_dim(keys, vals, dim, c_func):
59+
keys_out = Array()
60+
vals_out = Array()
61+
rdim = _FNSD(dim, vals.dims())
62+
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim)))
63+
return keys_out, vals_out
64+
65+
def _nan_rbk_dim(a, dim, c_func, nan_val):
66+
keys_out = Array()
67+
vals_out = Array()
68+
rdim = _FNSD(dim, vals.dims())
69+
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
70+
return keys_out, vals_out
71+
4772
def sum(a, dim=None, nan_val=None):
4873
"""
4974
Calculate the sum of all the elements along a specified dimension.
@@ -74,6 +99,34 @@ def sum(a, dim=None, nan_val=None):
7499
else:
75100
return _reduce_all(a, backend.get().af_sum_all)
76101

102+
103+
def sumByKey(keys, vals, dim=-1, nan_val=None):
104+
"""
105+
Calculate the sum of elements along a specified dimension according to a key.
106+
107+
Parameters
108+
----------
109+
keys : af.Array
110+
One dimensional arrayfire array with reduction keys.
111+
vals : af.Array
112+
Multi dimensional arrayfire array that will be reduced.
113+
dim: optional: int. default: -1
114+
Dimension along which the sum will occur.
115+
nan_val: optional: scalar. default: None
116+
The value that replaces NaN in the array
117+
118+
Returns
119+
-------
120+
keys: af.Array or scalar number
121+
The reduced keys of all elements in `vals` along dimension `dim`.
122+
values: af.Array or scalar number
123+
The sum of all elements in `vals` along dimension `dim` according to keys
124+
"""
125+
if (nan_val is not None):
126+
return _nan_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key_nan, nan_val)
127+
else:
128+
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
129+
77130
def product(a, dim=None, nan_val=None):
78131
"""
79132
Calculate the product of all the elements along a specified dimension.
@@ -104,6 +157,33 @@ def product(a, dim=None, nan_val=None):
104157
else:
105158
return _reduce_all(a, backend.get().af_product_all)
106159

160+
def productByKey(keys, vals, dim=-1, nan_val=None):
161+
"""
162+
Calculate the product of elements along a specified dimension according to a key.
163+
164+
Parameters
165+
----------
166+
keys : af.Array
167+
One dimensional arrayfire array with reduction keys.
168+
vals : af.Array
169+
Multi dimensional arrayfire array that will be reduced.
170+
dim: optional: int. default: -1
171+
Dimension along which the product will occur.
172+
nan_val: optional: scalar. default: None
173+
The value that replaces NaN in the array
174+
175+
Returns
176+
-------
177+
keys: af.Array or scalar number
178+
The reduced keys of all elements in `vals` along dimension `dim`.
179+
values: af.Array or scalar number
180+
The product of all elements in `vals` along dimension `dim` according to keys
181+
"""
182+
if (nan_val is not None):
183+
return _nan_rbk_dim(keys, vals, dim, backend.get().af_product_by_key_nan, nan_val)
184+
else:
185+
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
186+
107187
def min(a, dim=None):
108188
"""
109189
Find the minimum value of all the elements along a specified dimension.
@@ -126,6 +206,28 @@ def min(a, dim=None):
126206
else:
127207
return _reduce_all(a, backend.get().af_min_all)
128208

209+
def minByKey(keys, vals, dim=-1):
210+
"""
211+
Calculate the min of elements along a specified dimension according to a key.
212+
213+
Parameters
214+
----------
215+
keys : af.Array
216+
One dimensional arrayfire array with reduction keys.
217+
vals : af.Array
218+
Multi dimensional arrayfire array that will be reduced.
219+
dim: optional: int. default: -1
220+
Dimension along which the min will occur.
221+
222+
Returns
223+
-------
224+
keys: af.Array or scalar number
225+
The reduced keys of all elements in `vals` along dimension `dim`.
226+
values: af.Array or scalar number
227+
The min of all elements in `vals` along dimension `dim` according to keys
228+
"""
229+
return _rbk_dim(keys, vals, dim, backend.get().af_min_by_key)
230+
129231
def max(a, dim=None):
130232
"""
131233
Find the maximum value of all the elements along a specified dimension.
@@ -148,6 +250,28 @@ def max(a, dim=None):
148250
else:
149251
return _reduce_all(a, backend.get().af_max_all)
150252

253+
def maxByKey(keys, vals, dim=-1):
254+
"""
255+
Calculate the max of elements along a specified dimension according to a key.
256+
257+
Parameters
258+
----------
259+
keys : af.Array
260+
One dimensional arrayfire array with reduction keys.
261+
vals : af.Array
262+
Multi dimensional arrayfire array that will be reduced.
263+
dim: optional: int. default: -1
264+
Dimension along which the max will occur.
265+
266+
Returns
267+
-------
268+
keys: af.Array or scalar number
269+
The reduced keys of all elements in `vals` along dimension `dim`.
270+
values: af.Array or scalar number
271+
The max of all elements in `vals` along dimension `dim` according to keys.
272+
"""
273+
return _rbk_dim(keys, vals, dim, backend.get().af_max_by_key)
274+
151275
def all_true(a, dim=None):
152276
"""
153277
Check if all the elements along a specified dimension are true.
@@ -170,6 +294,28 @@ def all_true(a, dim=None):
170294
else:
171295
return _reduce_all(a, backend.get().af_all_true_all)
172296

297+
def allTrueByKey(keys, vals, dim=-1):
298+
"""
299+
Calculate if all elements are true along a specified dimension according to a key.
300+
301+
Parameters
302+
----------
303+
keys : af.Array
304+
One dimensional arrayfire array with reduction keys.
305+
vals : af.Array
306+
Multi dimensional arrayfire array that will be reduced.
307+
dim: optional: int. default: -1
308+
Dimension along which the all true check will occur.
309+
310+
Returns
311+
-------
312+
keys: af.Array or scalar number
313+
The reduced keys of all true check in `vals` along dimension `dim`.
314+
values: af.Array or scalar number
315+
Booleans denoting if all elements are true in `vals` along dimension `dim` according to keys
316+
"""
317+
return _rbk_dim(keys, vals, dim, backend.get().af_all_true_by_key)
318+
173319
def any_true(a, dim=None):
174320
"""
175321
Check if any the elements along a specified dimension are true.
@@ -192,6 +338,28 @@ def any_true(a, dim=None):
192338
else:
193339
return _reduce_all(a, backend.get().af_any_true_all)
194340

341+
def anyTrueByKey(keys, vals, dim=-1):
342+
"""
343+
Calculate if any elements are true along a specified dimension according to a key.
344+
345+
Parameters
346+
----------
347+
keys : af.Array
348+
One dimensional arrayfire array with reduction keys.
349+
vals : af.Array
350+
Multi dimensional arrayfire array that will be reduced.
351+
dim: optional: int. default: -1
352+
Dimension along which the any true check will occur.
353+
354+
Returns
355+
-------
356+
keys: af.Array or scalar number
357+
The reduced keys of any true check in `vals` along dimension `dim`.
358+
values: af.Array or scalar number
359+
Booleans denoting if any elements are true in `vals` along dimension `dim` according to keys.
360+
"""
361+
return _rbk_dim(keys, vals, dim, backend.get().af_any_true_by_key)
362+
195363
def count(a, dim=None):
196364
"""
197365
Count the number of non zero elements in an array along a specified dimension.
@@ -214,6 +382,28 @@ def count(a, dim=None):
214382
else:
215383
return _reduce_all(a, backend.get().af_count_all)
216384

385+
def countByKey(keys, vals, dim=-1):
386+
"""
387+
Counts non-zero elements along a specified dimension according to a key.
388+
389+
Parameters
390+
----------
391+
keys : af.Array
392+
One dimensional arrayfire array with reduction keys.
393+
vals : af.Array
394+
Multi dimensional arrayfire array that will be reduced.
395+
dim: optional: int. default: -1
396+
Dimension along which to count elements.
397+
398+
Returns
399+
-------
400+
keys: af.Array or scalar number
401+
The reduced keys of count in `vals` along dimension `dim`.
402+
values: af.Array or scalar number
403+
Count of non-zero elements in `vals` along dimension `dim` according to keys.
404+
"""
405+
return _rbk_dim(keys, vals, dim, backend.get().af_count_by_key)
406+
217407
def imin(a, dim=None):
218408
"""
219409
Find the value and location of the minimum value along a specified dimension

‎arrayfire/arith.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,26 @@ def sqrt(a):
958958
"""
959959
return _arith_unary_func(a, backend.get().af_sqrt)
960960

961+
def rsqrt(a):
962+
"""
963+
Reciprocal or inverse square root of each element in the array.
964+
965+
Parameters
966+
----------
967+
a : af.Array
968+
Multi dimensional arrayfire array.
969+
970+
Returns
971+
--------
972+
out : af.Array
973+
array containing the inverse square root of each value from `a`.
974+
975+
Note
976+
-------
977+
`a` must not be complex.
978+
"""
979+
return _arith_unary_func(a, backend.get().af_rsqrt)
980+
961981
def cbrt(a):
962982
"""
963983
Cube root of each element in the array.

‎arrayfire/array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,14 @@ def is_single(self):
783783
safe_call(backend.get().af_is_single(c_pointer(res), self.arr))
784784
return res.value
785785

786+
def is_half(self):
787+
"""
788+
Check if the array is of half floating point type (fp16).
789+
"""
790+
res = c_bool_t(False)
791+
safe_call(backend.get().af_is_half(c_pointer(res), self.arr))
792+
return res.value
793+
786794
def is_real_floating(self):
787795
"""
788796
Check if the array is real and of floating point type.

‎arrayfire/data.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,58 @@ def replace(lhs, cond, rhs):
799799
else:
800800
safe_call(backend.get().af_replace_scalar(lhs.arr, cond.arr, c_double_t(rhs)))
801801

802+
def pad(a, beginPadding, endPadding, padFillType = PAD.ZERO):
803+
"""
804+
Pad an array
805+
806+
This function will pad an array with the specified border size.
807+
Newly padded values can be filled in several different ways.
808+
809+
Parameters
810+
----------
811+
812+
a: af.Array
813+
A multi dimensional input arrayfire array.
814+
815+
beginPadding: tuple of ints. default: (0, 0, 0, 0).
816+
817+
endPadding: tuple of ints. default: (0, 0, 0, 0).
818+
819+
padFillType: optional af.PAD default: af.PAD.ZERO
820+
specifies type of values to fill padded border with
821+
822+
Returns
823+
-------
824+
output: af.Array
825+
A padded array
826+
827+
Examples
828+
---------
829+
>>> import arrayfire as af
830+
>>> a = af.randu(3,3)
831+
>>> af.display(a)
832+
[3 3 1 1]
833+
0.4107 0.1794 0.3775
834+
0.8224 0.4198 0.3027
835+
0.9518 0.0081 0.6456
836+
837+
>>> padded = af.pad(a, (1, 1), (1, 1), af.ZERO)
838+
>>> af.display(padded)
839+
[5 5 1 1]
840+
0.0000 0.0000 0.0000 0.0000 0.0000
841+
0.0000 0.4107 0.1794 0.3775 0.0000
842+
0.0000 0.8224 0.4198 0.3027 0.0000
843+
0.0000 0.9518 0.0081 0.6456 0.0000
844+
0.0000 0.0000 0.0000 0.0000 0.0000
845+
"""
846+
out = Array()
847+
begin_dims = dim4(beginPadding[0], beginPadding[1], beginPadding[2], beginPadding[3])
848+
end_dims = dim4(endPadding[0], endPadding[1], endPadding[2], endPadding[3])
849+
850+
safe_call(backend.get().af_pad(c_pointer(out.arr), a.arr, 4, c_pointer(begin_dims), 4, c_pointer(end_dims), padFillType.value))
851+
return out
852+
853+
802854
def lookup(a, idx, dim=0):
803855
"""
804856
Lookup the values of input array based on index.

‎arrayfire/device.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,25 @@ def is_dbl_supported(device=None):
150150
safe_call(backend.get().af_get_dbl_support(c_pointer(res), dev))
151151
return res.value
152152

153+
def is_half_supported(device=None):
154+
"""
155+
Check if half precision is supported on specified device.
156+
157+
Parameters
158+
-----------
159+
device: optional: int. default: None.
160+
id of the desired device.
161+
162+
Returns
163+
--------
164+
- True if half precision supported.
165+
- False if half precision not supported.
166+
"""
167+
dev = device if device is not None else get_device()
168+
res = c_bool_t(False)
169+
safe_call(backend.get().af_get_half_support(c_pointer(res), dev))
170+
return res.value
171+
153172
def sync(device=None):
154173
"""
155174
Block until all the functions on the device have completed execution.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /