|
1 | 1 | extern crate libc; |
2 | 2 | extern crate num; |
3 | 3 |
|
| 4 | +use dim4::Dim4; |
4 | 5 | use array::Array; |
5 | 6 | use defines::AfError; |
6 | 7 | use self::libc::{c_int}; |
7 | | -use data::constant; |
| 8 | +use data::{constant, tile}; |
8 | 9 | use self::num::Complex; |
9 | 10 |
|
10 | 11 | type MutAfArray = *mut self::libc::c_longlong; |
@@ -182,32 +183,100 @@ macro_rules! binary_func { |
182 | 183 | ) |
183 | 184 | } |
184 | 185 |
|
185 | | -binary_func!(add, af_add); |
186 | | -binary_func!(sub, af_sub); |
187 | | -binary_func!(mul, af_mul); |
188 | | -binary_func!(div, af_div); |
189 | | -binary_func!(rem, af_rem); |
190 | 186 | binary_func!(bitand, af_bitand); |
191 | 187 | binary_func!(bitor, af_bitor); |
192 | 188 | binary_func!(bitxor, af_bitxor); |
193 | | -binary_func!(shiftl, af_bitshiftl); |
194 | | -binary_func!(shiftr, af_bitshiftr); |
195 | | -binary_func!(lt, af_lt); |
196 | | -binary_func!(gt, af_gt); |
197 | | -binary_func!(le, af_le); |
198 | | -binary_func!(ge, af_ge); |
199 | | -binary_func!(eq, af_eq); |
200 | 189 | binary_func!(neq, af_neq); |
201 | 190 | binary_func!(and, af_and); |
202 | 191 | binary_func!(or, af_or); |
203 | 192 | binary_func!(minof, af_minof); |
204 | 193 | binary_func!(maxof, af_maxof); |
205 | | -binary_func!(modulo, af_mod); |
206 | 194 | binary_func!(hypot, af_hypot); |
207 | | -binary_func!(atan2, af_atan2); |
208 | | -binary_func!(cplx2, af_cplx2); |
209 | | -binary_func!(root, af_root); |
210 | | -binary_func!(pow, af_pow); |
| 195 | + |
| 196 | +pub trait Convertable { |
| 197 | + fn convert(&self) -> Array; |
| 198 | +} |
| 199 | + |
| 200 | +macro_rules! convertable_type_def { |
| 201 | + ($rust_type: ty) => ( |
| 202 | + impl Convertable for $rust_type { |
| 203 | + fn convert(&self) -> Array { |
| 204 | + constant(*self, Dim4::new(&[1,1,1,1])).unwrap() |
| 205 | + } |
| 206 | + } |
| 207 | + ) |
| 208 | +} |
| 209 | + |
| 210 | +convertable_type_def!(f64); |
| 211 | +convertable_type_def!(f32); |
| 212 | +convertable_type_def!(i32); |
| 213 | +convertable_type_def!(u32); |
| 214 | +convertable_type_def!(u8); |
| 215 | + |
| 216 | +impl Convertable for Array { |
| 217 | + fn convert(&self) -> Array { |
| 218 | + self.clone() |
| 219 | + } |
| 220 | +} |
| 221 | + |
| 222 | +impl Convertable for Result<Array, AfError> { |
| 223 | + fn convert(&self) -> Array { |
| 224 | + self.clone().unwrap() |
| 225 | + } |
| 226 | +} |
| 227 | + |
| 228 | +macro_rules! overloaded_binary_func { |
| 229 | + ($fn_name: ident, $help_name: ident, $ffi_name: ident) => ( |
| 230 | + fn $help_name(lhs: &Array, rhs: &Array) -> Result<Array, AfError> { |
| 231 | + unsafe { |
| 232 | + let mut temp: i64 = 0; |
| 233 | + let err_val = $ffi_name(&mut temp as MutAfArray, |
| 234 | + lhs.get() as AfArray, rhs.get() as AfArray, |
| 235 | + 0); |
| 236 | + match err_val { |
| 237 | + 0 => Ok(Array::from(temp)), |
| 238 | + _ => Err(AfError::from(err_val)), |
| 239 | + } |
| 240 | + } |
| 241 | + } |
| 242 | + |
| 243 | + pub fn $fn_name<T: Convertable, U: Convertable> (arg1: T, arg2: U) -> Result<Array, AfError> { |
| 244 | + let lhs = arg1.convert(); |
| 245 | + let rhs = arg2.convert(); |
| 246 | + match (lhs.is_scalar().unwrap(), rhs.is_scalar().unwrap()) { |
| 247 | + ( true, false) => { |
| 248 | + let l = tile(&lhs, rhs.dims().unwrap()).unwrap(); |
| 249 | + $help_name(&l, &rhs) |
| 250 | + }, |
| 251 | + (false, true) => { |
| 252 | + let r = tile(&rhs, lhs.dims().unwrap()).unwrap(); |
| 253 | + $help_name(&lhs, &r) |
| 254 | + }, |
| 255 | + _ => $help_name(&lhs, &rhs), |
| 256 | + } |
| 257 | + } |
| 258 | + ) |
| 259 | +} |
| 260 | + |
| 261 | +// thanks to Umar Arshad for the idea on how to |
| 262 | +// implement overloaded function |
| 263 | +overloaded_binary_func!(add, add_helper, af_add); |
| 264 | +overloaded_binary_func!(sub, sub_helper, af_sub); |
| 265 | +overloaded_binary_func!(mul, mul_helper, af_mul); |
| 266 | +overloaded_binary_func!(div, div_helper, af_div); |
| 267 | +overloaded_binary_func!(rem, rem_helper, af_rem); |
| 268 | +overloaded_binary_func!(shiftl, shiftl_helper, af_bitshiftl); |
| 269 | +overloaded_binary_func!(shiftr, shiftr_helper, af_bitshiftr); |
| 270 | +overloaded_binary_func!(lt, lt_helper, af_lt); |
| 271 | +overloaded_binary_func!(gt, gt_helper, af_gt); |
| 272 | +overloaded_binary_func!(le, le_helper, af_le); |
| 273 | +overloaded_binary_func!(ge, ge_helper, af_ge); |
| 274 | +overloaded_binary_func!(eq, eq_helper, af_eq); |
| 275 | +overloaded_binary_func!(modulo, modulo_helper, af_mod); |
| 276 | +overloaded_binary_func!(atan2, atan2_helper, af_atan2); |
| 277 | +overloaded_binary_func!(cplx2, cplx2_helper, af_cplx2); |
| 278 | +overloaded_binary_func!(root, root_helper, af_root); |
| 279 | +overloaded_binary_func!(pow, pow_helper, af_pow); |
211 | 280 |
|
212 | 281 | macro_rules! arith_scalar_func { |
213 | 282 | ($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => ( |
|
0 commit comments