I have implemented an integer square root function that is branch-free and runs in constant time, using the first variant found in this answer as a base. All possible values for the types byte
, ushort
, and uint
have been exhaustively verified against the Math.Sqrt
function. Validating ulong
and UInt128
completely is not feasible but I have yet to find any edge cases that fail.
It would be nice to add support for types that are larger than 128 bits but I was unable to come up with a way to calculate the constant required. Am curious if anyone has any ideas on how one could solve that problem or otherwise improve the function.
C#
public static class BinaryIntegerConstants<T> where T : IBinaryInteger<T>
{
public static T Size { get; } = T.PopCount(value: T.AllBitsSet);
}
private static T As<T>(this bool value) where T : IBinaryInteger<T> =>
T.CreateTruncating(value: Unsafe.As<bool, byte>(source: ref value));
public static T MostSignificantBit<T>(this T value) where T : IBinaryInteger<T> =>
(BinaryIntegerConstants<T>.Size - T.LeadingZeroCount(value: value));
public static T SquareRoot<T>(this T value) where T : IBinaryInteger<T>, IUnsignedNumber<T> {
var msb = int.CreateTruncating(value: value.MostSignificantBit());
var msbIsOdd = (msb & 1);
var m = ((msb + 1) >> 1);
var mMinusOne = (m - 1);
var mPlusOne = (m + 1);
var x = (T.One << mMinusOne);
var y = (x - (value >> (mPlusOne - msbIsOdd)));
var z = y;
x += x;
if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 8UL)) {
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
}
if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 16UL)) {
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
}
if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 32UL)) {
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
}
if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 64UL)) {
var i = (BinaryIntegerConstants<T>.Size >> 3);
do {
i -= (T.One << 3);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
y = (((y * y) >> mPlusOne) + z);
} while (i != T.Zero);
}
y = (x - y);
x = T.CreateTruncating(value: msbIsOdd);
y -= uint.CreateChecked(value: BinaryIntegerConstants<T>.Size) switch {
8U => (x * ((y * T.CreateChecked(value: 5UL)) >> 4)),
16U => (x * ((y * T.CreateChecked(value: 75UL)) >> 8)),
32U => (x * ((y * T.CreateChecked(value: 19195UL)) >> 16)),
64U => (x * ((y * T.CreateChecked(value: 1257966796UL)) >> 32)),
128U => (x * ((y * T.CreateChecked(value: 5402926248376769403UL)) >> 64)),
_ => throw new NotSupportedException(), // TODO: Research a way to calculate the proper constant at runtime.
};
x = (T.One << (int.CreateTruncating(value: (BinaryIntegerConstants<T>.Size - T.One))));
y -= ((value - (y * y)) > x).As<T>();
if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 8UL)) {
y -= ((value - (y * y)) > x).As<T>();
y -= ((value - (y * y)) > x).As<T>();
}
if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 32UL)) {
y -= ((value - (y * y)) > x).As<T>();
y -= ((value - (y * y)) > x).As<T>();
y -= ((value - (y * y)) > x).As<T>();
}
return (y & (T.AllBitsSet >> 1));
}
32-Bit Asm | .NET 7.0.0 (7.0.22.51805), X64 RyuJIT AVX2
; SquareRoot[[System.UInt32, System.Private.CoreLib]](UInt32)
push rsi
sub rsp,20
mov esi,ecx
mov ecx,esi
call qword ptr [MostSignificantBit[[System.UInt32, System.Private.CoreLib]](UInt32)]
mov edx,eax
and edx,1
inc eax
shr eax,1
lea ecx,[rax-1]
inc eax
mov r8d,1
shlx ecx,r8d,ecx
mov r8d,eax
sub r8d,edx
shrx r8d,esi,r8d
mov r9d,ecx
sub r9d,r8d
add ecx,ecx
mov r8d,r9d
imul r8d,r9d
and eax,1F
shrx r8d,r8d,eax
add r8d,r9d
imul r8d,r8d
shrx r8d,r8d,eax
add r8d,r9d
imul r8d,r8d
shrx r8d,r8d,eax
add r8d,r9d
imul r8d,r8d
shrx r8d,r8d,eax
add r8d,r9d
imul r8d,r8d
shrx r8d,r8d,eax
add r8d,r9d
imul r8d,r8d
shrx r8d,r8d,eax
add r8d,r9d
mov eax,ecx
sub eax,r8d
mov r8d,eax
imul eax,r8d,4AFB
shr eax,10
imul eax,edx
sub r8d,eax
mov eax,r8d
imul eax,r8d
mov edx,esi
sub edx,eax
xor eax,eax
cmp edx,80000000
seta al
sub r8d,eax
mov eax,r8d
imul eax,r8d
mov edx,esi
sub edx,eax
xor eax,eax
cmp edx,80000000
seta al
sub r8d,eax
mov eax,r8d
imul eax,r8d
sub esi,eax
xor eax,eax
cmp esi,80000000
seta al
sub r8d,eax
mov eax,r8d
and eax,7FFFFFFF
add rsp,20
pop rsi
ret
; Total bytes of code 248
64-Bit Asm | .NET 7.0.0 (7.0.22.51805), X64 RyuJIT AVX2
; SquareRoot[[System.UInt64, System.Private.CoreLib]](UInt64)
push rsi
sub rsp,20
mov rsi,rcx
mov rcx,rsi
call qword ptr [MostSignificantBit[[System.UInt64, System.Private.CoreLib]](UInt64)]
mov rdx,rax
and rdx,1
inc rax
shr rax,1
lea rcx,[rax-1]
inc rax
mov r8d,1
shlx rcx,r8,rcx
mov r8d,eax
sub r8d,edx
shrx r8,rsi,r8
mov r9,rcx
sub r9,r8
add rcx,rcx
mov r8,r9
imul r8,r9
and eax,3F
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
imul r8,r8
shrx r8,r8,rax
add r8,r9
mov rax,rcx
sub rax,r8
mov r8,rax
movsxd rax,edx
imul rdx,r8,4AFB0CCC
shr rdx,20
imul rax,rdx
sub r8,rax
mov rax,r8
imul rax,r8
mov rdx,rsi
sub rdx,rax
mov rax,8000000000000000
cmp rdx,rax
seta al
movzx eax,al
sub r8,rax
mov rax,r8
imul rax,r8
mov rdx,rsi
sub rdx,rax
mov rax,8000000000000000
cmp rdx,rax
seta al
movzx eax,al
sub r8,rax
mov rax,r8
imul rax,r8
mov rdx,rsi
sub rdx,rax
mov rax,8000000000000000
cmp rdx,rax
seta al
movzx eax,al
sub r8,rax
mov rax,r8
imul rax,r8
mov rdx,rsi
sub rdx,rax
mov rax,8000000000000000
cmp rdx,rax
seta al
movzx eax,al
sub r8,rax
mov rax,r8
imul rax,r8
mov rdx,rsi
sub rdx,rax
mov rax,8000000000000000
cmp rdx,rax
seta al
movzx eax,al
sub r8,rax
mov rax,r8
imul rax,r8
sub rsi,rax
mov rax,8000000000000000
cmp rsi,rax
seta al
movzx eax,al
sub r8,rax
mov rax,7FFFFFFFFFFFFFFF
and rax,r8
add rsp,20
pop rsi
ret
; Total bytes of code 498
1 Answer 1
how one could solve that problem .... (add support for types that are larger than 128 bits)
Found magic number 1257966796 in Implementation of binary floating-point arithmetic on embedded integer processors. Might help with this goal or just coincidental.
or otherwise improve the function.
Just some minor stuff:
Documentation
Comments in code explaining the algorithm are warranted.
Use hex
Constants 75, 19195, 1257966796, 5402926248376769403 certainly look magical.
At least 0x4B, 0x4AFB, 0x4AFB0CCC, 0x4AFB0CCC06219B7B looks like a pattern.
Let x = 5402926248376769403/264 --> 0.29289321881345247556389585485981.
Notice x is very close to (2 + √2)/2, so the next value may be
(2 + √2)/2 * 2128
99666397752933951918340834954143154528.885... or rounded
99666397752933951918340834954143154529
0x4AFB0CCC06219B7BA682764C8AB54161
"Validating ulong and UInt128 completely is not feasible but I have yet to find any edge cases that fail." --> This also implies OP's 5402926248376769403 may be off-by-1.
Runs in constant time?
Does below run in constant time?
do {
i -= (T.One << 3);
y = (((y * y) >> mPlusOne) + z);
...
y = (((y * y) >> mPlusOne) + z);
} while (i != T.Zero);
Format uniformity
SquareRoot<T>() ...
lacks a preceding blank line.
Simplification?
var mMinusOne = (m - 1);
// var x = (T.One << mMinusOne);
// x += x;
var x = (T.One << m);
-
\$\begingroup\$ Notes: The constants look magical, but aren't (they're clearer in base 10 once one understands that they're a fixed-point representation of
sqrt(0.5)
). The entire function is essentially a loop that has been unrolled in order to allow the compiler to optimized for commonly supported word sizes, meaning that the final loop truly is constant time. \$\endgroup\$Kittoes0124– Kittoes01242023年01月15日 06:40:02 +00:00Commented Jan 15, 2023 at 6:40 -
1