1
\$\begingroup\$

Intro

I have this implementation of a parallel MSD (most significant digit) radix sort. It runs in $$\mathcal{O}\Bigg( \bigg(\frac{N}{P} + PB \bigg) \log_B \sigma\Bigg),$$ where \$N\$ is the length of the long range to sort, \$P\$ is the number of cores sorting, \$B\$ is the number of buckets for distributing data, and, finally, \$\sigma\$ is the cardinality of the key universe. (For sorting long keys, we have \$\sigma = 2^{64}\$, and so \$\log_B \sigma = 8\$, which implies that the algorithm in question considers the data one byte at a time in the recursion tree.)

What comes to the actual algorithm design, it adheres to a divide-and-conquer paradigm and my sort distributes the elements in buckets (the divide step) and then recurse on each bucket in order to sort them individually (the conquer step).


Code

package io.github.coderodde.util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/**
 * This class provides the method for parallel sorting of {@code int} arrays.
 * The underlying algorithm is a parallel MSD (most significant digit) radix
 * sort. At each iteration, only a single byte is considered so that the number 
 * of buckets is 256. This implementation honours the sign bit so that the 
 * result of parallel radix sorting is the same as in 
 * {@link java.util.Arrays.parallelSort(int[])}.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.1 (Jul 1, 2025)
 * @since 1.0 (Jun 3, 2023)
 */
public final class ParallelRadixSort {
 /**
 * The number of sort buckets.
 */
 private static final int BUCKETS = 256;
 
 /**
 * The index of the most significant byte.
 */
 private static final int DEEPEST_RECURSION_DEPTH = 7;
 
 /**
 * The mask for extracting the sign bit.
 */
 private static final long SIGN_BIT_MASK = 0x8000_0000_0000_0000L;
 
 /**
 * The number of bits per byte.
 */
 private static final int BITS_PER_BYTE = Byte.SIZE;
 
 /**
 * The number of bits to shift to the right on top most recursion depth.
 */
 private static final int TOPMOST_SHIFT_RIGHT = 56;
 
 /**
 * The mask for extracting a byte.
 */
 private static final long EXTRACT_BYTE_MASK = 0xff;
 
 /**
 * The minimum workload for a thread.
 */
 private static final int DEFAULT_THREAD_THRESHOLD = 65536;
 
 /**
 * Minimum thread workload.
 */
 private static final int MINIMUM_THREAD_WORKLOAD = 14047;
 
 /**
 * The current actual minimum thread workload in elements.
 */
 private static volatile int minimumThreadWorkload = 
 DEFAULT_THREAD_THRESHOLD;
 
 /**
 * Sets the current minimum thread workload.
 * 
 * @param newMinimumThreadWorkload the new minimum thread workload.
 */
 public static void setMinimumThreadWorkload(int newMinimumThreadWorkload) {
 minimumThreadWorkload = 
 Math.max(
 MINIMUM_THREAD_WORKLOAD,
 newMinimumThreadWorkload);
 }
 
 /**
 * Sorts the entire input array into non-decreasing order.
 * 
 * @param array the array to sort.
 */
 public static void parallelSort(long[] array) {
 parallelSort(array, 
 0,
 array.length);
 }
 
 /**
 * Sorts the range {@code array[fromIndex], ..., array[toIndex - 1]}.
 * 
 * @param array the array holding the target range to sort.
 * @param fromIndex the starting, inclusive index of the range to sort.
 * @param toIndex the ending, exclusive index of the range to sort.
 */
 public static void parallelSort(long[] array, 
 int fromIndex,
 int toIndex) {
 
 rangeCheck(array.length, fromIndex, toIndex);
 
 int rangeLength = toIndex - fromIndex;
 
 if (rangeLength < BUCKETS) {
 Arrays.sort(array,
 fromIndex,
 toIndex);
 return;
 }
 
 long[] buffer = new long[rangeLength]; 
 int threads = 
 Math.min(
 Runtime.getRuntime().availableProcessors(), 
 rangeLength / minimumThreadWorkload);
 
 threads = Math.max(threads, 1);
 
 if (threads == 1) {
 radixSortTopImpl(array, 
 buffer,
 fromIndex, 
 0,
 rangeLength);
 } else {
 parallelRadixSortTopImpl(array, 
 buffer, 
 fromIndex, 
 0, 
 rangeLength,
 threads);
 }
 }
 
 private static void parallelRadixSortTopImpl(
 long[] source, 
 long[] target,
 int sourceFromIndex,
 int targetFromIndex,
 int rangeLength,
 int threads) {
 
 if (rangeLength < BUCKETS) {
 Arrays.sort(source,
 sourceFromIndex,
 sourceFromIndex + rangeLength);
 return;
 }
 
 int startIndex = sourceFromIndex;
 int subrangeLength = rangeLength / threads;
 
 BucketSizeCounterTopThread[] bucketSizeCounterThreads = 
 new BucketSizeCounterTopThread[threads];
 
 // Spawn all but the rightmost bucket size counter thread. The rightmost
 // thread will be run in this thread as a mild optimization:
 for (int i = 0; i != bucketSizeCounterThreads.length - 1; i++) {
 BucketSizeCounterTopThread bucketSizeCounterThread =
 new BucketSizeCounterTopThread(
 source,
 startIndex,
 startIndex += subrangeLength);
 
 bucketSizeCounterThread.start();
 bucketSizeCounterThreads[i] = bucketSizeCounterThread;
 }
 
 // Run the last bucket size counter thread in this thread:
 BucketSizeCounterTopThread lastBucketSizeCounterThread =
 new BucketSizeCounterTopThread(
 source, 
 startIndex, 
 sourceFromIndex + rangeLength);
 
 // Run the last bucket size thread in this thread:
 lastBucketSizeCounterThread.run(); 
 bucketSizeCounterThreads[threads - 1] = lastBucketSizeCounterThread;
 
 // Join all the spawned bucket size counter threads:
 for (int i = 0; i != threads - 1; i++) {
 BucketSizeCounterTopThread bucketSizeCounterThread = 
 bucketSizeCounterThreads[i];
 
 try {
 bucketSizeCounterThread.join();
 } catch (InterruptedException ex) {
 throw new RuntimeException(
 "Could not join a bucket size counter thread.",
 ex);
 }
 }
 
 // Build the global bucket size map:
 int[] globalBucketSizeMap = new int[BUCKETS];
 
 for (int i = 0; i != threads; i++) {
 int[] localBucketSizeMap = 
 bucketSizeCounterThreads[i].getLocalBucketSizeMap();
 
 for (int j = 0; j != BUCKETS; j++) {
 globalBucketSizeMap[j] += localBucketSizeMap[j];
 }
 }
 
 int numberOfNonemptyBuckets = 0;
 
 for (int i = 0; i != BUCKETS; i++) {
 if (globalBucketSizeMap[i] != 0) {
 numberOfNonemptyBuckets++;
 }
 }
 
 int spawnDegree = Math.min(numberOfNonemptyBuckets,
 threads);
 
 int[] startIndexMap = new int[BUCKETS];
 startIndexMap[0] = targetFromIndex;
 
 for (int i = 1; i != BUCKETS; i++) {
 startIndexMap[i] = startIndexMap[i - 1] 
 + globalBucketSizeMap[i - 1];
 }
 
 int[][] processedMaps = new int[spawnDegree][BUCKETS];
 
 // Make the preprocessing maps independent of each thread:
 for (int i = 1; i != spawnDegree; i++) {
 int[] partialBucketSizeMap =
 bucketSizeCounterThreads[i - 1].getLocalBucketSizeMap();
 
 for (int j = 0; j != BUCKETS; j++) {
 processedMaps[i][j] = processedMaps[i - 1][j]
 + partialBucketSizeMap[j];
 }
 }
 
 int sourceStartIndex = sourceFromIndex;
 
 BucketInserterTopThread[] bucketInserterThreads = 
 new BucketInserterTopThread[spawnDegree];
 
 // Spawn all but the rightmost bucket inserter thread. The rightmost
 // thread will be run in this thread as a mild optimization:
 for (int i = 0; i != spawnDegree - 1; i++) {
 BucketInserterTopThread bucketInserterThread = 
 new BucketInserterTopThread(source,
 target,
 sourceStartIndex,
 startIndexMap,
 processedMaps[i],
 subrangeLength);
 
 sourceStartIndex += subrangeLength;
 bucketInserterThread.start();
 bucketInserterThreads[i] = bucketInserterThread;
 }
 
 BucketInserterTopThread lastBucketInserterThread =
 new BucketInserterTopThread(source,
 target,
 sourceStartIndex,
 startIndexMap,
 processedMaps[spawnDegree - 1],
 rangeLength - (spawnDegree - 1) 
 * subrangeLength);
 
 // Run the last, rightmost bucket inserter thread in this thread:
 lastBucketInserterThread.run();
 bucketInserterThreads[spawnDegree - 1] = lastBucketInserterThread;
 
 // Join all the spawned bucket inserter threads:
 for (int i = 0; i != spawnDegree - 1; i++) {
 BucketInserterTopThread bucketInserterThread = 
 bucketInserterThreads[i];
 
 try {
 bucketInserterThread.join();
 } catch (InterruptedException ex) {
 throw new RuntimeException(
 "Could not join a bucket inserter thread.",
 ex);
 }
 }
 
 ListOfBucketKeyLists bucketIndexListArray =
 new ListOfBucketKeyLists(spawnDegree);
 
 for (int i = 0; i != spawnDegree; i++) {
 BucketKeyList bucketKeyList = 
 new BucketKeyList(numberOfNonemptyBuckets);
 
 bucketIndexListArray.addBucketKeyList(bucketKeyList);
 }
 
 // Match each thread to the number of threads it may run in:
 int[] threadCountMap = new int[spawnDegree];
 
 // ... basic thread counts...
 for (int i = 0; i != spawnDegree; i++) {
 threadCountMap[i] = threads / spawnDegree;
 }
 
 // ... make sure all threads are in use:
 for (int i = 0; i != threads % spawnDegree; i++) {
 threadCountMap[i]++;
 }
 
 // Contains all the keys of all the non-empty buckets:
 BucketKeyList nonEmptyBucketIndices = 
 new BucketKeyList(numberOfNonemptyBuckets);
 
 for (int bucketKey = 0; bucketKey != BUCKETS; bucketKey++) {
 if (globalBucketSizeMap[bucketKey] != 0) {
 nonEmptyBucketIndices.addBucketKey(bucketKey);
 }
 }
 
 // Shuffle the bucket keys:
 nonEmptyBucketIndices.shuffle(new Random());
 
 // Distributed the buckets over sorter task lists:
 int frontIndex = 0;
 int cursorIndex = 0;
 int listIndex = 0;
 int optimalSubrangeLength = rangeLength / spawnDegree;
 int packed = 0;
 int numberOfNonEmptyBuckets = nonEmptyBucketIndices.size();
 
 while (cursorIndex != numberOfNonEmptyBuckets) {
 int bucketKey = nonEmptyBucketIndices.getBucketKey(cursorIndex++);
 int tmp = globalBucketSizeMap[bucketKey];
 packed += tmp;
 
 if (packed >= optimalSubrangeLength 
 || cursorIndex == numberOfNonEmptyBuckets) {
 
 packed = 0;
 
 for (int i = frontIndex; i != cursorIndex; i++) {
 int bucketKey2 = nonEmptyBucketIndices.getBucketKey(i);
 
 BucketKeyList bucketKeyList = 
 bucketIndexListArray.getBucketKeyList(listIndex);
 
 bucketKeyList.addBucketKey(bucketKey2);
 }
 
 listIndex++;
 frontIndex = cursorIndex;
 }
 }
 
 sourceStartIndex = sourceFromIndex;
 
 List<List<SorterTask>> arrayOfTaskArrays = 
 new ArrayList<>(spawnDegree);
 
 for (int i = 0; i != spawnDegree; i++) {
 List<SorterTask> taskArray = 
 new ArrayList<>(BUCKETS);
 
 BucketKeyList bucketKeyList = 
 bucketIndexListArray.getBucketKeyList(i);
 
 int size = bucketKeyList.size();
 
 for (int idx = 0; idx != size; idx++) {
 int bucketKey = bucketKeyList.getBucketKey(idx);
 
 SorterTask sorterTask =
 new SorterTask(
 target,
 source,
 startIndexMap[bucketKey],
 startIndexMap[bucketKey] - 
 targetFromIndex + 
 sourceFromIndex,
 
 globalBucketSizeMap[bucketKey],
 1,
 threadCountMap[i]);
 
 taskArray.add(sorterTask);
 }
 
 arrayOfTaskArrays.add(taskArray);
 }
 
 SorterThread[] sorterThreads = new SorterThread[spawnDegree - 1];
 
 // Recur into deeper depth via multithreading:
 for (int i = 0; i != sorterThreads.length; i++) {
 SorterThread sorterThread = 
 new SorterThread(
 arrayOfTaskArrays.get(i));
 
 sorterThread.start();
 sorterThreads[i] = sorterThread;
 }
 
 // Run the rightmost sorter thread in this thread:
 new SorterThread(
 arrayOfTaskArrays.get(spawnDegree - 1)).run();;
 
 // Join all the actually spawned sorter threads:
 for (SorterThread sorterThread : sorterThreads) {
 try {
 sorterThread.join();
 } catch (InterruptedException ex) {
 throw new RuntimeException(
 "Could not join a sorter thread.",
 ex);
 }
 }
 }
 private static void parallelRadixSortImpl(long[] source, 
 long[] target,
 int sourceFromIndex,
 int targetFromIndex,
 int rangeLength,
 int recursionDepth,
 int threads) {
 
 int startIndex = sourceFromIndex;
 int subrangeLength = rangeLength / threads;
 
 BucketSizeCounterThread[] bucketSizeCounterThreads = 
 new BucketSizeCounterThread[threads];
 
 // Spawn all but the rightmost bucket size counter thread. The rightmost
 // thread will be run in this thread as a mild optimization:
 for (int i = 0; i != bucketSizeCounterThreads.length - 1; i++) {
 BucketSizeCounterThread bucketSizeCounterThread = 
 new BucketSizeCounterThread(
 source,
 startIndex,
 startIndex += subrangeLength, 
 recursionDepth);
 
 bucketSizeCounterThread.start();
 bucketSizeCounterThreads[i] = bucketSizeCounterThread;
 }
 
 // Run the last bucket size counter thread in this thread:
 BucketSizeCounterThread lastBucketSizeCounterThread =
 new BucketSizeCounterThread(
 source, 
 startIndex, 
 sourceFromIndex + rangeLength, 
 recursionDepth);
 
 // Run the last bucket size thread in this thread:
 lastBucketSizeCounterThread.run(); 
 bucketSizeCounterThreads[threads - 1] = lastBucketSizeCounterThread;
 
 // Join all the spawned bucket size counter threads:
 for (int i = 0; i != threads - 1; i++) {
 BucketSizeCounterThread bucketSizeCounterThread = 
 bucketSizeCounterThreads[i];
 
 try {
 bucketSizeCounterThread.join();
 } catch (InterruptedException ex) {
 throw new RuntimeException(
 "Could not join a bucket size counter thread.",
 ex);
 }
 }
 
 // Build the global bucket size map:
 int[] globalBucketSizeMap = new int[BUCKETS];
 
 for (int i = 0; i != threads; i++) {
 int[] localBucketSizeMap = 
 bucketSizeCounterThreads[i].getLocalBucketSizeMap();
 
 for (int j = 0; j != BUCKETS; j++) {
 globalBucketSizeMap[j] += localBucketSizeMap[j];
 }
 }
 
 int numberOfNonemptyBuckets = 0;
 
 for (int i = 0; i != BUCKETS; i++) {
 if (globalBucketSizeMap[i] != 0) {
 numberOfNonemptyBuckets++;
 }
 }
 
 int spawnDegree = Math.min(numberOfNonemptyBuckets, threads);
 int[] startIndexMap = new int[BUCKETS];
 startIndexMap[0] = targetFromIndex;
 
 for (int i = 1; i != BUCKETS; i++) {
 startIndexMap[i] = startIndexMap[i - 1] 
 + globalBucketSizeMap[i - 1];
 }
 
 int[][] processedMaps = new int[spawnDegree][BUCKETS];
 
 // Make the preprocessing maps independent of each thread:
 for (int i = 1; i != spawnDegree; i++) {
 int[] partialBucketSizeMap =
 bucketSizeCounterThreads[i - 1].getLocalBucketSizeMap();
 
 for (int j = 0; j != BUCKETS; j++) {
 processedMaps[i][j] = processedMaps[i - 1][j]
 + partialBucketSizeMap[j];
 }
 }
 
 int sourceStartIndex = sourceFromIndex;
 
 BucketInserterThread[] bucketInserterThreads = 
 new BucketInserterThread[spawnDegree];
 
 // Spawn all but the rightmost bucket inserter thread. The rightmost
 // thread will be run in this thread as a mild optimization:
 for (int i = 0; i != spawnDegree - 1; i++) {
 BucketInserterThread bucketInserterThread = 
 new BucketInserterThread(
 source,
 target,
 sourceStartIndex,
 startIndexMap,
 processedMaps[i],
 subrangeLength,
 recursionDepth);
 
 sourceStartIndex += subrangeLength;
 
 bucketInserterThread.start();
 bucketInserterThreads[i] = bucketInserterThread;
 }
 
 BucketInserterThread lastBucketInserterThread =
 new BucketInserterThread(
 source,
 target,
 sourceStartIndex,
 startIndexMap,
 processedMaps[spawnDegree - 1],
 rangeLength - (spawnDegree - 1) * subrangeLength,
 recursionDepth);
 
 // Run the last, rightmost bucket inserter thread in this thread:
 lastBucketInserterThread.run();
 bucketInserterThreads[spawnDegree - 1] = lastBucketInserterThread;
 
 // Join all the spawned bucket inserter threads:
 for (int i = 0; i != spawnDegree - 1; i++) {
 BucketInserterThread bucketInserterThread = 
 bucketInserterThreads[i];
 
 try {
 bucketInserterThread.join();
 } catch (InterruptedException ex) {
 throw new RuntimeException(
 "Could not join a bucket inserter thread.",
 ex);
 }
 }
 
 if (recursionDepth == DEEPEST_RECURSION_DEPTH) {
 // Nowhere to recur, all bytes are processed. Return.
 return;
 }
 
 ListOfBucketKeyLists bucketIndexListArray =
 new ListOfBucketKeyLists(spawnDegree);
 
 for (int i = 0; i != spawnDegree; i++) {
 BucketKeyList bucketKeyList = 
 new BucketKeyList(numberOfNonemptyBuckets);
 
 bucketIndexListArray.addBucketKeyList(bucketKeyList);
 }
 
 // Match each thread to the number of threads it may run in:
 int[] threadCountMap = new int[spawnDegree];
 
 // ... basic thread counts...
 for (int i = 0; i != spawnDegree; i++) {
 threadCountMap[i] = threads / spawnDegree;
 }
 
 // ... make sure all threads are in use:
 for (int i = 0; i != threads % spawnDegree; i++) {
 threadCountMap[i]++;
 }
 
 // Contains all the keys of all the non-empty buckets:
 BucketKeyList nonEmptyBucketIndices = 
 new BucketKeyList(numberOfNonemptyBuckets);
 
 for (int bucketKey = 0; bucketKey != BUCKETS; bucketKey++) {
 if (globalBucketSizeMap[bucketKey] != 0) {
 nonEmptyBucketIndices.addBucketKey(bucketKey);
 }
 }
 
 // Shuffle the bucket keys:
 nonEmptyBucketIndices.shuffle(new Random());
 
 // Distributed the buckets over sorter task lists:
 int frontIndex = 0;
 int cursorIndex = 0;
 int listIndex = 0;
 int optimalSubrangeLength = rangeLength / spawnDegree;
 int packed = 0;
 int numberOfNonEmptyBuckets = nonEmptyBucketIndices.size();
 
 while (cursorIndex != numberOfNonEmptyBuckets) {
 int bucketKey = nonEmptyBucketIndices.getBucketKey(cursorIndex++);
 int tmp = globalBucketSizeMap[bucketKey];
 packed += tmp;
 
 if (packed >= optimalSubrangeLength 
 || cursorIndex == numberOfNonEmptyBuckets) {
 
 packed = 0;
 
 for (int i = frontIndex; i != cursorIndex; i++) {
 int bucketKey2 = nonEmptyBucketIndices.getBucketKey(i);
 
 BucketKeyList bucketKeyList = 
 bucketIndexListArray.getBucketKeyList(listIndex);
 
 bucketKeyList.addBucketKey(bucketKey2);
 }
 
 listIndex++;
 frontIndex = cursorIndex;
 }
 }
 
 sourceStartIndex = sourceFromIndex;
 
 List<List<SorterTask>> arrayOfTaskArrays = 
 new ArrayList<>(spawnDegree);
 
 for (int i = 0; i != spawnDegree; i++) {
 List<SorterTask> taskArray = 
 new ArrayList<>(BUCKETS);
 
 BucketKeyList bucketKeyList = 
 bucketIndexListArray.getBucketKeyList(i);
 
 int size = bucketKeyList.size();
 
 for (int idx = 0; idx != size; idx++) {
 int bucketKey = bucketKeyList.getBucketKey(idx);
 
 SorterTask sorterTask =
 new SorterTask(
 target,
 source,
 startIndexMap[bucketKey],
 startIndexMap[bucketKey] - 
 targetFromIndex + 
 sourceFromIndex,
 
 globalBucketSizeMap[bucketKey],
 recursionDepth + 1,
 threadCountMap[i]);
 
 taskArray.add(sorterTask);
 }
 
 arrayOfTaskArrays.add(taskArray);
 }
 
 SorterThread[] sorterThreads = new SorterThread[spawnDegree - 1];
 
 // Recur into deeper depth via multithreading:
 for (int i = 0; i != sorterThreads.length; i++) {
 SorterThread sorterThread = 
 new SorterThread(
 arrayOfTaskArrays.get(i));
 
 sorterThread.start();
 sorterThreads[i] = sorterThread;
 }
 
 // Run the rightmost sorter thread in this thread:
 new SorterThread(
 arrayOfTaskArrays.get(spawnDegree - 1)).run();;
 
 // Join all the actually spawned sorter threads:
 for (SorterThread sorterThread : sorterThreads) {
 try {
 sorterThread.join();
 } catch (InterruptedException ex) {
 throw new RuntimeException(
 "Could not join a sorter thread.",
 ex);
 }
 }
 }
 
 private static void rangeCheck(
 int arrayLength, 
 int fromIndex, 
 int toIndex) {
 
 if (fromIndex > toIndex) {
 throw new IllegalArgumentException(
 "fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")");
 }
 
 if (fromIndex < 0) {
 throw new ArrayIndexOutOfBoundsException(fromIndex);
 }
 
 if (toIndex > arrayLength) {
 throw new ArrayIndexOutOfBoundsException(toIndex);
 }
 }
 
 /**
 * Sorts the topmost range 
 * {@code <source[sourceFromIndex], ..., source[sourceFromIndex + rangeLength - 1>}
 * and stores the result in 
 * {@code <target[targetFromIndex], ..., target[targetFromIndex + rangeLength -l>}.
 * 
 * @param source the source array.
 * @param target the target array.
 * @param sourceFromIndex the starting index of the range to sort in 
 * {@code source}.
 * @param targetFromIndex the starting index of the range to put the result
 * in.
 * @param rangeLength the length of the range to sort.
 */
 private static void radixSortTopImpl(long[] source,
 long[] target,
 int sourceFromIndex,
 int targetFromIndex,
 int rangeLength) {
 
 if (rangeLength < BUCKETS) {
 Arrays.sort(source, 
 sourceFromIndex,
 sourceFromIndex + rangeLength);
 
 return;
 }
 
 int[] bucketSizeMap = new int[BUCKETS];
 int[] startIndexMap = new int[BUCKETS];
 int[] processedMap = new int[BUCKETS];
 
 int sourceToIndex = sourceFromIndex + rangeLength;
 
 // Find out the size of each bucket:
 for (int i = sourceFromIndex; 
 i != sourceToIndex; 
 i++) {
 
 long datum = source[i];
 int bucketIndex = getBucketIndexTop(datum);
 bucketSizeMap[bucketIndex]++;
 }
 
 startIndexMap[0] = targetFromIndex;
 
 // Compute starting indices for buckets in the target array. This is 
 // actually just an accumulated array of bucketSizeMap, such that
 // startIndexMap[0] = 0, startIndexMap[1] = bucketSizeMap[0], ...,
 // startIndexMap[BUCKETS - 1] = bucketSizeMap[0] + bucketSizeMap[1] +
 // ... + bucketSizeMap[BUCKETS - 2].
 for (int i = 1; i != BUCKETS; i++) {
 startIndexMap[i] = startIndexMap[i - 1] 
 + bucketSizeMap[i - 1];
 }
 
 // Insert each element to its own bucket:
 for (int i = sourceFromIndex; i != sourceToIndex; i++) {
 long datum = source[i];
 int bucketKey = getBucketIndexTop(datum);
 
 target[startIndexMap[bucketKey] + 
 processedMap[bucketKey]++] = datum;
 }
 
 for (int i = 0; i != BUCKETS; i++) {
 if (bucketSizeMap[i] != 0) {
 // Sort from 'target' to 'source':
 radixSortImpl(
 target,
 source,
 startIndexMap[i],
 startIndexMap[i] - targetFromIndex + sourceFromIndex,
 bucketSizeMap[i],
 1);
 }
 }
 }
 /**
 * Sorts the range 
 * {@code <source[sourceFromIndex], ..., source[sourceFromIndex + rangeLength - 1>}
 * and stores the result in 
 * {@code <target[targetFromIndex], ..., target[targetFromIndex + rangeLength -l>}.
 * 
 * @param source the source array.
 * @param target the target array.
 * @param sourceFromIndex the starting index of the range to sort in 
 * {@code source}.
 * @param targetFromIndex the starting index of the range to put the result
 * in.
 * @param rangeLength the length of the range to sort.
 * @param recursionDepth the recursion depth.
 */
 private static void radixSortImpl(long[] source,
 long[] target,
 int sourceFromIndex,
 int targetFromIndex,
 int rangeLength,
 int recursionDepth) {
 
 if (rangeLength < BUCKETS) {
 Arrays.sort(source,
 sourceFromIndex,
 sourceFromIndex + rangeLength);
 
 if (recursionDepth % 2 == 1) {
 System.arraycopy(source, 
 sourceFromIndex, 
 target, 
 targetFromIndex, 
 rangeLength);
 }
 
 return;
 }
 
 int[] bucketSizeMap = new int[BUCKETS];
 int[] startIndexMap = new int[BUCKETS];
 int[] processedMap = new int[BUCKETS];
 
 int sourceToIndex = sourceFromIndex + rangeLength;
 
 // Find out the size of each bucket:
 for (int i = sourceFromIndex; 
 i != sourceToIndex; 
 i++) {
 
 long datum = source[i];
 int bucketIndex = getBucketIndex(datum, 
 recursionDepth);
 bucketSizeMap[bucketIndex]++;
 }
 
 startIndexMap[0] = targetFromIndex;
 
 // Compute starting indices for buckets in the target array. This is 
 // actually just an accumulated array of bucketSizeMap, such that
 // startIndexMap[0] = 0, startIndexMap[1] = bucketSizeMap[0], ...,
 // startIndexMap[BUCKETS - 1] = bucketSizeMap[0] + bucketSizeMap[1] +
 // ... + bucketSizeMap[BUCKETS - 2].
 for (int i = 1; i != BUCKETS; i++) {
 startIndexMap[i] = startIndexMap[i - 1] 
 + bucketSizeMap[i - 1];
 }
 
 // Insert each element to its own bucket:
 for (int i = sourceFromIndex; i != sourceToIndex; i++) {
 long datum = source[i];
 int bucketKey = getBucketIndex(datum, recursionDepth);
 
 target[startIndexMap[bucketKey] + 
 processedMap[bucketKey]++] = datum;
 }
 
 if (recursionDepth == DEEPEST_RECURSION_DEPTH) {
 System.arraycopy(
 target, 
 targetFromIndex, 
 source, 
 sourceFromIndex,
 rangeLength);
 
 return;
 }
 
 for (int i = 0; i != BUCKETS; i++) {
 if (bucketSizeMap[i] != 0) {
 // Sort from 'target' to 'source':
 radixSortImpl(
 target,
 source,
 startIndexMap[i],
 startIndexMap[i] - targetFromIndex + sourceFromIndex,
 bucketSizeMap[i],
 recursionDepth + 1);
 }
 }
 }
 
 /**
 * Returns the bucket index of the {@code element} at the top-most recursion
 * level.
 * 
 * @param element the target element.
 * 
 * @return the index of the bucket to which the element belongs.
 */
 static int getBucketIndexTop(long element) {
 return (int)((element ^ SIGN_BIT_MASK) >>> TOPMOST_SHIFT_RIGHT);
 }
 
 /**
 * Returns the bucket index of the {@code element} at the 
 * {@code recursionDepth}th recursion level. The higher the value of
 * {@code recursionDepth}, the deeper in the recursion tree we are.
 * 
 * @param element the element whose bucket index to compute.
 * @param recursionDepth the recursion depth.
 * 
 * @return the index fo the bucket to which the element belongs.
 */
 static int getBucketIndex(long element,
 int recursionDepth) {
 return (int)(element >>> 
 ((DEEPEST_RECURSION_DEPTH - recursionDepth) 
 * BITS_PER_BYTE) 
 & EXTRACT_BYTE_MASK);
 }
 
 private static final class BucketSizeCounterTopThread extends Thread {
 
 private final int[] localBucketSizeMap = new int[BUCKETS];
 private final long[] array;
 private final int fromIndex;
 private final int toIndex;
 
 BucketSizeCounterTopThread(long[] array,
 int fromIndex,
 int toIndex) {
 
 this.array = array;
 this.fromIndex = fromIndex;
 this.toIndex = toIndex;
 }
 
 @Override
 public void run() {
 for (int i = fromIndex; i != toIndex; i++) {
 localBucketSizeMap[getBucketIndexTop(array[i])]++;
 }
 }
 
 int[] getLocalBucketSizeMap() {
 return localBucketSizeMap;
 }
 }
 
 private static final class BucketSizeCounterThread extends Thread {
 
 private final int[] localBucketSizeMap = new int[BUCKETS];
 private final long[] array;
 private final int fromIndex;
 private final int toIndex;
 private final int recursionDepth;
 
 BucketSizeCounterThread(long[] array,
 int fromIndex,
 int toIndex,
 int recursionDepth) {
 
 this.array = array;
 this.fromIndex = fromIndex;
 this.toIndex = toIndex;
 this.recursionDepth = recursionDepth;
 }
 
 @Override
 public void run() {
 for (int i = fromIndex; i != toIndex; i++) {
 localBucketSizeMap[getBucketIndex(array[i], 
 recursionDepth)]++;
 }
 }
 
 int[] getLocalBucketSizeMap() {
 return localBucketSizeMap;
 }
 }
 
 private static final class BucketInserterTopThread extends Thread {
 
 private final long[] source;
 private final long[] target;
 private final int sourceFromIndex;
 private final int[] startIndexMap;
 private final int[] processedMap;
 private final int rangeLength;
 
 BucketInserterTopThread(long[] source,
 long[] target,
 int sourceFromIndex,
 int[] startIndexMap,
 int[] processedMap,
 int rangeLength) {
 
 this.source = source;
 this.target = target;
 this.sourceFromIndex = sourceFromIndex;
 this.startIndexMap = startIndexMap;
 this.processedMap = processedMap;
 this.rangeLength = rangeLength;
 }
 
 @Override
 public void run() {
 int sourceToIndex = sourceFromIndex + rangeLength;
 
 for (int i = sourceFromIndex; i != sourceToIndex; i++) {
 long datum = source[i];
 int bucketKey = getBucketIndexTop(datum);
 
 target[startIndexMap[bucketKey] + 
 processedMap[bucketKey]++] = datum;
 }
 }
 }
 
 private static final class BucketInserterThread extends Thread {
 
 private final long[] source;
 private final long[] target;
 private final int sourceFromIndex;
 private final int[] startIndexMap;
 private final int[] processedMap;
 private final int rangeLength;
 private final int recursionDepth;
 
 BucketInserterThread(long[] source,
 long[] target,
 int sourceFromIndex,
 int[] startIndexMap,
 int[] processedMap,
 int rangeLength,
 int recursionDepth) {
 
 this.source = source;
 this.target = target;
 this.sourceFromIndex = sourceFromIndex;
 this.startIndexMap = startIndexMap;
 this.processedMap = processedMap;
 this.rangeLength = rangeLength;
 this.recursionDepth = recursionDepth;
 }
 
 @Override
 public void run() {
 int sourceToIndex = sourceFromIndex + rangeLength;
 
 for (int i = sourceFromIndex; i != sourceToIndex; i++) {
 
 long datum = source[i];
 int bucketKey = getBucketIndex(datum, 
 recursionDepth);
 
 target[startIndexMap[bucketKey] + 
 processedMap[bucketKey]++] = datum;
 }
 }
 }
 
 private static final class SorterThread extends Thread {
 
 private final List<SorterTask> sorterTasks;
 
 SorterThread(List<SorterTask> sorterTasks) {
 this.sorterTasks = sorterTasks;
 }
 
 @Override
 public void run() {
 for (SorterTask sorterTask : sorterTasks) {
 if (sorterTask.threads > 1) {
 parallelRadixSortImpl(sorterTask.source,
 sorterTask.target,
 sorterTask.sourceStartOffset,
 sorterTask.targetStartOffset,
 sorterTask.rangeLength,
 sorterTask.recursionDepth,
 sorterTask.threads);
 } else {
 radixSortImpl(sorterTask.source,
 sorterTask.target,
 sorterTask.sourceStartOffset,
 sorterTask.targetStartOffset,
 sorterTask.rangeLength,
 sorterTask.recursionDepth);
 }
 }
 }
 }
 
 private static final class SorterTask{
 
 final long[] source;
 final long[] target;
 final int sourceStartOffset;
 final int targetStartOffset;
 final int rangeLength;
 final int recursionDepth;
 final int threads;
 
 SorterTask(long[] source,
 long[] target,
 int sourceStartOffset,
 int targetStartOffset,
 int rangeLength,
 int recursionDepth,
 int threads) {
 
 this.source = source;
 this.target = target;
 this.sourceStartOffset = sourceStartOffset;
 this.targetStartOffset = targetStartOffset;
 this.rangeLength = rangeLength;
 this.recursionDepth = recursionDepth;
 this.threads = threads;
 }
 }
 
 private static final class BucketKeyList {
 private final int[] bucketKeys;
 private int size;
 
 BucketKeyList(int capacity) {
 this.bucketKeys = new int[capacity];
 }
 
 void addBucketKey(int bucketKey) {
 this.bucketKeys[size++] = bucketKey;
 }
 
 int getBucketKey(int index) {
 return this.bucketKeys[index];
 }
 
 int size() {
 return size;
 }
 
 void shuffle(Random random) {
 for (int i = 0; i != size - 1; i++) {
 int j = i + random.nextInt(size - i);
 int temp = bucketKeys[i];
 bucketKeys[i] = bucketKeys[j];
 bucketKeys[j] = temp;
 }
 }
 }
 
 private static final class ListOfBucketKeyLists {
 private final BucketKeyList[] lists;
 private int size;
 
 ListOfBucketKeyLists(int capacity) {
 this.lists = new BucketKeyList[capacity];
 }
 
 void addBucketKeyList(BucketKeyList bucketKeyList) {
 this.lists[this.size++] = bucketKeyList;
 }
 
 BucketKeyList getBucketKeyList(int index) {
 return this.lists[index];
 }
 }
}

Typical output

I got once at my PC this kind of output. As you can see, my sort is inferior on already sorted data, yet is faster than Arrays.parallelSort by a factor of \$\tilde =\$ 2.

<<< Warming on sorted data >>>
<<< Warming on random data >>>
<<< Benchmarking on sorted data >>>
Arrays.parallelSort: 242 ms, ParallelRadixSort.parallelSort: 1764 ms, agreed: true
Arrays.parallelSort: 246 ms, ParallelRadixSort.parallelSort: 1730 ms, agreed: true
Arrays.parallelSort: 235 ms, ParallelRadixSort.parallelSort: 1833 ms, agreed: true
Arrays.parallelSort: 246 ms, ParallelRadixSort.parallelSort: 1779 ms, agreed: true
Arrays.parallelSort: 244 ms, ParallelRadixSort.parallelSort: 1913 ms, agreed: true
Arrays.parallelSort: 238 ms, ParallelRadixSort.parallelSort: 1750 ms, agreed: true
Arrays.parallelSort: 250 ms, ParallelRadixSort.parallelSort: 1748 ms, agreed: true
Arrays.parallelSort: 221 ms, ParallelRadixSort.parallelSort: 1622 ms, agreed: true
Arrays.parallelSort: 247 ms, ParallelRadixSort.parallelSort: 1852 ms, agreed: true
Arrays.parallelSort: 254 ms, ParallelRadixSort.parallelSort: 1802 ms, agreed: true
Arrays.parallelSort: 244 ms, ParallelRadixSort.parallelSort: 1721 ms, agreed: true
Arrays.parallelSort: 248 ms, ParallelRadixSort.parallelSort: 1754 ms, agreed: true
Arrays.parallelSort: 237 ms, ParallelRadixSort.parallelSort: 1697 ms, agreed: true
Arrays.parallelSort: 242 ms, ParallelRadixSort.parallelSort: 1780 ms, agreed: true
Arrays.parallelSort: 240 ms, ParallelRadixSort.parallelSort: 1728 ms, agreed: true
Arrays.parallelSort: 239 ms, ParallelRadixSort.parallelSort: 1736 ms, agreed: true
Arrays.parallelSort: 237 ms, ParallelRadixSort.parallelSort: 1713 ms, agreed: true
Arrays.parallelSort: 245 ms, ParallelRadixSort.parallelSort: 1765 ms, agreed: true
Arrays.parallelSort: 238 ms, ParallelRadixSort.parallelSort: 1715 ms, agreed: true
Arrays.parallelSort: 228 ms, ParallelRadixSort.parallelSort: 1706 ms, agreed: true
Total Arrays.parallelSort duration: 4821, total ParallelRadixSort.parallelSort: 35108
<<< Benchmarking on random data >>>
Arrays.parallelSort: 1023 ms, ParallelRadixSort.parallelSort: 476 ms, agreed: true
Arrays.parallelSort: 1038 ms, ParallelRadixSort.parallelSort: 486 ms, agreed: true
Arrays.parallelSort: 964 ms, ParallelRadixSort.parallelSort: 460 ms, agreed: true
Arrays.parallelSort: 964 ms, ParallelRadixSort.parallelSort: 483 ms, agreed: true
Arrays.parallelSort: 1105 ms, ParallelRadixSort.parallelSort: 549 ms, agreed: true
Arrays.parallelSort: 1004 ms, ParallelRadixSort.parallelSort: 464 ms, agreed: true
Arrays.parallelSort: 985 ms, ParallelRadixSort.parallelSort: 499 ms, agreed: true
Arrays.parallelSort: 978 ms, ParallelRadixSort.parallelSort: 490 ms, agreed: true
Arrays.parallelSort: 1007 ms, ParallelRadixSort.parallelSort: 472 ms, agreed: true
Arrays.parallelSort: 1049 ms, ParallelRadixSort.parallelSort: 484 ms, agreed: true
Arrays.parallelSort: 976 ms, ParallelRadixSort.parallelSort: 475 ms, agreed: true
Arrays.parallelSort: 1018 ms, ParallelRadixSort.parallelSort: 474 ms, agreed: true
Arrays.parallelSort: 994 ms, ParallelRadixSort.parallelSort: 453 ms, agreed: true
Arrays.parallelSort: 1109 ms, ParallelRadixSort.parallelSort: 558 ms, agreed: true
Arrays.parallelSort: 1012 ms, ParallelRadixSort.parallelSort: 496 ms, agreed: true
Arrays.parallelSort: 971 ms, ParallelRadixSort.parallelSort: 507 ms, agreed: true
Arrays.parallelSort: 1048 ms, ParallelRadixSort.parallelSort: 516 ms, agreed: true
Arrays.parallelSort: 1346 ms, ParallelRadixSort.parallelSort: 543 ms, agreed: true
Arrays.parallelSort: 972 ms, ParallelRadixSort.parallelSort: 455 ms, agreed: true
Arrays.parallelSort: 1068 ms, ParallelRadixSort.parallelSort: 512 ms, agreed: true
Total Arrays.parallelSort duration: 20631, total ParallelRadixSort.parallelSort: 9852
Benchmark done!

Critique request

As always, I am eager to hear any constructive commentary on my work.

toolic
14.9k5 gold badges29 silver badges205 bronze badges
asked Jul 1 at 10:44
\$\endgroup\$
0

0

Know someone who can answer? Share a link to this question via email, Twitter, or Facebook.

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.