I'm implementing a (hopefully) efficient MSB Radix Sort in Java. My current implementation isn't efficient enough for my use-case (see bottom for a brief performance comparison against Arrays.sort
).
Please rip me to shreds over any general improvements you can suggest; I am currently most interested in suggestions for performance improvements, though.
I do use recursion, as I'm only copying references in each stack frame (so I don't think I'm going to be taking much of a hit there). I allocate two buffers and each layer in the recursion tree uses one arrays as an input and one as an output - each node on each layer sorts a different bucket within the input, storing its results in the output before recursing.
// Sort "amount" integers from "input" and write them to "output".
protected void radicsSort(DataInputStream input, DataOutputStream output, int amount) throws IOException
{
// Read in our buffer to memory
int[] inputBuffer = new int[amount];
for (int x = 0; x < inputBuffer.length; ++x)
inputBuffer[x] = input.readInt();
// Allocate an output buffer as big as the input one
int[] outputBuffer = new int[amount];
//// Sort the input buffer
radicsSort(inputBuffer, outputBuffer, 0, amount, Integer.SIZE - 1);
// Input and output are swapped each layer of recursion, and we recurse as many times
// as there are bits in an integer, so ensure the final output is in outputBuffer.
// Commented out at the moment as we get a warning as technically we already know the result...
//if (Integer.SIZE % 2 == 0)
outputBuffer = inputBuffer;
// Write the output to the file
for (int x = 0; x < outputBuffer.length; ++x)
output.writeInt(outputBuffer[x]);
}
// Sort the integers in "inBuffer" at indices between "start" inclusive and "end" exclusive,
// by the value of the bit at "bit", storing the results in "outBuffer" in the same
// interval as they were taken from
protected void radicsSort(int[] inBuffer, int[] outBuffer, int start, int end, int bit)
{
// TODO: Check negative numbers are sorted correctly with this
int zeroIndex = start; // Counts up from the start
int oneIndex = end; // Counts down from the end
for (int x = start; x < end; ++x)
{
if (((inBuffer[x] >> bit) & 1) == 0) // Bit is 0
outBuffer[zeroIndex++] = inBuffer[x];
else // Bit is 1
outBuffer[--oneIndex] = inBuffer[x];
}
if (bit > 0)
{
// Recurse to sort the two sub-regions, moving one bit towards the LSB.
// Swap the input/output buffers
radicsSort(outBuffer, inBuffer, start, oneIndex, bit - 1);
radicsSort(outBuffer, inBuffer, oneIndex, end, bit - 1);
}
}
Performance comparison with Arrays.sort
, times given in milliseconds:
Number of integers Arrays.sort Above implementation
1024 1 8328
1048576 1221 19367
I understand that radix sort has a high constant overhead, but I thought it would have overcome that by the time we're sorting 4MiB of data.
The things I can think of to tweak are:
- Use iteration instead of recursion (don't think this will change much).
- Sort by eg. decimal digit instead of binary digit. Again, don't know how much of a performance impact this'll have.
- Switch to eg. insertion sort after a threshold. I believe this is normally done to overcome the overhead of allocating new structures when you're working with small buckets. I think my implementation avoids this overhead anyway, by using a common output buffer and swapping input/output buffers, so I don't think this'll have much of an impact.
2 Answers 2
Post line-oriented timings from a profiling run, please. Where does it actually spend its time?
Curly braces {} are good for you, even when your
for
orif
statement is a one-liner. Someone will be maintaining this code later, and it might not be you.Please rename to
radixSort()
, from the Latin word for root.Recursion overhead is simply not an issue here.
I imagine you're testing with random integers, so stability or reversing would not affect that. But think about timsort https://en.wikipedia.org/wiki/Timsort which is stable. If there are runs of non-descending input values, you would really like to preserve them. That would let you verify and terminate early.
Arrays.sort()
runs in \$O(n \log{n})\$ time. Your radix sort is "efficient" in the sense that it runs in \$O(n)\$ time, but the notation can hide impressively large constants, like bits per Integer in your case. A linear read scan tries to blow out memory bandwidth, and L2 (last level) cache won't help you. You're not even working on cache-size subproblems with a subsequent merge. Yes, I agree that switching to an \$n\log{n}\$ sort for limited size subproblems would be helpful.You mentioned decimal digits, and I see no reason to go there. But consider sorting into 4 buckets during a pass, or 8, or 16....
-
\$\begingroup\$ In the
O(n)
speak, then
can be very misleading indeed. \$\endgroup\$vnp– vnp2017年08月12日 23:29:07 +00:00Commented Aug 12, 2017 at 23:29
The actual specifier is MSD (most-significant digit) and not MSB (most-significant bit, I suppose). What comes to performance, years ago I wrote a MSD radix sort that takes a parameter, call it \$d\$. Given such a \$d\,ドル the algorithm treats a \$d\$-bit sequence as a key. I benchmarked all the values \$d = 1, 2, \dots, 16\,ドル and \$d = 8\$ (one byte at a time) was a clear winner. The moral here is: process one byte at a time and not a single bit.
Arrays.sort
takes 1000-fold time, while theAbove implementation
is just 2.5 longer. More datapoint will be seriously enlightening. \$\endgroup\$