What?
I have designed curvesort, an algorithm that adapts to "smoothness" of data. For example, if the data resembles a sine wave, it is likely that curvesort will sort it fast. This post is about a parallel implementation of the algorithm.
Code
ParallelCurvesort.java
package net.coderodde.util;
import java.util.Objects;
public final class ParallelCurvesort {
/**
* Each thread should not handle less than this number of integers.
*/
private static final int MINIMUM_INTS_PER_THREAD = 10_000;
/**
* This static inner class implements a node in the frequency list.
*/
private static final class FrequencyListNode {
final int integer;
int count;
FrequencyListNode prev;
FrequencyListNode next;
FrequencyListNode(int integer) {
this.integer = integer;
this.count = 1;
}
}
private static final class ScannerThread extends Thread {
private FrequencyListNode last;
private FrequencyListNode head;
private FrequencyListNode tail;
private final int[] array;
private final int fromIndex;
private final int toIndex;
ScannerThread(int[] array, int fromIndex, int toIndex) {
this.array = array;
this.fromIndex = fromIndex;
this.toIndex = toIndex;
int initialInteger = array[fromIndex];
FrequencyListNode initialNode =
new FrequencyListNode(initialInteger);
this.head = initialNode;
this.tail = initialNode;
this.last = initialNode;
}
@Override
public void run() {
for (int i = fromIndex + 1; i < toIndex; ++i) {
add(array[i]);
}
}
FrequencyListNode getHead() {
return head;
}
private void add(int integer) {
if (integer < last.integer) {
findAndUpdateSmallerNode(integer);
} else if (integer > last.integer) {
findAndUpdateLargerNode(integer);
} else {
last.count++;
}
}
private void findAndUpdateSmallerNode(int integer) {
FrequencyListNode tmp = last.prev;
// Go down the node chain towards the nodes with smaller integers.
while (tmp != null && tmp.integer > integer) {
tmp = tmp.prev;
}
if (tmp == null) {
// 'integer' is the new minimum. Create new head node and put
// the integer in it.
FrequencyListNode newNode = new FrequencyListNode(integer);
newNode.next = head;
head.prev = newNode;
head = newNode;
last = newNode;
} else if (tmp.integer == integer) {
// 'integer' already in the list. Just update the count.
tmp.count++;
last = tmp;
} else {
// Insert a new node between 'tmp' and 'tmp.next'.
FrequencyListNode newNode = new FrequencyListNode(integer);
newNode.prev = tmp;
newNode.next = tmp.next;
newNode.prev.next = newNode;
newNode.next.prev = newNode;
last = newNode;
}
}
private void findAndUpdateLargerNode(int integer) {
FrequencyListNode tmp = last.next;
// Go up the chain towards the nodes with larger keys.
while (tmp != null && tmp.integer < integer) {
tmp = tmp.next;
}
if (tmp == null) {
// 'integer' is the new maximum. Create new head node and put
// the integer in it.
FrequencyListNode newNode = new FrequencyListNode(integer);
newNode.prev = tail;
tail.next = newNode;
tail = newNode;
last = newNode;
} else if (tmp.integer == integer) {
// 'integer' already in the list. Just update the count.
tmp.count++;
last = tmp;
} else {
FrequencyListNode newNode = new FrequencyListNode(integer);
newNode.prev = tmp.prev;
newNode.next = tmp;
tmp.prev.next = newNode;
tmp.prev = newNode;
last = newNode;
}
}
}
private final int[] array;
private final int fromIndex;
private final int toIndex;
private ParallelCurvesort(int[] array, int fromIndex, int toIndex) {
this.array = array;
this.fromIndex = fromIndex;
this.toIndex = toIndex;
}
private void sort() {
int rangeLength = toIndex - fromIndex;
int numberOfThreads =
Math.min(rangeLength / MINIMUM_INTS_PER_THREAD,
Runtime.getRuntime().availableProcessors());
numberOfThreads = Math.max(numberOfThreads, 1);
numberOfThreads = ceilToPowerOfTwo(numberOfThreads);
ScannerThread[] scannerThreads = new ScannerThread[numberOfThreads - 1];
int threadRangeLength = rangeLength / numberOfThreads;
int startIndex = fromIndex;
for (int i = 0;
i < numberOfThreads - 1;
i++, startIndex += threadRangeLength) {
scannerThreads[i] =
new ScannerThread(array,
startIndex,
startIndex + threadRangeLength);
scannerThreads[i].start();
}
ScannerThread thisThread = new ScannerThread(array,
startIndex,
toIndex);
thisThread.run();
for (ScannerThread thread : scannerThreads) {
try {
thread.join();
} catch (InterruptedException ex) {
throw new RuntimeException("A thread was interrupted.", ex);
}
}
FrequencyListNode[] listHeads = new FrequencyListNode[numberOfThreads];
for (int i = 0; i < scannerThreads.length; ++i) {
listHeads[i] = scannerThreads[i].getHead();
}
listHeads[listHeads.length - 1] = thisThread.getHead();
FrequencyListNode mergedListHead = mergeLists(listHeads);
dump(mergedListHead, array, fromIndex);
}
private static int ceilToPowerOfTwo(int number) {
int ret = 1;
while (ret < number) {
ret <<= 1;
}
return ret;
}
private static void dump(FrequencyListNode head,
int[] array,
int fromIndex) {
for (FrequencyListNode node = head; node != null; node = node.next) {
int integer = node.integer;
int count = node.count;
for (int i = 0; i != count; ++i) {
array[fromIndex++] = integer;
}
}
}
private static FrequencyListNode mergeLists(FrequencyListNode[] heads) {
return mergeLists(heads, 0, heads.length);
}
private static FrequencyListNode mergeLists(FrequencyListNode[] heads,
int fromIndex,
int toIndex) {
int lists = toIndex - fromIndex;
if (lists == 1) {
return heads[fromIndex];
}
if (lists == 2) {
return mergeLists(heads[fromIndex], heads[fromIndex + 1]);
}
int middleIndex = lists / 2;
return mergeLists(mergeLists(heads, fromIndex, middleIndex),
mergeLists(heads, middleIndex, toIndex));
}
private static FrequencyListNode mergeLists(FrequencyListNode head1,
FrequencyListNode head2) {
FrequencyListNode initialNode;
if (head1.integer < head2.integer) {
initialNode = head1;
head1 = head1.next;
} else if (head1.integer > head2.integer) {
initialNode = head2;
head2 = head2.next;
} else {
initialNode = head1;
initialNode.count += head2.count;
head1 = head1.next;
head2 = head2.next;
}
FrequencyListNode newHead = initialNode;
FrequencyListNode newTail = initialNode;
while (head1 != null && head2 != null) {
if (head1.integer < head2.integer) {
newTail.next = head1;
newTail = head1;
head1 = head1.next;
} else if (head1.integer > head2.integer) {
newTail.next = head2;
newTail = head2;
head2 = head2.next;
} else {
FrequencyListNode nextHead1 = head1.next;
FrequencyListNode nextHead2 = head2.next;
newTail.next = head1;
newTail = head1;
newTail.count += head2.count;
head1 = nextHead1;
head2 = nextHead2;
}
}
if (head1 != null) {
newTail.next = head1;
newTail = head1;
} else if (head2 != null) {
newTail.next = head2;
newTail = head2;
}
newTail.next = null;
return newHead;
}
public static void sort(int[] array) {
Objects.requireNonNull(array, "The input array is null.");
sort(array, 0, array.length);
}
public static void sort(int[] array, int fromIndex, int toIndex) {
Objects.requireNonNull(array, "The input array is null.");
rangeCheck(array.length, fromIndex, toIndex);
new ParallelCurvesort(array, fromIndex, toIndex).sort();
}
/**
* Checks that {@code fromIndex} and {@code toIndex} are in the range and
* throws an exception if they aren't.
*/
private static void rangeCheck(int arrayLength, int fromIndex, int toIndex) {
if (fromIndex > toIndex) {
throw new IllegalArgumentException(
"fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")");
}
if (fromIndex < 0) {
throw new ArrayIndexOutOfBoundsException(fromIndex);
}
if (toIndex > arrayLength) {
throw new ArrayIndexOutOfBoundsException(toIndex);
}
}
}
Demo.java
package net.coderodde.util;
import java.util.Arrays;
public final class Demo {
private static final int ARRAY_LENGTH = 20_000_000;
private static final int FROM_INDEX = 2;
private static final int TO_INDEX = ARRAY_LENGTH - 3;
private static final int PERIOD_LENGTH = 10_000;
private static final int MINIMUM = -3_000;
private static final int MAXIMUM = 3_000;
public static void main(String[] args) {
warmup();
benchmark();
}
private static void warmup() {
System.out.println("Warming up...");
int[] array = getWaveArray(ARRAY_LENGTH,
MINIMUM,
MAXIMUM,
PERIOD_LENGTH);
perform(array, false);
System.out.println("Warming up done!");
}
private static void benchmark() {
int[] array = getWaveArray(ARRAY_LENGTH,
MINIMUM,
MAXIMUM,
PERIOD_LENGTH);
perform(array, true);
}
private static void perform(int[] array, boolean output) {
int[] array1 = array.clone();
int[] array2 = array.clone();
int[] array3 = array.clone();
long start = System.currentTimeMillis();
Arrays.sort(array1, FROM_INDEX, TO_INDEX);
long end = System.currentTimeMillis();
if (output) {
System.out.println("Arrays.sort in " + (end - start) +
" milliseconds.");
}
start = System.currentTimeMillis();
Arrays.parallelSort(array2, FROM_INDEX, TO_INDEX);
end = System.currentTimeMillis();
if (output) {
System.out.println("Arrays.parallelSort in " + (end - start) +
" milliseconds.");
}
start = System.currentTimeMillis();
ParallelCurvesort.sort(array3, FROM_INDEX, TO_INDEX);
end = System.currentTimeMillis();
if (output) {
System.out.println("ParallelCurvesort.sort in " + (end - start) +
" milliseconds.");
System.out.println("Algorithms agree: " +
(Arrays.equals(array1, array2) &&
Arrays.equals(array2, array3)));
}
}
private static int[] getWaveArray(int length,
int minimum,
int maximum,
int periodLength) {
int[] array = new int[length];
int halfAmplitude = (maximum - minimum +1) / 2;
for (int i = 0; i < length; ++i) {
array[i] = generateWaveInt(i, periodLength, halfAmplitude);
}
return array;
}
private static int generateWaveInt(int i,
int periodLength,
int halfAmplitude) {
double stage = (2.0 * Math.PI * i) / periodLength;
return (int)(halfAmplitude * Math.sin(stage));
}
}
Performance on "good" data
I get the following results:
Warming up... Warming up done! Arrays.sort in 705 milliseconds. Arrays.parallelSort in 512 milliseconds. ParallelCurvesort.sort in 102 milliseconds. Algorithms agree: true
Critique request
Please tell me anything that comes to mind.
1 Answer 1
Not a full review.
The performance would be poor on real data since you create so many nodes. For sorting, performance is important and it's better to work with raw arrays all the time.
findAndUpdateSmallerNode
and findAndUpdateLargerNode
have very similar code, so you might want to try to extract some common code. But adding an extra layer of abstraction might make the code less readable.
last
should probably be named latest
instead since it sounds like it's the last element in the list (tail
).
-
\$\begingroup\$ True. The worst case running time is indeed quadratic. Even worse, the hidden constant factors are big since the sort relies on a linked list. However, please note that curve sort will adapt to "smooth" data very well. \$\endgroup\$coderodde– coderodde2018年03月23日 21:11:59 +00:00Commented Mar 23, 2018 at 21:11
Explore related questions
See similar questions with these tags.
mergeLists
, the middle index shoud befromIndex+toIndex>>>1
instead of(toIndex-fromIndex)/2
. \$\endgroup\$mergeLists
takes the array ofFrequencyListNode
s whose first index is zero, and the last index isheads.length
(exclusive). \$\endgroup\$middleIndex
won't be the middle index butmiddleIndex-fromIndex
(i.e.fromIndex=0
andtoIndex=6
:0,6->3,6->1,6->2,6->2,6->2,6->...
). \$\endgroup\$