Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add compile time type checks to arrayfire-rust #318

c0dearm started this conversation in Ideas
Discussion options

Hi all lovely Rust 🦀 users!

I would like to initiate a discussion on using Rust's strong type system to avoid undesired panics at runtime with certain operations. As of now, nothing would stop the library user from doing this:

let a: Array<f64> = randu::<f64>(dim4!(1, 1, 3, 4));
let b: Array<f64> = randu::<f64>(dim4!(1, 1, 2, 5));
// The following line panics during runtime (dimensions of a and b mismatch), but it is allowed to compile
let c: Array<f64> = a + b;

In that regard, the API is not that much different of what we would have with a weak type system language. Unfortunately we are not really taking advantage of the tools Rusts gives us in regards of safety. A possible solution is to wrap Array<T> in a tuple struct parametrized by the array dimensions, something like:

struct Tensor<const W: u64, const X: u64, const Y: u64, const Z: u64, T: HasAfEnum>(Array<T>);

And then implement the same traits/functions of Array<T> for Tensor<W,X,Y,Z,T>:

impl<const W: u64, const X: u64, const Y: u64, const Z: u64, T: HasAfEnum> Add
 for Tensor<W, X, Y, Z, T>
{
 type Output = Self;
 fn add(self, rhs: Self) -> Self::Output {
 Self(self.0 + rhs.0)
 }
}

Then we can use them to have compile time guarantees that we are not doing anything mathematically unsound. This does NOT compile:

let a = Tensor::<1,1,3,4,f64>::randu();
let b = Tensor::<1,1,2,5,f64>::randu();
let c = a + b;

The implementation of matmul would look something like this:

fn matmul<const ZO: u64>(self, rhs: Tensor<W, X, Z, ZO, T>) -> Tensor<W, X, Y, ZO> {
 Self(self.0 + rhs.0)
}

I.e. if the left operand matrix has dimensions [1, 1, 3, 4], then the right operand matrix must have dimensions [1, 1, 4, ZO] and the result will have dimensions [1, 1, 3, ZO], where ZO is the number of columns of the right matrix.

If something like this compiles then we know for sure we are not going to have any panic during runtime!

let a = Tensor::<1,1,3,4,f64>::randu();
let b = Tensor::<1,1,4,2,f64>::randu();
let c = Tensor::<1,1,3,2,f64>::randu();
let z = a*b + c;

Looking forward to hearing your thoughts on whether you think this could be a good idea and what approach you would follow to design/implement the API. Thank you all! ❤️

You must be logged in to vote

Replies: 0 comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
1 participant

AltStyle によって変換されたページ (->オリジナル) /