Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8d3a82f

Browse files
committed
New API to check for half support on given device
Also fixes a doc test in Array struct implementation
1 parent 65cc4ff commit 8d3a82f

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

‎src/core/array.rs‎

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,20 @@ where
189189
/// An example of creating an Array from half::f16 array
190190
///
191191
/// ```rust
192-
/// use arrayfire::{Array, Dim4, print};
192+
/// use arrayfire::{Array, Dim4, is_half_available, print};
193193
/// use half::f16;
194194
///
195195
/// let values: [f32; 3] = [1.0, 2.0, 3.0];
196196
///
197-
/// let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
197+
/// if is_half_available(0) { // Default device is 0, hence the argument
198+
/// let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
198199
///
199-
/// let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1]));
200+
/// let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1]));
200201
///
201-
/// print(&hvals);
202+
/// print(&hvals);
203+
/// } else {
204+
/// println!("Half support isn't available on this device");
205+
/// }
202206
/// ```
203207
///
204208
pub fn new(slice: &[T], dims: Dim4) -> Self {

‎src/core/device.rs‎

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ extern "C" {
3636

3737
fn af_alloc_pinned(non_pagable_ptr: *mut void_ptr, bytes: dim_t) -> c_int;
3838
fn af_free_pinned(non_pagable_ptr: void_ptr) -> c_int;
39+
fn af_get_half_support(available: *mut c_int, device: c_int) -> c_int;
3940
}
4041

4142
/// Get ArrayFire Version Number
@@ -331,3 +332,21 @@ pub unsafe fn free_pinned(ptr: void_ptr) {
331332
let err_val = af_free_pinned(ptr);
332333
HANDLE_ERROR(AfError::from(err_val));
333334
}
335+
336+
/// Check if a device has half support
337+
///
338+
/// # Parameters
339+
///
340+
/// - `device` is the device for which half precision support is checked for
341+
///
342+
/// # Return Values
343+
///
344+
/// `True` if `device` device has half support, `False` otherwise.
345+
pub fn is_half_available(device: i32) -> bool {
346+
unsafe {
347+
let mut temp: i32 = 0;
348+
let err_val = af_get_half_support(&mut temp as *mut c_int, device as c_int);
349+
HANDLE_ERROR(AfError::from(err_val));
350+
temp > 0
351+
}
352+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /