3
\$\begingroup\$

I have implemented an integer square root function that is branch-free and runs in constant time, using the first variant found in this answer as a base. All possible values for the types byte, ushort, and uint have been exhaustively verified against the Math.Sqrt function. Validating ulong and UInt128 completely is not feasible but I have yet to find any edge cases that fail.

It would be nice to add support for types that are larger than 128 bits but I was unable to come up with a way to calculate the constant required. Am curious if anyone has any ideas on how one could solve that problem or otherwise improve the function.

C#

public static class BinaryIntegerConstants<T> where T : IBinaryInteger<T>
{
 public static T Size { get; } = T.PopCount(value: T.AllBitsSet);
}
private static T As<T>(this bool value) where T : IBinaryInteger<T> =>
 T.CreateTruncating(value: Unsafe.As<bool, byte>(source: ref value));
public static T MostSignificantBit<T>(this T value) where T : IBinaryInteger<T> =>
 (BinaryIntegerConstants<T>.Size - T.LeadingZeroCount(value: value));
public static T SquareRoot<T>(this T value) where T : IBinaryInteger<T>, IUnsignedNumber<T> {
 var msb = int.CreateTruncating(value: value.MostSignificantBit());
 var msbIsOdd = (msb & 1);
 var m = ((msb + 1) >> 1);
 var mMinusOne = (m - 1);
 var mPlusOne = (m + 1);
 var x = (T.One << mMinusOne);
 var y = (x - (value >> (mPlusOne - msbIsOdd)));
 var z = y;
 x += x;
 if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 8UL)) {
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 }
 if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 16UL)) {
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 }
 if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 32UL)) {
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 }
 if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 64UL)) {
 var i = (BinaryIntegerConstants<T>.Size >> 3);
 do {
 i -= (T.One << 3);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 y = (((y * y) >> mPlusOne) + z);
 } while (i != T.Zero);
 }
 y = (x - y);
 x = T.CreateTruncating(value: msbIsOdd);
 y -= uint.CreateChecked(value: BinaryIntegerConstants<T>.Size) switch {
 8U => (x * ((y * T.CreateChecked(value: 5UL)) >> 4)),
 16U => (x * ((y * T.CreateChecked(value: 75UL)) >> 8)),
 32U => (x * ((y * T.CreateChecked(value: 19195UL)) >> 16)),
 64U => (x * ((y * T.CreateChecked(value: 1257966796UL)) >> 32)),
 128U => (x * ((y * T.CreateChecked(value: 5402926248376769403UL)) >> 64)),
 _ => throw new NotSupportedException(), // TODO: Research a way to calculate the proper constant at runtime.
 };
 x = (T.One << (int.CreateTruncating(value: (BinaryIntegerConstants<T>.Size - T.One))));
 y -= ((value - (y * y)) > x).As<T>();
 if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 8UL)) {
 y -= ((value - (y * y)) > x).As<T>();
 y -= ((value - (y * y)) > x).As<T>();
 }
 if (BinaryIntegerConstants<T>.Size > T.CreateChecked(value: 32UL)) {
 y -= ((value - (y * y)) > x).As<T>();
 y -= ((value - (y * y)) > x).As<T>();
 y -= ((value - (y * y)) > x).As<T>();
 }
 return (y & (T.AllBitsSet >> 1));
}

32-Bit Asm | .NET 7.0.0 (7.0.22.51805), X64 RyuJIT AVX2

; SquareRoot[[System.UInt32, System.Private.CoreLib]](UInt32)
 push rsi
 sub rsp,20
 mov esi,ecx
 mov ecx,esi
 call qword ptr [MostSignificantBit[[System.UInt32, System.Private.CoreLib]](UInt32)]
 mov edx,eax
 and edx,1
 inc eax
 shr eax,1
 lea ecx,[rax-1]
 inc eax
 mov r8d,1
 shlx ecx,r8d,ecx
 mov r8d,eax
 sub r8d,edx
 shrx r8d,esi,r8d
 mov r9d,ecx
 sub r9d,r8d
 add ecx,ecx
 mov r8d,r9d
 imul r8d,r9d
 and eax,1F
 shrx r8d,r8d,eax
 add r8d,r9d
 imul r8d,r8d
 shrx r8d,r8d,eax
 add r8d,r9d
 imul r8d,r8d
 shrx r8d,r8d,eax
 add r8d,r9d
 imul r8d,r8d
 shrx r8d,r8d,eax
 add r8d,r9d
 imul r8d,r8d
 shrx r8d,r8d,eax
 add r8d,r9d
 imul r8d,r8d
 shrx r8d,r8d,eax
 add r8d,r9d
 mov eax,ecx
 sub eax,r8d
 mov r8d,eax
 imul eax,r8d,4AFB
 shr eax,10
 imul eax,edx
 sub r8d,eax
 mov eax,r8d
 imul eax,r8d
 mov edx,esi
 sub edx,eax
 xor eax,eax
 cmp edx,80000000
 seta al
 sub r8d,eax
 mov eax,r8d
 imul eax,r8d
 mov edx,esi
 sub edx,eax
 xor eax,eax
 cmp edx,80000000
 seta al
 sub r8d,eax
 mov eax,r8d
 imul eax,r8d
 sub esi,eax
 xor eax,eax
 cmp esi,80000000
 seta al
 sub r8d,eax
 mov eax,r8d
 and eax,7FFFFFFF
 add rsp,20
 pop rsi
 ret
; Total bytes of code 248

64-Bit Asm | .NET 7.0.0 (7.0.22.51805), X64 RyuJIT AVX2

; SquareRoot[[System.UInt64, System.Private.CoreLib]](UInt64)
 push rsi
 sub rsp,20
 mov rsi,rcx
 mov rcx,rsi
 call qword ptr [MostSignificantBit[[System.UInt64, System.Private.CoreLib]](UInt64)]
 mov rdx,rax
 and rdx,1
 inc rax
 shr rax,1
 lea rcx,[rax-1]
 inc rax
 mov r8d,1
 shlx rcx,r8,rcx
 mov r8d,eax
 sub r8d,edx
 shrx r8,rsi,r8
 mov r9,rcx
 sub r9,r8
 add rcx,rcx
 mov r8,r9
 imul r8,r9
 and eax,3F
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 imul r8,r8
 shrx r8,r8,rax
 add r8,r9
 mov rax,rcx
 sub rax,r8
 mov r8,rax
 movsxd rax,edx
 imul rdx,r8,4AFB0CCC
 shr rdx,20
 imul rax,rdx
 sub r8,rax
 mov rax,r8
 imul rax,r8
 mov rdx,rsi
 sub rdx,rax
 mov rax,8000000000000000
 cmp rdx,rax
 seta al
 movzx eax,al
 sub r8,rax
 mov rax,r8
 imul rax,r8
 mov rdx,rsi
 sub rdx,rax
 mov rax,8000000000000000
 cmp rdx,rax
 seta al
 movzx eax,al
 sub r8,rax
 mov rax,r8
 imul rax,r8
 mov rdx,rsi
 sub rdx,rax
 mov rax,8000000000000000
 cmp rdx,rax
 seta al
 movzx eax,al
 sub r8,rax
 mov rax,r8
 imul rax,r8
 mov rdx,rsi
 sub rdx,rax
 mov rax,8000000000000000
 cmp rdx,rax
 seta al
 movzx eax,al
 sub r8,rax
 mov rax,r8
 imul rax,r8
 mov rdx,rsi
 sub rdx,rax
 mov rax,8000000000000000
 cmp rdx,rax
 seta al
 movzx eax,al
 sub r8,rax
 mov rax,r8
 imul rax,r8
 sub rsi,rax
 mov rax,8000000000000000
 cmp rsi,rax
 seta al
 movzx eax,al
 sub r8,rax
 mov rax,7FFFFFFFFFFFFFFF
 and rax,r8
 add rsp,20
 pop rsi
 ret
; Total bytes of code 498
asked Jan 3, 2023 at 20:13
\$\endgroup\$

1 Answer 1

1
\$\begingroup\$

how one could solve that problem .... (add support for types that are larger than 128 bits)

Found magic number 1257966796 in Implementation of binary floating-point arithmetic on embedded integer processors. Might help with this goal or just coincidental.


or otherwise improve the function.

Just some minor stuff:

Documentation

Comments in code explaining the algorithm are warranted.

Use hex

Constants 75, 19195, 1257966796, 5402926248376769403 certainly look magical.

At least 0x4B, 0x4AFB, 0x4AFB0CCC, 0x4AFB0CCC06219B7B looks like a pattern.

Let x = 5402926248376769403/264 --> 0.29289321881345247556389585485981.

Notice x is very close to (2 + √2)/2, so the next value may be

(2 + √2)/2 * 2128
99666397752933951918340834954143154528.885... or rounded
99666397752933951918340834954143154529
0x4AFB0CCC06219B7BA682764C8AB54161

"Validating ulong and UInt128 completely is not feasible but I have yet to find any edge cases that fail." --> This also implies OP's 5402926248376769403 may be off-by-1.

Runs in constant time?

Does below run in constant time?

 do {
 i -= (T.One << 3);
 y = (((y * y) >> mPlusOne) + z);
 ...
 y = (((y * y) >> mPlusOne) + z);
 } while (i != T.Zero);

Format uniformity

SquareRoot<T>() ... lacks a preceding blank line.

Simplification?

var mMinusOne = (m - 1);
// var x = (T.One << mMinusOne);
// x += x;
var x = (T.One << m);
answered Jan 14, 2023 at 14:01
\$\endgroup\$
2
  • \$\begingroup\$ Notes: The constants look magical, but aren't (they're clearer in base 10 once one understands that they're a fixed-point representation of sqrt(0.5)). The entire function is essentially a loop that has been unrolled in order to allow the compiler to optimized for commonly supported word sizes, meaning that the final loop truly is constant time. \$\endgroup\$ Commented Jan 15, 2023 at 6:40
  • 1
    \$\begingroup\$ @Kittoes0124 comment would have been useful as comments in code. Further, it does not look like a fixed-point representation of sqrt(0.5), but a fixed-point representation of (2-sqrt(0.5))/2. \$\endgroup\$ Commented Jan 15, 2023 at 6:50

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.