Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 05a4da5

Browse files
committed
Add optional serde serialization support
- Update ci to run serde tests - Add serialization support for Enums except the enum `arrayfire::Scalar` - Structs with serde support added - [x] Array - [x] Dim4 - [x] Seq - [x] RandomEngine - Structs without serde support - Features - currently not possible as `af_features` can't be recreated from individual `af_arrays` with current upstream API - Indexer - not possible with current API. Also, any subarray when fetched to host for serialization results in separate owned copy this making serde support for this unnecessary. - Callback - Event - Window
1 parent 97d097b commit 05a4da5

File tree

8 files changed

+275
-4
lines changed

8 files changed

+275
-4
lines changed

‎.github/workflows/ci.yml‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ jobs:
4848
export AF_PATH=${GITHUB_WORKSPACE}/afbin
4949
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${AF_PATH}/lib64
5050
echo "Using cargo version: $(cargo --version)"
51-
cargo build --all
52-
cargo test --no-fail-fast
51+
cargo build --all --all-features
52+
cargo test --no-fail-fast --all-features
5353
5454
format:
5555
name: Format Check

‎Cargo.toml‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,19 @@ statistics = []
4646
vision = []
4747
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
4848
"ml", "macros", "random", "signal", "sparse", "statistics", "vision"]
49+
afserde = ["serde"]
4950

5051
[dependencies]
5152
libc = "0.2"
5253
num = "0.2"
5354
lazy_static = "1.0"
5455
half = "1.5.0"
56+
serde = { version = "1.0", features = ["derive"], optional = true }
5557

5658
[dev-dependencies]
5759
half = "1.5.0"
60+
serde_json = "1.0"
61+
bincode = "1.3"
5862

5963
[build-dependencies]
6064
serde_json = "1.0"

‎README.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Only, Major(M) & Minor(m) version numbers need to match. *p1* and *p2* are patch
1616

1717
## Supported platforms
1818

19-
Linux, Windows and OSX. Rust 1.15.1 or higher is required.
19+
Linux, Windows and OSX. Rust 1.31 or newer is required.
2020

2121
## Use from Crates.io [![][6]][7] [![][8]][9]
2222

‎src/core/array.rs‎

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,73 @@ pub fn is_eval_manual() -> bool {
851851
}
852852
}
853853

854+
#[cfg(feature = "afserde")]
855+
mod afserde {
856+
// Reimport required from super scope
857+
use super::{Array, DType, Dim4, HasAfEnum};
858+
859+
use serde::de::{Deserializer, Error, Unexpected};
860+
use serde::ser::Serializer;
861+
use serde::{Deserialize, Serialize};
862+
863+
#[derive(Debug, Serialize, Deserialize)]
864+
struct ArrayOnHost<T: HasAfEnum + std::fmt::Debug> {
865+
dtype: DType,
866+
shape: Dim4,
867+
data: Vec<T>,
868+
}
869+
870+
/// Serialize Implementation of Array
871+
impl<T> Serialize for Array<T>
872+
where
873+
T: std::default::Default + std::clone::Clone + Serialize + HasAfEnum + std::fmt::Debug,
874+
{
875+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
876+
where
877+
S: Serializer,
878+
{
879+
let mut vec = vec![T::default(); self.elements()];
880+
self.host(&mut vec);
881+
let arr_on_host = ArrayOnHost {
882+
dtype: self.get_type(),
883+
shape: self.dims().clone(),
884+
data: vec,
885+
};
886+
arr_on_host.serialize(serializer)
887+
}
888+
}
889+
890+
/// Deserialize Implementation of Array
891+
impl<'de, T> Deserialize<'de> for Array<T>
892+
where
893+
T: Deserialize<'de> + HasAfEnum + std::fmt::Debug,
894+
{
895+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
896+
where
897+
D: Deserializer<'de>,
898+
{
899+
match ArrayOnHost::<T>::deserialize(deserializer) {
900+
Ok(arr_on_host) => {
901+
let read_dtype = arr_on_host.dtype;
902+
let expected_dtype = T::get_af_dtype();
903+
if expected_dtype != read_dtype {
904+
let error_msg = format!(
905+
"data type is {:?}, deserialized type is {:?}",
906+
expected_dtype, read_dtype
907+
);
908+
return Err(Error::invalid_value(Unexpected::Enum, &error_msg.as_str()));
909+
}
910+
Ok(Array::<T>::new(
911+
&arr_on_host.data,
912+
arr_on_host.shape.clone(),
913+
))
914+
}
915+
Err(err) => Err(err),
916+
}
917+
}
918+
}
919+
}
920+
854921
#[cfg(test)]
855922
mod tests {
856923
use super::super::array::print;
@@ -1082,4 +1149,37 @@ mod tests {
10821149
// 8.0000 8.0000 8.0000
10831150
// ANCHOR_END: accum_using_channel
10841151
}
1152+
1153+
#[cfg(feature = "afserde")]
1154+
mod serde_tests {
1155+
use super::super::Array;
1156+
use crate::algorithm::sum_all;
1157+
use crate::randu;
1158+
1159+
#[test]
1160+
fn array_serde_json() {
1161+
let input = randu!(u8; 2, 2);
1162+
let serd = match serde_json::to_string(&input) {
1163+
Ok(serialized_str) => serialized_str,
1164+
Err(e) => e.to_string(),
1165+
};
1166+
1167+
let deserd: Array<u8> = serde_json::from_str(&serd).unwrap();
1168+
1169+
assert_eq!(sum_all(&(input - deserd)), (0u32, 0u32));
1170+
}
1171+
1172+
#[test]
1173+
fn array_serde_bincode() {
1174+
let input = randu!(u8; 2, 2);
1175+
let encoded = match bincode::serialize(&input) {
1176+
Ok(encoded) => encoded,
1177+
Err(_) => vec![],
1178+
};
1179+
1180+
let decoded: Array<u8> = bincode::deserialize(&encoded).unwrap();
1181+
1182+
assert_eq!(sum_all(&(input - decoded)), (0u32, 0u32));
1183+
}
1184+
}
10851185
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /