Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 32bc2f8

Browse files
committed
Update API to reflect ArrayFire 3.7.0 release
1 parent 0557ab4 commit 32bc2f8

File tree

19 files changed

+1468
-14
lines changed

19 files changed

+1468
-14
lines changed

‎Cargo.toml‎

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,25 @@ indexing = []
2727
graphics = []
2828
image = []
2929
lapack = []
30+
machine_learning = []
3031
macros = []
3132
random = []
3233
signal = []
3334
sparse = []
3435
statistics = []
3536
vision = []
3637
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
37-
"macros", "random", "signal", "sparse", "statistics", "vision"]
38+
"machine_learning", "macros", "random", "signal", "sparse", "statistics", "vision"]
3839

3940
[dependencies]
4041
libc = "0.2"
4142
num = "0.2"
4243
lazy_static = "1.0"
44+
half = "1.5.0"
4345

4446
[dev-dependencies]
4547
float-cmp = "0.6.0"
48+
half = "1.5.0"
4649

4750
[build-dependencies]
4851
serde_json = "1.0"
@@ -85,3 +88,7 @@ path = "examples/conway.rs"
8588
[[example]]
8689
name = "fft"
8790
path = "examples/fft.rs"
91+
92+
[[example]]
93+
name = "using_half"
94+
path = "examples/using_half.rs"

‎examples/conway.rs‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ fn main() {
77
}
88

99
fn normalise(a: &Array<f32>) -> Array<f32> {
10-
(a / (max_all(&abs(a)).0 as f32))
10+
a / (max_all(&abs(a)).0 as f32)
1111
}
1212

1313
fn conways_game_of_life() {

‎examples/using_half.rs‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
use arrayfire::*;
2+
use half::f16;
3+
4+
fn main() {
5+
set_device(0);
6+
info();
7+
8+
let values: Vec<_> = (1u8..101).map(f32::from).collect();
9+
10+
let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
11+
12+
let hvals = Array::new(&half_values, Dim4::new(&[10, 10, 1, 1]));
13+
14+
print(&hvals);
15+
}

‎src/algorithm/mod.rs‎

Lines changed: 258 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::array::Array;
55
use crate::defines::{AfError, BinaryOp};
66
use crate::error::HANDLE_ERROR;
77
use crate::util::{AfArray, MutAfArray, MutDouble, MutUint};
8-
use crate::util::{HasAfEnum, RealNumber, Scanable};
8+
use crate::util::{HasAfEnum, RealNumber, ReduceByKeyInput,Scanable};
99

1010
#[allow(dead_code)]
1111
extern "C" {
@@ -59,6 +59,71 @@ extern "C" {
5959
op: c_uint,
6060
inclusive: c_int,
6161
) -> c_int;
62+
fn af_all_true_by_key(
63+
keys_out: MutAfArray,
64+
vals_out: MutAfArray,
65+
keys: AfArray,
66+
vals: AfArray,
67+
dim: c_int,
68+
) -> c_int;
69+
fn af_any_true_by_key(
70+
keys_out: MutAfArray,
71+
vals_out: MutAfArray,
72+
keys: AfArray,
73+
vals: AfArray,
74+
dim: c_int,
75+
) -> c_int;
76+
fn af_count_by_key(
77+
keys_out: MutAfArray,
78+
vals_out: MutAfArray,
79+
keys: AfArray,
80+
vals: AfArray,
81+
dim: c_int,
82+
) -> c_int;
83+
fn af_max_by_key(
84+
keys_out: MutAfArray,
85+
vals_out: MutAfArray,
86+
keys: AfArray,
87+
vals: AfArray,
88+
dim: c_int,
89+
) -> c_int;
90+
fn af_min_by_key(
91+
keys_out: MutAfArray,
92+
vals_out: MutAfArray,
93+
keys: AfArray,
94+
vals: AfArray,
95+
dim: c_int,
96+
) -> c_int;
97+
fn af_product_by_key(
98+
keys_out: MutAfArray,
99+
vals_out: MutAfArray,
100+
keys: AfArray,
101+
vals: AfArray,
102+
dim: c_int,
103+
) -> c_int;
104+
fn af_product_by_key_nan(
105+
keys_out: MutAfArray,
106+
vals_out: MutAfArray,
107+
keys: AfArray,
108+
vals: AfArray,
109+
dim: c_int,
110+
nan_val: c_double,
111+
) -> c_int;
112+
fn af_sum_by_key(
113+
keys_out: MutAfArray,
114+
vals_out: MutAfArray,
115+
keys: AfArray,
116+
vals: AfArray,
117+
dim: c_int,
118+
) -> c_int;
119+
fn af_sum_by_key_nan(
120+
keys_out: MutAfArray,
121+
vals_out: MutAfArray,
122+
keys: AfArray,
123+
vals: AfArray,
124+
dim: c_int,
125+
nan_val: c_double,
126+
) -> c_int;
62127
}
63128

64129
macro_rules! dim_reduce_func_def {
@@ -527,7 +592,8 @@ all_reduce_func_def!(
527592
let dims = Dim4::new(&[5, 5, 1, 1]);
528593
let a = randu::<f32>(dims);
529594
print(&a);
530-
println!(\"Result : {:?}\", product_all(&a));
595+
let res = product_all(&a);
596+
println!(\"Result : {:?}\", res);
531597
```
532598
",
533599
product_all,
@@ -1137,3 +1203,193 @@ where
11371203
}
11381204
temp.into()
11391205
}
1206+
1207+
macro_rules! dim_reduce_by_key_func_def {
1208+
($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1209+
#[doc=$brief_str]
1210+
/// # Parameters
1211+
///
1212+
/// - `keys` - key Array
1213+
/// - `vals` - value Array
1214+
/// - `dim` - Dimension along which the input Array is reduced
1215+
///
1216+
/// # Return Values
1217+
///
1218+
/// Tuple of Arrays, with output keys and values after reduction
1219+
///
1220+
#[doc=$ex_str]
1221+
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1222+
dim: i32
1223+
) -> (Array<KeyType>, Array<$out_type>)
1224+
where
1225+
KeyType: ReduceByKeyInput,
1226+
ValueType: HasAfEnum,
1227+
$out_type: HasAfEnum,
1228+
{
1229+
let mut out_keys: i64 = 0;
1230+
let mut out_vals: i64 = 0;
1231+
unsafe {
1232+
let err_val = $ffi_name(
1233+
&mut out_keys as MutAfArray,
1234+
&mut out_vals as MutAfArray,
1235+
keys.get() as AfArray,
1236+
vals.get() as AfArray,
1237+
dim as c_int,
1238+
);
1239+
HANDLE_ERROR(AfError::from(err_val));
1240+
}
1241+
(out_keys.into(), out_vals.into())
1242+
}
1243+
};
1244+
}
1245+
1246+
dim_reduce_by_key_func_def!(
1247+
"
1248+
Key based AND of elements along a given dimension
1249+
1250+
All positive non-zero values are considered true, while negative and zero
1251+
values are considered as false.
1252+
",
1253+
"
1254+
# Examples
1255+
```rust
1256+
use arrayfire::{Dim4, print, randu, all_true_by_key};
1257+
let dims = Dim4::new(&[5, 3, 1, 1]);
1258+
let vals = randu::<f32>(dims);
1259+
let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1260+
print(&vals);
1261+
print(&keys);
1262+
let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1263+
print(&out_keys);
1264+
print(&out_vals);
1265+
```
1266+
",
1267+
all_true_by_key,
1268+
af_all_true_by_key,
1269+
ValueType::AggregateOutType
1270+
);
1271+
1272+
dim_reduce_by_key_func_def!(
1273+
"
1274+
Key based OR of elements along a given dimension
1275+
1276+
All positive non-zero values are considered true, while negative and zero
1277+
values are considered as false.
1278+
",
1279+
"
1280+
# Examples
1281+
```rust
1282+
use arrayfire::{Dim4, print, randu, any_true_by_key};
1283+
let dims = Dim4::new(&[5, 3, 1, 1]);
1284+
let vals = randu::<f32>(dims);
1285+
let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1286+
print(&vals);
1287+
print(&keys);
1288+
let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1289+
print(&out_keys);
1290+
print(&out_vals);
1291+
```
1292+
",
1293+
any_true_by_key,
1294+
af_any_true_by_key,
1295+
ValueType::AggregateOutType
1296+
);
1297+
1298+
dim_reduce_by_key_func_def!(
1299+
"Find total count of elements with similar keys along a given dimension",
1300+
"",
1301+
count_by_key,
1302+
af_count_by_key,
1303+
ValueType::AggregateOutType
1304+
);
1305+
1306+
dim_reduce_by_key_func_def!(
1307+
"Find maximum among values of similar keys along a given dimension",
1308+
"",
1309+
max_by_key,
1310+
af_max_by_key,
1311+
ValueType::AggregateOutType
1312+
);
1313+
1314+
dim_reduce_by_key_func_def!(
1315+
"Find minimum among values of similar keys along a given dimension",
1316+
"",
1317+
min_by_key,
1318+
af_min_by_key,
1319+
ValueType::AggregateOutType
1320+
);
1321+
1322+
dim_reduce_by_key_func_def!(
1323+
"Find product of all values with similar keys along a given dimension",
1324+
"",
1325+
product_by_key,
1326+
af_product_by_key,
1327+
ValueType::ProductOutType
1328+
);
1329+
1330+
dim_reduce_by_key_func_def!(
1331+
"Find sum of all values with similar keys along a given dimension",
1332+
"",
1333+
sum_by_key,
1334+
af_sum_by_key,
1335+
ValueType::AggregateOutType
1336+
);
1337+
1338+
macro_rules! dim_reduce_by_key_nan_func_def {
1339+
($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1340+
#[doc=$brief_str]
1341+
///
1342+
/// This version of sum by key can replaced all NaN values in the input
1343+
/// with a user provided value before performing the reduction operation.
1344+
/// # Parameters
1345+
///
1346+
/// - `keys` - key Array
1347+
/// - `vals` - value Array
1348+
/// - `dim` - Dimension along which the input Array is reduced
1349+
///
1350+
/// # Return Values
1351+
///
1352+
/// Tuple of Arrays, with output keys and values after reduction
1353+
///
1354+
#[doc=$ex_str]
1355+
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1356+
dim: i32, replace_value: f64
1357+
) -> (Array<KeyType>, Array<$out_type>)
1358+
where
1359+
KeyType: ReduceByKeyInput,
1360+
ValueType: HasAfEnum,
1361+
$out_type: HasAfEnum,
1362+
{
1363+
let mut out_keys: i64 = 0;
1364+
let mut out_vals: i64 = 0;
1365+
unsafe {
1366+
let err_val = $ffi_name(
1367+
&mut out_keys as MutAfArray,
1368+
&mut out_vals as MutAfArray,
1369+
keys.get() as AfArray,
1370+
vals.get() as AfArray,
1371+
dim as c_int,
1372+
replace_value as c_double,
1373+
);
1374+
HANDLE_ERROR(AfError::from(err_val));
1375+
}
1376+
(out_keys.into(), out_vals.into())
1377+
}
1378+
};
1379+
}
1380+
1381+
dim_reduce_by_key_nan_func_def!(
1382+
"Compute sum of all values with similar keys along a given dimension",
1383+
"",
1384+
sum_by_key_nan,
1385+
af_sum_by_key_nan,
1386+
ValueType::AggregateOutType
1387+
);
1388+
1389+
dim_reduce_by_key_nan_func_def!(
1390+
"Compute product of all values with similar keys along a given dimension",
1391+
"",
1392+
product_by_key_nan,
1393+
af_product_by_key_nan,
1394+
ValueType::ProductOutType
1395+
);

‎src/arith/mod.rs‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ extern "C" {
8585
fn af_log10(out: MutAfArray, arr: AfArray) -> c_int;
8686
fn af_log2(out: MutAfArray, arr: AfArray) -> c_int;
8787
fn af_sqrt(out: MutAfArray, arr: AfArray) -> c_int;
88+
fn af_rsqrt(out: MutAfArray, arr: AfArray) -> c_int;
8889
fn af_cbrt(out: MutAfArray, arr: AfArray) -> c_int;
8990
fn af_factorial(out: MutAfArray, arr: AfArray) -> c_int;
9091
fn af_tgamma(out: MutAfArray, arr: AfArray) -> c_int;
@@ -199,6 +200,12 @@ unary_func!("Compute the natural logarithm", log, af_log, UnaryOutType);
199200
unary_func!("Compute sin", sin, af_sin, UnaryOutType);
200201
unary_func!("Compute sinh", sinh, af_sinh, UnaryOutType);
201202
unary_func!("Compute the square root", sqrt, af_sqrt, UnaryOutType);
203+
unary_func!(
204+
"Compute the reciprocal square root",
205+
rsqrt,
206+
af_rsqrt,
207+
UnaryOutType
208+
);
202209
unary_func!("Compute tan", tan, af_tan, UnaryOutType);
203210
unary_func!("Compute tanh", tanh, af_tanh, UnaryOutType);
204211

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /