Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0a64db6

Browse files
committed
Merge pull request #20 from 9prady9/function_overloads
Function overloads for binary operations
2 parents 21fc3e6 + 9232326 commit 0a64db6

File tree

6 files changed

+116
-30
lines changed

6 files changed

+116
-30
lines changed

‎build.conf‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{
2-
"use_backend": "cuda",
2+
"use_backend": "cpu",
33

44
"use_lib": false,
55
"lib_dir": "/usr/local/lib",
66
"inc_dir": "/usr/local/include",
77

88
"build_type": "Release",
99
"build_threads": "4",
10-
"build_cuda": "ON",
10+
"build_cuda": "OFF",
1111
"build_opencl": "ON",
1212
"build_cpu": "ON",
1313
"build_examples": "OFF",
@@ -28,7 +28,7 @@
2828
"glew_dir": "E:\\Libraries\\GLEW",
2929
"glfw_dir": "E:\\Libraries\\glfw3",
3030
"boost_dir": "E:\\Libraries\\boost_1_56_0",
31-
31+
3232
"cuda_sdk": "/usr/local/cuda",
3333
"opencl_sdk": "/usr",
3434
"sdk_lib_dir": "lib"

‎examples/helloworld.rs‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ extern crate arrayfire as af;
33
use af::Dim4;
44
use af::Array;
55

6+
#[allow(unused_must_use)]
67
fn main() {
78
af::set_device(0);
89
af::info();
@@ -14,10 +15,9 @@ fn main() {
1415
af::print(&a);
1516

1617
println!("Element-wise arithmetic");
17-
let sin_res = af::sin(&a).unwrap();
18-
let cos_res = af::cos(&a).unwrap();
19-
let b = &sin_res + 1.5;
20-
let b2 = &sin_res + &cos_res;
18+
let b = af::add(af::sin(&a), 1.5).unwrap();
19+
let b2 = af::add(af::sin(&a), af::cos(&a)).unwrap();
20+
2121
let b3 = ! &a;
2222
println!("sin(a) + 1.5 => "); af::print(&b);
2323
println!("sin(a) + cos(a) => "); af::print(&b2);

‎src/arith/mod.rs‎

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
extern crate libc;
22
extern crate num;
33

4+
use dim4::Dim4;
45
use array::Array;
56
use defines::AfError;
67
use self::libc::{c_int};
7-
use data::constant;
8+
use data::{constant, tile};
89
use self::num::Complex;
910

1011
type MutAfArray = *mut self::libc::c_longlong;
@@ -182,32 +183,100 @@ macro_rules! binary_func {
182183
)
183184
}
184185

185-
binary_func!(add, af_add);
186-
binary_func!(sub, af_sub);
187-
binary_func!(mul, af_mul);
188-
binary_func!(div, af_div);
189-
binary_func!(rem, af_rem);
190186
binary_func!(bitand, af_bitand);
191187
binary_func!(bitor, af_bitor);
192188
binary_func!(bitxor, af_bitxor);
193-
binary_func!(shiftl, af_bitshiftl);
194-
binary_func!(shiftr, af_bitshiftr);
195-
binary_func!(lt, af_lt);
196-
binary_func!(gt, af_gt);
197-
binary_func!(le, af_le);
198-
binary_func!(ge, af_ge);
199-
binary_func!(eq, af_eq);
200189
binary_func!(neq, af_neq);
201190
binary_func!(and, af_and);
202191
binary_func!(or, af_or);
203192
binary_func!(minof, af_minof);
204193
binary_func!(maxof, af_maxof);
205-
binary_func!(modulo, af_mod);
206194
binary_func!(hypot, af_hypot);
207-
binary_func!(atan2, af_atan2);
208-
binary_func!(cplx2, af_cplx2);
209-
binary_func!(root, af_root);
210-
binary_func!(pow, af_pow);
195+
196+
pub trait Convertable {
197+
fn convert(&self) -> Array;
198+
}
199+
200+
macro_rules! convertable_type_def {
201+
($rust_type: ty) => (
202+
impl Convertable for $rust_type {
203+
fn convert(&self) -> Array {
204+
constant(*self, Dim4::new(&[1,1,1,1])).unwrap()
205+
}
206+
}
207+
)
208+
}
209+
210+
convertable_type_def!(f64);
211+
convertable_type_def!(f32);
212+
convertable_type_def!(i32);
213+
convertable_type_def!(u32);
214+
convertable_type_def!(u8);
215+
216+
impl Convertable for Array {
217+
fn convert(&self) -> Array {
218+
self.clone()
219+
}
220+
}
221+
222+
impl Convertable for Result<Array, AfError> {
223+
fn convert(&self) -> Array {
224+
self.clone().unwrap()
225+
}
226+
}
227+
228+
macro_rules! overloaded_binary_func {
229+
($fn_name: ident, $help_name: ident, $ffi_name: ident) => (
230+
fn $help_name(lhs: &Array, rhs: &Array) -> Result<Array, AfError> {
231+
unsafe {
232+
let mut temp: i64 = 0;
233+
let err_val = $ffi_name(&mut temp as MutAfArray,
234+
lhs.get() as AfArray, rhs.get() as AfArray,
235+
0);
236+
match err_val {
237+
0 => Ok(Array::from(temp)),
238+
_ => Err(AfError::from(err_val)),
239+
}
240+
}
241+
}
242+
243+
pub fn $fn_name<T: Convertable, U: Convertable> (arg1: T, arg2: U) -> Result<Array, AfError> {
244+
let lhs = arg1.convert();
245+
let rhs = arg2.convert();
246+
match (lhs.is_scalar().unwrap(), rhs.is_scalar().unwrap()) {
247+
( true, false) => {
248+
let l = tile(&lhs, rhs.dims().unwrap()).unwrap();
249+
$help_name(&l, &rhs)
250+
},
251+
(false, true) => {
252+
let r = tile(&rhs, lhs.dims().unwrap()).unwrap();
253+
$help_name(&lhs, &r)
254+
},
255+
_ => $help_name(&lhs, &rhs),
256+
}
257+
}
258+
)
259+
}
260+
261+
// thanks to Umar Arshad for the idea on how to
262+
// implement overloaded function
263+
overloaded_binary_func!(add, add_helper, af_add);
264+
overloaded_binary_func!(sub, sub_helper, af_sub);
265+
overloaded_binary_func!(mul, mul_helper, af_mul);
266+
overloaded_binary_func!(div, div_helper, af_div);
267+
overloaded_binary_func!(rem, rem_helper, af_rem);
268+
overloaded_binary_func!(shiftl, shiftl_helper, af_bitshiftl);
269+
overloaded_binary_func!(shiftr, shiftr_helper, af_bitshiftr);
270+
overloaded_binary_func!(lt, lt_helper, af_lt);
271+
overloaded_binary_func!(gt, gt_helper, af_gt);
272+
overloaded_binary_func!(le, le_helper, af_le);
273+
overloaded_binary_func!(ge, ge_helper, af_ge);
274+
overloaded_binary_func!(eq, eq_helper, af_eq);
275+
overloaded_binary_func!(modulo, modulo_helper, af_mod);
276+
overloaded_binary_func!(atan2, atan2_helper, af_atan2);
277+
overloaded_binary_func!(cplx2, cplx2_helper, af_cplx2);
278+
overloaded_binary_func!(root, root_helper, af_root);
279+
overloaded_binary_func!(pow, pow_helper, af_pow);
211280

212281
macro_rules! arith_scalar_func {
213282
($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => (

‎src/array.rs‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ extern {
5757

5858
fn af_retain_array(out: MutAfArray, arr: AfArray) -> c_int;
5959

60+
fn af_copy_array(out: MutAfArray, arr: AfArray) -> c_int;
61+
6062
fn af_release_array(arr: AfArray) -> c_int;
6163

6264
fn af_print_array(arr: AfArray) -> c_int;
@@ -171,6 +173,17 @@ impl Array {
171173
}
172174
}
173175

176+
pub fn copy(&self) -> Result<Array, AfError> {
177+
unsafe {
178+
let mut temp: i64 = 0;
179+
let err_val = af_copy_array(&mut temp as MutAfArray, self.handle as AfArray);
180+
match err_val {
181+
0 => Ok(Array::from(temp)),
182+
_ => Err(AfError::from(err_val)),
183+
}
184+
}
185+
}
186+
174187
is_func!(is_empty, af_is_empty);
175188
is_func!(is_scalar, af_is_scalar);
176189
is_func!(is_row, af_is_row);

‎src/data/mod.rs‎

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,16 @@ impl ConstGenerator for Complex<f64> {
133133

134134
#[allow(unused_mut)]
135135
impl ConstGenerator for bool {
136-
fn generate(&self, dims: Dim4) -> Array {
136+
fn generate(&self, dims: Dim4) -> Result<Array,AfError> {
137137
unsafe {
138138
let mut temp: i64 = 0;
139-
af_constant(&mut temp as MutAfArray, *self as c_int as c_double,
140-
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT, 4);
141-
Array::from(temp)
139+
let err_val = af_constant(&mut temp as MutAfArray, *self as c_int as c_double,
140+
dims.ndims() as c_uint,
141+
dims.get().as_ptr() as *const DimT, 4);
142+
match err_val {
143+
0 => Ok(Array::from(temp)),
144+
_ => Err(AfError::from(err_val)),
145+
}
142146
}
143147
}
144148
}

‎src/dim4.rs‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl Dim4 {
3939
let nelems = self.elements();
4040
match nelems {
4141
0 => 0,
42-
1 => 0,
42+
1 => 1,
4343
_ => {
4444
if self.dims[3] != 1 { 4 }
4545
else if self.dims[2] != 1 { 3 }

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /