Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit ad77595

Browse files
committed
Fix flake8 errors. Minor refactoring and bugfixes
1 parent 28a0f61 commit ad77595

File tree

18 files changed

+493
-546
lines changed

18 files changed

+493
-546
lines changed

‎arrayfire/algorithm.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#######################################################
2-
# Copyright (c) 2019, ArrayFire
2+
# Copyright (c) 2020, ArrayFire
33
# All rights reserved.
44
#
55
# This file is distributed under 3-clause BSD license.
@@ -14,11 +14,13 @@
1414
from .array import Array
1515
from .library import backend, safe_call, BINARYOP, c_bool_t, c_double_t, c_int_t, c_pointer, c_uint_t
1616

17+
1718
def _parallel_dim(a, dim, c_func):
1819
out = Array()
1920
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim)))
2021
return out
2122

23+
2224
def _reduce_all(a, c_func):
2325
real = c_double_t(0)
2426
imag = c_double_t(0)
@@ -29,11 +31,13 @@ def _reduce_all(a, c_func):
2931
imag = imag.value
3032
return real if imag == 0 else real + imag * 1j
3133

34+
3235
def _nan_parallel_dim(a, dim, c_func, nan_val):
3336
out = Array()
3437
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim), c_double_t(nan_val)))
3538
return out
3639

40+
3741
def _nan_reduce_all(a, c_func, nan_val):
3842
real = c_double_t(0)
3943
imag = c_double_t(0)
@@ -44,6 +48,7 @@ def _nan_reduce_all(a, c_func, nan_val):
4448
imag = imag.value
4549
return real if imag == 0 else real + imag * 1j
4650

51+
4752
def _FNSD(dim, dims):
4853
if dim >= 0:
4954
return int(dim)
@@ -55,20 +60,26 @@ def _FNSD(dim, dims):
5560
break
5661
return int(fnsd)
5762

63+
5864
def _rbk_dim(keys, vals, dim, c_func):
5965
keys_out = Array()
6066
vals_out = Array()
6167
rdim = _FNSD(dim, vals.dims())
6268
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim)))
6369
return keys_out, vals_out
6470

71+
6572
def _nan_rbk_dim(a, dim, c_func, nan_val):
6673
keys_out = Array()
6774
vals_out = Array()
75+
# FIXME: vals is undefined
6876
rdim = _FNSD(dim, vals.dims())
69-
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
77+
# FIXME: keys is undefined
78+
safe_call(c_func(
79+
c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
7080
return keys_out, vals_out
7181

82+
7283
def sum(a, dim=None, nan_val=None):
7384
"""
7485
Calculate the sum of all the elements along a specified dimension.
@@ -88,18 +99,16 @@ def sum(a, dim=None, nan_val=None):
8899
The sum of all elements in `a` along dimension `dim`.
89100
If `dim` is `None`, sum of the entire Array is returned.
90101
"""
91-
if nan_valisnotNone:
92-
if dimisnotNone:
102+
if nan_val:
103+
if dim:
93104
return _nan_parallel_dim(a, dim, backend.get().af_sum_nan, nan_val)
94105
return _nan_reduce_all(a, backend.get().af_sum_nan_all, nan_val)
95106

96-
if dimisnotNone:
107+
if dim:
97108
return _parallel_dim(a, dim, backend.get().af_sum)
98109
return _reduce_all(a, backend.get().af_sum_all)
99110

100111

101-
102-
103112
def sumByKey(keys, vals, dim=-1, nan_val=None):
104113
"""
105114
Calculate the sum of elements along a specified dimension according to a key.
@@ -122,10 +131,10 @@ def sumByKey(keys, vals, dim=-1, nan_val=None):
122131
values: af.Array or scalar number
123132
The sum of all elements in `vals` along dimension `dim` according to keys
124133
"""
125-
if (nan_valisnotNone):
134+
if nan_val:
126135
return _nan_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key_nan, nan_val)
127-
else:
128-
return_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
136+
return_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
137+
129138

130139
def product(a, dim=None, nan_val=None):
131140
"""
@@ -178,10 +187,10 @@ def productByKey(keys, vals, dim=-1, nan_val=None):
178187
values: af.Array or scalar number
179188
The product of all elements in `vals` along dimension `dim` according to keys
180189
"""
181-
if (nan_val is not None):
190+
if nan_val is not None:
182191
return _nan_rbk_dim(keys, vals, dim, backend.get().af_product_by_key_nan, nan_val)
183-
else:
184-
return_rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
192+
return_rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
193+
185194

186195
def min(a, dim=None):
187196
"""
@@ -227,6 +236,7 @@ def minByKey(keys, vals, dim=-1):
227236
"""
228237
return _rbk_dim(keys, vals, dim, backend.get().af_min_by_key)
229238

239+
230240
def max(a, dim=None):
231241
"""
232242
Find the maximum value of all the elements along a specified dimension.
@@ -271,6 +281,7 @@ def maxByKey(keys, vals, dim=-1):
271281
"""
272282
return _rbk_dim(keys, vals, dim, backend.get().af_max_by_key)
273283

284+
274285
def all_true(a, dim=None):
275286
"""
276287
Check if all the elements along a specified dimension are true.
@@ -315,6 +326,7 @@ def allTrueByKey(keys, vals, dim=-1):
315326
"""
316327
return _rbk_dim(keys, vals, dim, backend.get().af_all_true_by_key)
317328

329+
318330
def any_true(a, dim=None):
319331
"""
320332
Check if any the elements along a specified dimension are true.
@@ -334,8 +346,8 @@ def any_true(a, dim=None):
334346
"""
335347
if dim is not None:
336348
return _parallel_dim(a, dim, backend.get().af_any_true)
337-
else:
338-
return_reduce_all(a, backend.get().af_any_true_all)
349+
return_reduce_all(a, backend.get().af_any_true_all)
350+
339351

340352
def anyTrueByKey(keys, vals, dim=-1):
341353
"""
@@ -359,6 +371,7 @@ def anyTrueByKey(keys, vals, dim=-1):
359371
"""
360372
return _rbk_dim(keys, vals, dim, backend.get().af_any_true_by_key)
361373

374+
362375
def count(a, dim=None):
363376
"""
364377
Count the number of non zero elements in an array along a specified dimension.
@@ -378,8 +391,7 @@ def count(a, dim=None):
378391
"""
379392
if dim is not None:
380393
return _parallel_dim(a, dim, backend.get().af_count)
381-
else:
382-
return _reduce_all(a, backend.get().af_count_all)
394+
return _reduce_all(a, backend.get().af_count_all)
383395

384396

385397
def countByKey(keys, vals, dim=-1):
@@ -404,6 +416,7 @@ def countByKey(keys, vals, dim=-1):
404416
"""
405417
return _rbk_dim(keys, vals, dim, backend.get().af_count_by_key)
406418

419+
407420
def imin(a, dim=None):
408421
"""
409422
Find the value and location of the minimum value along a specified dimension

‎arrayfire/arith.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def cast(a, dtype):
7777
out : af.Array
7878
array containing the values from `a` after converting to `dtype`.
7979
"""
80-
out=Array()
80+
out=Array()
8181
safe_call(backend.get().af_cast(c_pointer(out.arr), a.arr, dtype.value))
8282
return out
8383

@@ -156,15 +156,8 @@ def clamp(val, low, high):
156156
vdims = dim4_to_tuple(val.dims())
157157
vty = val.type()
158158

159-
if not is_low_array:
160-
low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
161-
else:
162-
low_arr = low.arr
163-
164-
if not is_high_array:
165-
high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
166-
else:
167-
high_arr = high.arr
159+
low_arr = low.arr if is_low_array else constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
160+
high_arr = high.arr if is_high_array else constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
168161

169162
safe_call(backend.get().af_clamp(c_pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
170163

@@ -1003,6 +996,7 @@ def sqrt(a):
1003996
"""
1004997
return _arith_unary_func(a, backend.get().af_sqrt)
1005998

999+
10061000
def rsqrt(a):
10071001
"""
10081002
Reciprocal or inverse square root of each element in the array.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /