Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 9232326

Browse files
committed
function overloading for arithmetic binary operations
Also, fixed a bug in dim4 module in Dim4::ndims method
1 parent 5e052ac commit 9232326

File tree

4 files changed

+95
-26
lines changed

4 files changed

+95
-26
lines changed

‎build.conf‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{
2-
"use_backend": "cuda",
2+
"use_backend": "cpu",
33

44
"use_lib": false,
55
"lib_dir": "/usr/local/lib",
66
"inc_dir": "/usr/local/include",
77

88
"build_type": "Release",
99
"build_threads": "4",
10-
"build_cuda": "ON",
10+
"build_cuda": "OFF",
1111
"build_opencl": "ON",
1212
"build_cpu": "ON",
1313
"build_examples": "OFF",
@@ -28,7 +28,7 @@
2828
"glew_dir": "E:\\Libraries\\GLEW",
2929
"glfw_dir": "E:\\Libraries\\glfw3",
3030
"boost_dir": "E:\\Libraries\\boost_1_56_0",
31-
31+
3232
"cuda_sdk": "/usr/local/cuda",
3333
"opencl_sdk": "/usr",
3434
"sdk_lib_dir": "lib"

‎examples/helloworld.rs‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ extern crate arrayfire as af;
33
use af::Dim4;
44
use af::Array;
55

6+
#[allow(unused_must_use)]
67
fn main() {
78
af::set_device(0);
89
af::info();
@@ -14,10 +15,9 @@ fn main() {
1415
af::print(&a);
1516

1617
println!("Element-wise arithmetic");
17-
let sin_res = af::sin(&a).unwrap();
18-
let cos_res = af::cos(&a).unwrap();
19-
let b = &sin_res + 1.5;
20-
let b2 = &sin_res + &cos_res;
18+
let b = af::add(af::sin(&a), 1.5).unwrap();
19+
let b2 = af::add(af::sin(&a), af::cos(&a)).unwrap();
20+
2121
let b3 = ! &a;
2222
println!("sin(a) + 1.5 => "); af::print(&b);
2323
println!("sin(a) + cos(a) => "); af::print(&b2);

‎src/arith/mod.rs‎

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
extern crate libc;
22
extern crate num;
33

4+
use dim4::Dim4;
45
use array::Array;
56
use defines::AfError;
67
use self::libc::{c_int};
7-
use data::constant;
8+
use data::{constant, tile};
89
use self::num::Complex;
910

1011
type MutAfArray = *mut self::libc::c_longlong;
@@ -182,32 +183,100 @@ macro_rules! binary_func {
182183
)
183184
}
184185

185-
binary_func!(add, af_add);
186-
binary_func!(sub, af_sub);
187-
binary_func!(mul, af_mul);
188-
binary_func!(div, af_div);
189-
binary_func!(rem, af_rem);
190186
binary_func!(bitand, af_bitand);
191187
binary_func!(bitor, af_bitor);
192188
binary_func!(bitxor, af_bitxor);
193-
binary_func!(shiftl, af_bitshiftl);
194-
binary_func!(shiftr, af_bitshiftr);
195-
binary_func!(lt, af_lt);
196-
binary_func!(gt, af_gt);
197-
binary_func!(le, af_le);
198-
binary_func!(ge, af_ge);
199-
binary_func!(eq, af_eq);
200189
binary_func!(neq, af_neq);
201190
binary_func!(and, af_and);
202191
binary_func!(or, af_or);
203192
binary_func!(minof, af_minof);
204193
binary_func!(maxof, af_maxof);
205-
binary_func!(modulo, af_mod);
206194
binary_func!(hypot, af_hypot);
207-
binary_func!(atan2, af_atan2);
208-
binary_func!(cplx2, af_cplx2);
209-
binary_func!(root, af_root);
210-
binary_func!(pow, af_pow);
195+
196+
pub trait Convertable {
197+
fn convert(&self) -> Array;
198+
}
199+
200+
macro_rules! convertable_type_def {
201+
($rust_type: ty) => (
202+
impl Convertable for $rust_type {
203+
fn convert(&self) -> Array {
204+
constant(*self, Dim4::new(&[1,1,1,1])).unwrap()
205+
}
206+
}
207+
)
208+
}
209+
210+
convertable_type_def!(f64);
211+
convertable_type_def!(f32);
212+
convertable_type_def!(i32);
213+
convertable_type_def!(u32);
214+
convertable_type_def!(u8);
215+
216+
impl Convertable for Array {
217+
fn convert(&self) -> Array {
218+
self.clone()
219+
}
220+
}
221+
222+
impl Convertable for Result<Array, AfError> {
223+
fn convert(&self) -> Array {
224+
self.clone().unwrap()
225+
}
226+
}
227+
228+
macro_rules! overloaded_binary_func {
229+
($fn_name: ident, $help_name: ident, $ffi_name: ident) => (
230+
fn $help_name(lhs: &Array, rhs: &Array) -> Result<Array, AfError> {
231+
unsafe {
232+
let mut temp: i64 = 0;
233+
let err_val = $ffi_name(&mut temp as MutAfArray,
234+
lhs.get() as AfArray, rhs.get() as AfArray,
235+
0);
236+
match err_val {
237+
0 => Ok(Array::from(temp)),
238+
_ => Err(AfError::from(err_val)),
239+
}
240+
}
241+
}
242+
243+
pub fn $fn_name<T: Convertable, U: Convertable> (arg1: T, arg2: U) -> Result<Array, AfError> {
244+
let lhs = arg1.convert();
245+
let rhs = arg2.convert();
246+
match (lhs.is_scalar().unwrap(), rhs.is_scalar().unwrap()) {
247+
( true, false) => {
248+
let l = tile(&lhs, rhs.dims().unwrap()).unwrap();
249+
$help_name(&l, &rhs)
250+
},
251+
(false, true) => {
252+
let r = tile(&rhs, lhs.dims().unwrap()).unwrap();
253+
$help_name(&lhs, &r)
254+
},
255+
_ => $help_name(&lhs, &rhs),
256+
}
257+
}
258+
)
259+
}
260+
261+
// thanks to Umar Arshad for the idea on how to
262+
// implement overloaded function
263+
overloaded_binary_func!(add, add_helper, af_add);
264+
overloaded_binary_func!(sub, sub_helper, af_sub);
265+
overloaded_binary_func!(mul, mul_helper, af_mul);
266+
overloaded_binary_func!(div, div_helper, af_div);
267+
overloaded_binary_func!(rem, rem_helper, af_rem);
268+
overloaded_binary_func!(shiftl, shiftl_helper, af_bitshiftl);
269+
overloaded_binary_func!(shiftr, shiftr_helper, af_bitshiftr);
270+
overloaded_binary_func!(lt, lt_helper, af_lt);
271+
overloaded_binary_func!(gt, gt_helper, af_gt);
272+
overloaded_binary_func!(le, le_helper, af_le);
273+
overloaded_binary_func!(ge, ge_helper, af_ge);
274+
overloaded_binary_func!(eq, eq_helper, af_eq);
275+
overloaded_binary_func!(modulo, modulo_helper, af_mod);
276+
overloaded_binary_func!(atan2, atan2_helper, af_atan2);
277+
overloaded_binary_func!(cplx2, cplx2_helper, af_cplx2);
278+
overloaded_binary_func!(root, root_helper, af_root);
279+
overloaded_binary_func!(pow, pow_helper, af_pow);
211280

212281
macro_rules! arith_scalar_func {
213282
($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => (

‎src/dim4.rs‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl Dim4 {
3939
let nelems = self.elements();
4040
match nelems {
4141
0 => 0,
42-
1 => 0,
42+
1 => 1,
4343
_ => {
4444
if self.dims[3] != 1 { 4 }
4545
else if self.dims[2] != 1 { 3 }

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /