Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 7bf28b7

Browse files
committed
FEAT: Adding support for dot
1 parent 8bf7c01 commit 7bf28b7

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

‎arrayfire/blas.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def matmulTT(lhs, rhs):
150150
MATPROP.TRANS.value, MATPROP.TRANS.value))
151151
return out
152152

153-
def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
153+
def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, return_scalar=False):
154154
"""
155155
Dot product of two input vectors.
156156
@@ -173,10 +173,13 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
173173
- af.MATPROP.NONE - If no op should be done on `rhs`.
174174
- No other options are currently supported.
175175
176+
return_scalar: optional: bool. default: False.
177+
- When set to true, the input arrays are flattened and the output is a scalar
178+
176179
Returns
177180
-------
178181
179-
out : af.Array
182+
out : af.Array or scalar
180183
Output of dot product of `lhs` and `rhs`.
181184
182185
Note
@@ -186,7 +189,16 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
186189
- Batches are not supported.
187190
188191
"""
189-
out = Array()
190-
safe_call(backend.get().af_dot(c_pointer(out.arr), lhs.arr, rhs.arr,
191-
lhs_opts.value, rhs_opts.value))
192-
return out
192+
if return_scalar:
193+
real = c_double_t(0)
194+
imag = c_double_t(0)
195+
safe_call(backend.get().af_dot_all(c_pointer(real), c_pointer(imag),
196+
lhs.arr, rhs.arr, lhs_opts.value, rhs_opts.value))
197+
real = real.value
198+
imag = imag.value
199+
return real if imag == 0 else real + imag * 1j
200+
else:
201+
out = Array()
202+
safe_call(backend.get().af_dot(c_pointer(out.arr), lhs.arr, rhs.arr,
203+
lhs_opts.value, rhs_opts.value))
204+
return out

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /