Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0dd898b

Browse files
committed
Round floats to prevent deltas across platforms
1 parent 43066a2 commit 0dd898b

File tree

11 files changed

+160
-134
lines changed

11 files changed

+160
-134
lines changed

‎tests/difftests/lib/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use-compiled-tools = [
1515
"spirv-builder/use-compiled-tools"
1616
]
1717

18-
[dependencies]
18+
[target.'cfg(not(target_arch="spirv"))'.dependencies]
1919
spirv-builder.workspace = true
2020
serde = { version = "1.0", features = ["derive"] }
2121
serde_json = "1.0"

‎tests/difftests/lib/src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1+
#![cfg_attr(target_arch = "spirv", no_std)]
2+
3+
#[cfg(not(target_arch = "spirv"))]
14
pub mod config;
5+
#[cfg(not(target_arch = "spirv"))]
26
pub mod scaffold;
37

8+
/// Macro to round a f32 value to 6 decimal places for cross-platform consistency
9+
/// in floating-point operations. This helps ensure difftest results are consistent
10+
/// across different platforms (Linux, Mac, Windows) which may have slight differences
11+
/// in floating-point implementations.
12+
#[macro_export]
13+
macro_rules! round6 {
14+
($v:expr) => {
15+
(($v) * 1_000_000.0).round() / 1_000_000.0
16+
};
17+
}
18+
419
#[cfg(test)]
520
mod tests {
621
use super::config::Config;

‎tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/Cargo.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@ crate-type = ["dylib"]
1010

1111
# Common deps
1212
[dependencies]
13-
14-
# GPU deps
1513
spirv-std.workspace = true
14+
difftest.workspace = true
1615

1716
# CPU deps
1817
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
19-
difftest.workspace = true
20-
bytemuck.workspace = true
18+
bytemuck.workspace = true

‎tests/difftests/tests/lang/core/ops/math_ops/math_ops-rust/src/lib.rs

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![no_std]
22

3+
use difftest::round6;
34
#[allow(unused_imports)]
45
use spirv_std::num_traits::Float;
56
use spirv_std::spirv;
@@ -24,28 +25,28 @@ pub fn main_cs(
2425
}
2526

2627
// Basic arithmetic
27-
output[base_offset + 0] = x + 1.5;
28-
output[base_offset + 1] = x - 0.5;
29-
output[base_offset + 2] = x * 2.0;
30-
output[base_offset + 3] = x / 2.0;
31-
output[base_offset + 4] = x % 3.0;
28+
output[base_offset + 0] = round6!(x + 1.5);
29+
output[base_offset + 1] = round6!(x - 0.5);
30+
output[base_offset + 2] = round6!(x * 2.0);
31+
output[base_offset + 3] = round6!(x / 2.0);
32+
output[base_offset + 4] = round6!(x % 3.0);
3233

3334
// Trigonometric functions (simplified for consistent results)
34-
output[base_offset + 5] = x.sin();
35-
output[base_offset + 6] = x.cos();
36-
output[base_offset + 7] = x.tan().clamp(-10.0, 10.0);
35+
output[base_offset + 5] = round6!(x.sin());
36+
output[base_offset + 6] = round6!(x.cos());
37+
output[base_offset + 7] = round6!(x.tan().clamp(-10.0, 10.0));
3738
output[base_offset + 8] = 0.0;
3839
output[base_offset + 9] = 0.0;
39-
output[base_offset + 10] = x.atan();
40+
output[base_offset + 10] = round6!(x.atan());
4041

4142
// Exponential and logarithmic (simplified)
42-
output[base_offset + 11] = x.exp().min(1e6);
43-
output[base_offset + 12] = if x > 0.0 { x.ln() } else { -10.0 };
44-
output[base_offset + 13] = x.abs().sqrt();
45-
output[base_offset + 14] = x.abs().powf(2.0);
46-
output[base_offset + 15] = if x > 0.0 { x.log2() } else { -10.0 };
47-
output[base_offset + 16] = x.exp2().min(1e6);
48-
output[base_offset + 17] = x.floor();
43+
output[base_offset + 11] = round6!(x.exp().min(1e6));
44+
output[base_offset + 12] = round6!(if x > 0.0 { x.ln() } else { -10.0 });
45+
output[base_offset + 13] = round6!(x.abs().sqrt());
46+
output[base_offset + 14] = round6!(x.abs()* x.abs());// Use multiplication instead of powf
47+
output[base_offset + 15] = round6!(if x > 0.0 { x.log2() } else { -10.0 });
48+
output[base_offset + 16] = round6!(x.exp2().min(1e6));
49+
output[base_offset + 17] = x.floor();// floor/ceil/round are exact
4950
output[base_offset + 18] = x.ceil();
5051
output[base_offset + 19] = x.round();
5152

‎tests/difftests/tests/lang/core/ops/math_ops/math_ops-wgsl/shader.wgsl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ var<storage, read> input: array<f32>;
44
@group(0) @binding(1)
55
var<storage, read_write> output: array<f32>;
66

7+
// Helper function to round to 6 decimal places for cross-platform consistency
8+
fn round6(v: f32) -> f32 {
9+
return round(v * 1000000.0) / 1000000.0;
10+
}
11+
712
@compute @workgroup_size(32, 1, 1)
813
fn main_cs(@builtin(global_invocation_id) global_id: vec3<u32>) {
914
let tid = global_id.x;
@@ -20,28 +25,28 @@ fn main_cs(@builtin(global_invocation_id) global_id: vec3<u32>) {
2025
}
2126

2227
// Basic arithmetic
23-
output[base_offset + 0u] = x + 1.5;
24-
output[base_offset + 1u] = x - 0.5;
25-
output[base_offset + 2u] = x * 2.0;
26-
output[base_offset + 3u] = x / 2.0;
27-
output[base_offset + 4u] = x % 3.0;
28+
output[base_offset + 0u] = round6(x + 1.5);
29+
output[base_offset + 1u] = round6(x - 0.5);
30+
output[base_offset + 2u] = round6(x * 2.0);
31+
output[base_offset + 3u] = round6(x / 2.0);
32+
output[base_offset + 4u] = round6(x % 3.0);
2833

2934
// Trigonometric functions (simplified for consistent results)
30-
output[base_offset + 5u] = sin(x);
31-
output[base_offset + 6u] = cos(x);
32-
output[base_offset + 7u] = clamp(tan(x), -10.0, 10.0);
35+
output[base_offset + 5u] = round6(sin(x));
36+
output[base_offset + 6u] = round6(cos(x));
37+
output[base_offset + 7u] = round6(clamp(tan(x), -10.0, 10.0));
3338
output[base_offset + 8u] = 0.0;
3439
output[base_offset + 9u] = 0.0;
35-
output[base_offset + 10u] = atan(x);
40+
output[base_offset + 10u] = round6(atan(x));
3641

3742
// Exponential and logarithmic (simplified)
38-
output[base_offset + 11u] = min(exp(x), 1e6);
39-
output[base_offset + 12u] = select(-10.0, log(x), x > 0.0);
40-
output[base_offset + 13u] = sqrt(abs(x));
41-
output[base_offset + 14u] = pow(abs(x), 2.0);
42-
output[base_offset + 15u] = select(-10.0, log2(x), x > 0.0);
43-
output[base_offset + 16u] = min(exp2(x), 1e6);
44-
output[base_offset + 17u] = floor(x);
43+
output[base_offset + 11u] = round6(min(exp(x), 1e6));
44+
output[base_offset + 12u] = round6(select(-10.0, log(x), x > 0.0));
45+
output[base_offset + 13u] = round6(sqrt(abs(x)));
46+
output[base_offset + 14u] = round6(abs(x)*abs(x)); // Use multiplication instead of pow
47+
output[base_offset + 15u] = round6(select(-10.0, log2(x), x > 0.0));
48+
output[base_offset + 16u] = round6(min(exp2(x), 1e6));
49+
output[base_offset + 17u] = floor(x); // floor/ceil/round are exact
4550
output[base_offset + 18u] = ceil(x);
4651
output[base_offset + 19u] = round(x);
4752

‎tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/Cargo.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@ crate-type = ["dylib"]
1010

1111
# Common deps
1212
[dependencies]
13-
14-
# GPU deps
1513
spirv-std.workspace = true
14+
difftest.workspace = true
1615

1716
# CPU deps
1817
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
19-
difftest.workspace = true
20-
bytemuck.workspace = true
18+
bytemuck.workspace = true

‎tests/difftests/tests/lang/core/ops/matrix_ops/matrix_ops-rust/src/lib.rs

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![no_std]
22

3+
use difftest::round6;
34
use spirv_std::glam::{Mat2, Mat3, Mat4, UVec3, Vec2, Vec3, Vec4};
45
#[allow(unused_imports)]
56
use spirv_std::num_traits::Float;
@@ -35,10 +36,10 @@ pub fn main_cs(
3536

3637
// Mat2 multiplication
3738
let m2_mul = m2a * m2b;
38-
output[base_offset + 0] = m2_mul.col(0).x;
39-
output[base_offset + 1] = m2_mul.col(0).y;
40-
output[base_offset + 2] = m2_mul.col(1).x;
41-
output[base_offset + 3] = m2_mul.col(1).y;
39+
output[base_offset + 0] = round6!(m2_mul.col(0).x);
40+
output[base_offset + 1] = round6!(m2_mul.col(0).y);
41+
output[base_offset + 2] = round6!(m2_mul.col(1).x);
42+
output[base_offset + 3] = round6!(m2_mul.col(1).y);
4243

4344
// Mat2 transpose
4445
let m2_transpose = m2a.transpose();
@@ -48,29 +49,29 @@ pub fn main_cs(
4849
output[base_offset + 7] = m2_transpose.col(1).y;
4950

5051
// Mat2 determinant (with rounding for consistency)
51-
output[base_offset + 8] = (m2a.determinant()*1000.0).round() / 1000.0;
52+
output[base_offset + 8] = round6!(m2a.determinant());
5253

5354
// Mat2 * Vec2
5455
let v2 = Vec2::new(1.0, 2.0);
5556
let m2_v2 = m2a * v2;
56-
output[base_offset + 9] = m2_v2.x;
57-
output[base_offset + 10] = m2_v2.y;
57+
output[base_offset + 9] = round6!(m2_v2.x);
58+
output[base_offset + 10] = round6!(m2_v2.y);
5859

5960
// Mat3 operations
6061
let m3a = Mat3::from_cols(Vec3::new(a, b, c), Vec3::new(b, c, d), Vec3::new(c, d, a));
6162
let m3b = Mat3::from_cols(Vec3::new(d, c, b), Vec3::new(c, b, a), Vec3::new(b, a, d));
6263

6364
// Mat3 multiplication
6465
let m3_mul = m3a * m3b;
65-
output[base_offset + 11] = m3_mul.col(0).x;
66-
output[base_offset + 12] = m3_mul.col(0).y;
67-
output[base_offset + 13] = m3_mul.col(0).z;
68-
output[base_offset + 14] = m3_mul.col(1).x;
69-
output[base_offset + 15] = m3_mul.col(1).y;
70-
output[base_offset + 16] = m3_mul.col(1).z;
71-
output[base_offset + 17] = m3_mul.col(2).x;
72-
output[base_offset + 18] = m3_mul.col(2).y;
73-
output[base_offset + 19] = m3_mul.col(2).z;
66+
output[base_offset + 11] = round6!(m3_mul.col(0).x);
67+
output[base_offset + 12] = round6!(m3_mul.col(0).y);
68+
output[base_offset + 13] = round6!(m3_mul.col(0).z);
69+
output[base_offset + 14] = round6!(m3_mul.col(1).x);
70+
output[base_offset + 15] = round6!(m3_mul.col(1).y);
71+
output[base_offset + 16] = round6!(m3_mul.col(1).z);
72+
output[base_offset + 17] = round6!(m3_mul.col(2).x);
73+
output[base_offset + 18] = round6!(m3_mul.col(2).y);
74+
output[base_offset + 19] = round6!(m3_mul.col(2).z);
7475

7576
// Mat3 transpose - store just diagonal elements
7677
let m3_transpose = m3a.transpose();
@@ -79,14 +80,14 @@ pub fn main_cs(
7980
output[base_offset + 22] = m3_transpose.col(2).z;
8081

8182
// Mat3 determinant (with rounding for consistency)
82-
output[base_offset + 23] = (m3a.determinant()*1000.0).round() / 1000.0;
83+
output[base_offset + 23] = round6!(m3a.determinant());
8384

8485
// Mat3 * Vec3 (with rounding for consistency)
8586
let v3 = Vec3::new(1.0, 2.0, 3.0);
8687
let m3_v3 = m3a * v3;
87-
output[base_offset + 24] = (m3_v3.x*10000.0).round() / 10000.0;
88-
output[base_offset + 25] = (m3_v3.y*10000.0).round() / 10000.0;
89-
output[base_offset + 26] = (m3_v3.z*10000.0).round() / 10000.0;
88+
output[base_offset + 24] = round6!(m3_v3.x);
89+
output[base_offset + 25] = round6!(m3_v3.y);
90+
output[base_offset + 26] = round6!(m3_v3.z);
9091

9192
// Mat4 operations
9293
let m4a = Mat4::from_cols(
@@ -104,10 +105,10 @@ pub fn main_cs(
104105

105106
// Mat4 multiplication (just store diagonal for brevity)
106107
let m4_mul = m4a * m4b;
107-
output[base_offset + 27] = m4_mul.col(0).x;
108-
output[base_offset + 28] = m4_mul.col(1).y;
109-
output[base_offset + 29] = m4_mul.col(2).z;
110-
output[base_offset + 30] = m4_mul.col(3).w;
108+
output[base_offset + 27] = round6!(m4_mul.col(0).x);
109+
output[base_offset + 28] = round6!(m4_mul.col(1).y);
110+
output[base_offset + 29] = round6!(m4_mul.col(2).z);
111+
output[base_offset + 30] = round6!(m4_mul.col(3).w);
111112

112113
// Mat4 transpose (just store diagonal)
113114
let m4_transpose = m4a.transpose();
@@ -117,15 +118,15 @@ pub fn main_cs(
117118
output[base_offset + 34] = m4_transpose.col(3).w;
118119

119120
// Mat4 determinant (with rounding for consistency)
120-
output[base_offset + 35] = (m4a.determinant()*1000.0).round() / 1000.0;
121+
output[base_offset + 35] = round6!(m4a.determinant());
121122

122123
// Mat4 * Vec4 (with rounding for consistency)
123124
let v4 = Vec4::new(1.0, 2.0, 3.0, 4.0);
124125
let m4_v4 = m4a * v4;
125-
output[base_offset + 36] = (m4_v4.x*10000.0).round() / 10000.0;
126-
output[base_offset + 37] = (m4_v4.y*10000.0).round() / 10000.0;
127-
output[base_offset + 38] = (m4_v4.z*10000.0).round() / 10000.0;
128-
output[base_offset + 39] = (m4_v4.w*10000.0).round() / 10000.0;
126+
output[base_offset + 36] = round6!(m4_v4.x);
127+
output[base_offset + 37] = round6!(m4_v4.y);
128+
output[base_offset + 38] = round6!(m4_v4.z);
129+
output[base_offset + 39] = round6!(m4_v4.w);
129130

130131
// Identity matrices
131132
output[base_offset + 40] = Mat2::IDENTITY.col(0).x;
@@ -135,8 +136,8 @@ pub fn main_cs(
135136
// Matrix inverse
136137
if m2a.determinant().abs() > 0.0001 {
137138
let m2_inv = m2a.inverse();
138-
output[base_offset + 43] = m2_inv.col(0).x;
139-
output[base_offset + 44] = m2_inv.col(1).y;
139+
output[base_offset + 43] = round6!(m2_inv.col(0).x);
140+
output[base_offset + 44] = round6!(m2_inv.col(1).y);
140141
} else {
141142
output[base_offset + 43] = 0.0;
142143
output[base_offset + 44] = 0.0;

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /