@@ -5,7 +5,7 @@ use crate::array::Array;
55use crate :: defines:: { AfError , BinaryOp } ;
66use crate :: error:: HANDLE_ERROR ;
77use crate :: util:: { AfArray , MutAfArray , MutDouble , MutUint } ;
8- use crate :: util:: { HasAfEnum , RealNumber , Scanable } ;
8+ use crate :: util:: { HasAfEnum , RealNumber , ReduceByKeyInput , Scanable } ;
99
1010#[ allow( dead_code) ]
1111extern "C" {
@@ -59,6 +59,71 @@ extern "C" {
5959 op : c_uint ,
6060 inclusive : c_int ,
6161 ) -> c_int ;
62+ fn af_all_true_by_key (
63+ keys_out : MutAfArray ,
64+ vals_out : MutAfArray ,
65+ keys : AfArray ,
66+ vals : AfArray ,
67+ dim : c_int ,
68+ ) -> c_int ;
69+ fn af_any_true_by_key (
70+ keys_out : MutAfArray ,
71+ vals_out : MutAfArray ,
72+ keys : AfArray ,
73+ vals : AfArray ,
74+ dim : c_int ,
75+ ) -> c_int ;
76+ fn af_count_by_key (
77+ keys_out : MutAfArray ,
78+ vals_out : MutAfArray ,
79+ keys : AfArray ,
80+ vals : AfArray ,
81+ dim : c_int ,
82+ ) -> c_int ;
83+ fn af_max_by_key (
84+ keys_out : MutAfArray ,
85+ vals_out : MutAfArray ,
86+ keys : AfArray ,
87+ vals : AfArray ,
88+ dim : c_int ,
89+ ) -> c_int ;
90+ fn af_min_by_key (
91+ keys_out : MutAfArray ,
92+ vals_out : MutAfArray ,
93+ keys : AfArray ,
94+ vals : AfArray ,
95+ dim : c_int ,
96+ ) -> c_int ;
97+ fn af_product_by_key (
98+ keys_out : MutAfArray ,
99+ vals_out : MutAfArray ,
100+ keys : AfArray ,
101+ vals : AfArray ,
102+ dim : c_int ,
103+ ) -> c_int ;
104+ fn af_product_by_key_nan (
105+ keys_out : MutAfArray ,
106+ vals_out : MutAfArray ,
107+ keys : AfArray ,
108+ vals : AfArray ,
109+ dim : c_int ,
110+ nan_val : c_double ,
111+ ) -> c_int ;
112+ fn af_sum_by_key (
113+ keys_out : MutAfArray ,
114+ vals_out : MutAfArray ,
115+ keys : AfArray ,
116+ vals : AfArray ,
117+ dim : c_int ,
118+ ) -> c_int ;
119+ fn af_sum_by_key_nan (
120+ keys_out : MutAfArray ,
121+ vals_out : MutAfArray ,
122+ keys : AfArray ,
123+ vals : AfArray ,
124+ dim : c_int ,
125+ nan_val : c_double ,
126+ ) -> c_int ;
62127}
63128
64129macro_rules! dim_reduce_func_def {
@@ -527,7 +592,8 @@ all_reduce_func_def!(
527592 let dims = Dim4::new(&[5, 5, 1, 1]);
528593 let a = randu::<f32>(dims);
529594 print(&a);
530- println!(\" Result : {:?}\" , product_all(&a));
595+ let res = product_all(&a);
596+ println!(\" Result : {:?}\" , res);
531597 ```
532598 " ,
533599 product_all,
@@ -1137,3 +1203,193 @@ where
11371203 }
11381204 temp. into ( )
11391205}
1206+ 1207+ macro_rules! dim_reduce_by_key_func_def {
1208+ ( $brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1209+ #[ doc=$brief_str]
1210+ /// # Parameters
1211+ ///
1212+ /// - `keys` - key Array
1213+ /// - `vals` - value Array
1214+ /// - `dim` - Dimension along which the input Array is reduced
1215+ ///
1216+ /// # Return Values
1217+ ///
1218+ /// Tuple of Arrays, with output keys and values after reduction
1219+ ///
1220+ #[ doc=$ex_str]
1221+ pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1222+ dim: i32
1223+ ) -> ( Array <KeyType >, Array <$out_type>)
1224+ where
1225+ KeyType : ReduceByKeyInput ,
1226+ ValueType : HasAfEnum ,
1227+ $out_type: HasAfEnum ,
1228+ {
1229+ let mut out_keys: i64 = 0 ;
1230+ let mut out_vals: i64 = 0 ;
1231+ unsafe {
1232+ let err_val = $ffi_name(
1233+ & mut out_keys as MutAfArray ,
1234+ & mut out_vals as MutAfArray ,
1235+ keys. get( ) as AfArray ,
1236+ vals. get( ) as AfArray ,
1237+ dim as c_int,
1238+ ) ;
1239+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
1240+ }
1241+ ( out_keys. into( ) , out_vals. into( ) )
1242+ }
1243+ } ;
1244+ }
1245+ 1246+ dim_reduce_by_key_func_def ! (
1247+ "
1248+ Key based AND of elements along a given dimension
1249+
1250+ All positive non-zero values are considered true, while negative and zero
1251+ values are considered as false.
1252+ " ,
1253+ "
1254+ # Examples
1255+ ```rust
1256+ use arrayfire::{Dim4, print, randu, all_true_by_key};
1257+ let dims = Dim4::new(&[5, 3, 1, 1]);
1258+ let vals = randu::<f32>(dims);
1259+ let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1260+ print(&vals);
1261+ print(&keys);
1262+ let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1263+ print(&out_keys);
1264+ print(&out_vals);
1265+ ```
1266+ " ,
1267+ all_true_by_key,
1268+ af_all_true_by_key,
1269+ ValueType :: AggregateOutType
1270+ ) ;
1271+ 1272+ dim_reduce_by_key_func_def ! (
1273+ "
1274+ Key based OR of elements along a given dimension
1275+
1276+ All positive non-zero values are considered true, while negative and zero
1277+ values are considered as false.
1278+ " ,
1279+ "
1280+ # Examples
1281+ ```rust
1282+ use arrayfire::{Dim4, print, randu, any_true_by_key};
1283+ let dims = Dim4::new(&[5, 3, 1, 1]);
1284+ let vals = randu::<f32>(dims);
1285+ let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1286+ print(&vals);
1287+ print(&keys);
1288+ let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1289+ print(&out_keys);
1290+ print(&out_vals);
1291+ ```
1292+ " ,
1293+ any_true_by_key,
1294+ af_any_true_by_key,
1295+ ValueType :: AggregateOutType
1296+ ) ;
1297+ 1298+ dim_reduce_by_key_func_def ! (
1299+ "Find total count of elements with similar keys along a given dimension" ,
1300+ "" ,
1301+ count_by_key,
1302+ af_count_by_key,
1303+ ValueType :: AggregateOutType
1304+ ) ;
1305+ 1306+ dim_reduce_by_key_func_def ! (
1307+ "Find maximum among values of similar keys along a given dimension" ,
1308+ "" ,
1309+ max_by_key,
1310+ af_max_by_key,
1311+ ValueType :: AggregateOutType
1312+ ) ;
1313+ 1314+ dim_reduce_by_key_func_def ! (
1315+ "Find minimum among values of similar keys along a given dimension" ,
1316+ "" ,
1317+ min_by_key,
1318+ af_min_by_key,
1319+ ValueType :: AggregateOutType
1320+ ) ;
1321+ 1322+ dim_reduce_by_key_func_def ! (
1323+ "Find product of all values with similar keys along a given dimension" ,
1324+ "" ,
1325+ product_by_key,
1326+ af_product_by_key,
1327+ ValueType :: ProductOutType
1328+ ) ;
1329+ 1330+ dim_reduce_by_key_func_def ! (
1331+ "Find sum of all values with similar keys along a given dimension" ,
1332+ "" ,
1333+ sum_by_key,
1334+ af_sum_by_key,
1335+ ValueType :: AggregateOutType
1336+ ) ;
1337+ 1338+ macro_rules! dim_reduce_by_key_nan_func_def {
1339+ ( $brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1340+ #[ doc=$brief_str]
1341+ ///
1342+ /// This version of sum by key can replaced all NaN values in the input
1343+ /// with a user provided value before performing the reduction operation.
1344+ /// # Parameters
1345+ ///
1346+ /// - `keys` - key Array
1347+ /// - `vals` - value Array
1348+ /// - `dim` - Dimension along which the input Array is reduced
1349+ ///
1350+ /// # Return Values
1351+ ///
1352+ /// Tuple of Arrays, with output keys and values after reduction
1353+ ///
1354+ #[ doc=$ex_str]
1355+ pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1356+ dim: i32 , replace_value: f64
1357+ ) -> ( Array <KeyType >, Array <$out_type>)
1358+ where
1359+ KeyType : ReduceByKeyInput ,
1360+ ValueType : HasAfEnum ,
1361+ $out_type: HasAfEnum ,
1362+ {
1363+ let mut out_keys: i64 = 0 ;
1364+ let mut out_vals: i64 = 0 ;
1365+ unsafe {
1366+ let err_val = $ffi_name(
1367+ & mut out_keys as MutAfArray ,
1368+ & mut out_vals as MutAfArray ,
1369+ keys. get( ) as AfArray ,
1370+ vals. get( ) as AfArray ,
1371+ dim as c_int,
1372+ replace_value as c_double,
1373+ ) ;
1374+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
1375+ }
1376+ ( out_keys. into( ) , out_vals. into( ) )
1377+ }
1378+ } ;
1379+ }
1380+ 1381+ dim_reduce_by_key_nan_func_def ! (
1382+ "Compute sum of all values with similar keys along a given dimension" ,
1383+ "" ,
1384+ sum_by_key_nan,
1385+ af_sum_by_key_nan,
1386+ ValueType :: AggregateOutType
1387+ ) ;
1388+ 1389+ dim_reduce_by_key_nan_func_def ! (
1390+ "Compute product of all values with similar keys along a given dimension" ,
1391+ "" ,
1392+ product_by_key_nan,
1393+ af_product_by_key_nan,
1394+ ValueType :: ProductOutType
1395+ ) ;
0 commit comments