package net.coderodde.stat;stat.support;
import java.util.HashSet;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import net.coderodde.stat.AbstractProbabilityDistribution;
/**
* This class implements an abstract base class for probability distributions.
* Elements are added with strictly positive weights and whenever asking this
* data structure for a random element, their respective weights are taken into
* account. For example, ifprobability thisdistribution datarelying structureon containsan threearray differentof
* elements (<tt>a</tt>, <tt>b</tt>, <tt>c</tt> with respective weights
* <tt>1.0</tt>, <tt>1.0</tt>, <tt>3.0</tt>), whenever asking for a random
* element, there is 20 percent chance of obtaining <tt>a</tt>, 20 percent
* chance of obtaining <tt>b</tt>, and 60 percent chanceThe ofrunning obtainingtimes are *as <tt>c</tt>.follows:
*
* <table>
* <tr><td>Method</td> <td>Complexity</td></tr>
* <tr><td><tt>addElement </tt> </td> <td>amortized constant time,</td></tr>
* <tr><td><tt>sampleElement</tt> </td> <td><tt>O(n)</tt>,</td></tr>
* <tr><td><tt>removeElement</tt> </td> <td><tt>O(n)</tt>.</td></tr>
* </table>
*
* @param <E> the actual type of the elements stored in this probability
* distribution.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (Jun 11, 2016)
*/
public abstract class AbstractProbabilityDistribution<E> {
/**
* The amount of elements in this probability distribution.
ArrayProbabilityDistribution<E> */
protectedextends intAbstractProbabilityDistribution<E> size;{
/**
* The sum of all weights.
*/
private static final protectedint doubleDEFAULT_STORAGE_ARRAYS_CAPACITY totalWeight;= 8;
/**
*private TheObject[] randomobjectStorageArray;
number generator of thisprivate probabilitydouble[] distribution.weightStorageArray;
*/
private final Set<E> protectedfilterSet final= Randomnew random;HashSet<>();
/**
* Constructs this probability distribution.
*/
protectedpublic AbstractProbabilityDistributionArrayProbabilityDistribution() {
this(new Random());
}
public ArrayProbabilityDistribution(final Random random) {
super(random);
this.objectStorageArray = new Object[DEFAULT_STORAGE_ARRAYS_CAPACITY];
this.weightStorageArray = new double[DEFAULT_STORAGE_ARRAYS_CAPACITY];
}
/**
* Constructs this probability distribution using the input random number
* generator.
*
* @param random the random{@inheritDoc number} generator.
*/
protected AbstractProbabilityDistribution(final Random random) {
this.random = @Override
public boolean Objects.requireNonNulladdElement(random,
final E element, final double weight) {
"The random number generator is null."checkWeight(weight);
}
public boolean isEmpty if (filterSet.contains(element)) {
return // 'element' is already present in this probability distribution.size == 0; return false;
}
public int size ensureCapacity(this.size + 1);
{ objectStorageArray[this.size] = element;
returnweightStorageArray[this.size] = weight;
this.size;totalWeight += weight;
this.size++;
this.filterSet.add(element);
return true;
}
/**
* Adds the element {@code element} to this probability distribution, and
* assigns {@code@inheritDoc weight} as its weight.
*
* @param element the element to add.
* @param weight the weight of the new element.
*/
@Override
public abstractE booleansampleElement() addElement{
checkNotEmpty(final);
E element, final double weightvalue = this.random.nextDouble(); * this.totalWeight;
/**
*for Returns(int ai randomly= chosen0; elementi from< this probability.size; distribution++i) {
* taking the weights into account if (value < this.weightStorageArray[i]) {
* * @return a randomly chosen elementreturn (E) this.objectStorageArray[i];
*/
public abstract E sampleElement();}
/**
* Returns {@code true}value if-= this probability distribution contains the
* element {@code element}.
*
* @param element the element to query.
* @return {@code true} if the input element is in this probability weightStorageArray[i];
* distribution; {@code false} otherwise.
*/
public abstract boolean contains(final E element);
throw new IllegalStateException("Should not get here.");
}
/**
* Removes the element {@code element} from this probability distribution.
*
* @param element the element to remove.
* @return {@code@inheritDoc true} if the element was present in this probability
* distribution and was successfully removed.
*/
public@Override
abstract public boolean removeElement(final E element); {
if (!this.filterSet.contains(element)) {
return false;
}
final int index = indexOf(element);
this.totalWeight -= this.weightStorageArray[index];
for (int j = index + 1; j < this.size; ++j) {
objectStorageArray[j - 1] = objectStorageArray[j];
weightStorageArray[j - 1] = weightStorageArray[j];
}
objectStorageArray[--this.size] = null;
return true;
}
/**
* Removes all elements from this{@inheritDoc probability} distribution.
*/
public@Override
abstract public void clear(); {
for (int i = 0; i < this.size; ++i) {
objectStorageArray[i] = null;
}
this.size = 0;
this.totalWeight = 0.0;
}
/**
* Checks that the element weight is valid. The weight must not be a
* <tt>NaN</tt> and must be positive, but not a positive infinity.
*
* @param weight the weight{@inheritDoc to} validate.
*/
protected void checkWeight(final double weight) {@Override
public ifboolean (Double.isNaNcontains(weight)E element) {
throw newreturn IllegalArgumentExceptionthis.filterSet.contains("The element weight is NaN.");
}
private int indexOf(final E element) if{
for (weightint <=i 0= 0; i < this.0size; ++i) {
throw newif IllegalArgumentException(Objects.equals(element, this.objectStorageArray[i])) {
return i;
"The element weight must be positive. Received " + weight);}
}
if (Double.isInfinite(weight)) {
// Once here, 'weight' is positive infinity.
throw new IllegalArgumentException(
"The element weight is infinite.");
return }-1;
}
/**private void ensureCapacity(final int requestedCapacity) {
* Checks that thisif probability(requestedCapacity distribution> containsobjectStorageArray.length) at{
least one element final int newCapacity = Math.max(requestedCapacity,
2 */ objectStorageArray.length);
protected void checkNotEmpty() { final Object[] newObjectStorageArray = new Object[newCapacity];
if final double[] newWeightStorageArray = new double[newCapacity];
System.arraycopy(sizethis.objectStorageArray, == 0), {
throw new IllegalStateException newObjectStorageArray,
0,
this.size);
System.arraycopy(this.weightStorageArray,
"This probability distribution is empty 0,
newWeightStorageArray,
0,
this."size);
this.objectStorageArray = newObjectStorageArray;
this.weightStorageArray = newWeightStorageArray;
}
}
}
package net.coderodde.stat;
import java.util.Objects;
import java.util.Random;
/**
* This class implements an abstract base class for probability distributions.
* Elements are added with strictly positive weights and whenever asking this
* data structure for a random element, their respective weights are taken into
* account. For example, if this data structure contains three different
* elements (<tt>a</tt>, <tt>b</tt>, <tt>c</tt> with respective weights
* <tt>1.0</tt>, <tt>1.0</tt>, <tt>3.0</tt>), whenever asking for a random
* element, there is 20 percent chance of obtaining <tt>a</tt>, 20 percent
* chance of obtaining <tt>b</tt>, and 60 percent chance of obtaining * <tt>c</tt>.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (Jun 11, 2016)
*/
public abstract class AbstractProbabilityDistribution<E> {
/**
* The amount of elements in this probability distribution.
*/
protected int size;
/**
* The sum of all weights.
*/
protected double totalWeight;
/**
* The random number generator of this probability distribution.
*/
protected final Random random;
/**
* Constructs this probability distribution.
*/
protected AbstractProbabilityDistribution() {
this(new Random());
}
/**
* Constructs this probability distribution using the input random number
* generator.
*
* @param random the random number generator.
*/
protected AbstractProbabilityDistribution(final Random random) {
this.random =
Objects.requireNonNull(random,
"The random number generator is null.");
}
public boolean isEmpty() {
return this.size == 0;
}
public int size() {
return this.size;
}
/**
* Adds the element {@code element} to this probability distribution, and
* assigns {@code weight} as its weight.
*
* @param element the element to add.
* @param weight the weight of the new element.
*/
public abstract boolean addElement(final E element, final double weight);
/**
* Returns a randomly chosen element from this probability distribution
* taking the weights into account.
* * @return a randomly chosen element.
*/
public abstract E sampleElement();
/**
* Returns {@code true} if this probability distribution contains the
* element {@code element}.
*
* @param element the element to query.
* @return {@code true} if the input element is in this probability
* distribution; {@code false} otherwise.
*/
public abstract boolean contains(final E element);
/**
* Removes the element {@code element} from this probability distribution.
*
* @param element the element to remove.
* @return {@code true} if the element was present in this probability
* distribution and was successfully removed.
*/
public abstract boolean removeElement(final E element);
/**
* Removes all elements from this probability distribution.
*/
public abstract void clear();
/**
* Checks that the element weight is valid. The weight must not be a
* <tt>NaN</tt> and must be positive, but not a positive infinity.
*
* @param weight the weight to validate.
*/
protected void checkWeight(final double weight) {
if (Double.isNaN(weight)) {
throw new IllegalArgumentException("The element weight is NaN.");
}
if (weight <= 0.0) {
throw new IllegalArgumentException(
"The element weight must be positive. Received " + weight);
}
if (Double.isInfinite(weight)) {
// Once here, 'weight' is positive infinity.
throw new IllegalArgumentException(
"The element weight is infinite.");
}
}
/**
* Checks that this probability distribution contains at least one element.
*/
protected void checkNotEmpty() {
if (size == 0) {
throw new IllegalStateException(
"This probability distribution is empty.");
}
}
}
package net.coderodde.stat.support;
import java.util.HashSet;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import net.coderodde.stat.AbstractProbabilityDistribution;
/**
* This class implements a probability distribution relying on an array of
* elements. The running times are as follows:
*
* <table>
* <tr><td>Method</td> <td>Complexity</td></tr>
* <tr><td><tt>addElement </tt> </td> <td>amortized constant time,</td></tr>
* <tr><td><tt>sampleElement</tt> </td> <td><tt>O(n)</tt>,</td></tr>
* <tr><td><tt>removeElement</tt> </td> <td><tt>O(n)</tt>.</td></tr>
* </table>
*
* @param <E> the actual type of the elements stored in this probability
* distribution.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (Jun 11, 2016)
*/
public class ArrayProbabilityDistribution<E>
extends AbstractProbabilityDistribution<E> {
private static final int DEFAULT_STORAGE_ARRAYS_CAPACITY = 8;
private Object[] objectStorageArray;
private double[] weightStorageArray;
private final Set<E> filterSet = new HashSet<>();
public ArrayProbabilityDistribution() {
this(new Random());
}
public ArrayProbabilityDistribution(final Random random) {
super(random);
this.objectStorageArray = new Object[DEFAULT_STORAGE_ARRAYS_CAPACITY];
this.weightStorageArray = new double[DEFAULT_STORAGE_ARRAYS_CAPACITY];
}
/**
* {@inheritDoc }
*/
@Override
public boolean addElement(final E element, final double weight) {
checkWeight(weight);
if (filterSet.contains(element)) {
// 'element' is already present in this probability distribution. return false;
}
ensureCapacity(this.size + 1);
objectStorageArray[this.size] = element;
weightStorageArray[this.size] = weight;
this.totalWeight += weight;
this.size++;
this.filterSet.add(element);
return true;
}
/**
* {@inheritDoc }
*/
@Override
public E sampleElement() {
checkNotEmpty();
double value = this.random.nextDouble() * this.totalWeight;
for (int i = 0; i < this.size; ++i) {
if (value < this.weightStorageArray[i]) {
return (E) this.objectStorageArray[i];
}
value -= this.weightStorageArray[i];
}
throw new IllegalStateException("Should not get here.");
}
/**
* {@inheritDoc }
*/
@Override
public boolean removeElement(final E element) {
if (!this.filterSet.contains(element)) {
return false;
}
final int index = indexOf(element);
this.totalWeight -= this.weightStorageArray[index];
for (int j = index + 1; j < this.size; ++j) {
objectStorageArray[j - 1] = objectStorageArray[j];
weightStorageArray[j - 1] = weightStorageArray[j];
}
objectStorageArray[--this.size] = null;
return true;
}
/**
* {@inheritDoc }
*/
@Override
public void clear() {
for (int i = 0; i < this.size; ++i) {
objectStorageArray[i] = null;
}
this.size = 0;
this.totalWeight = 0.0;
}
/**
* {@inheritDoc }
*/
@Override
public boolean contains(E element) {
return this.filterSet.contains(element);
}
private int indexOf(final E element) {
for (int i = 0; i < this.size; ++i) {
if (Objects.equals(element, this.objectStorageArray[i])) {
return i;
}
}
return -1;
}
private void ensureCapacity(final int requestedCapacity) {
if (requestedCapacity > objectStorageArray.length) {
final int newCapacity = Math.max(requestedCapacity,
2 * objectStorageArray.length);
final Object[] newObjectStorageArray = new Object[newCapacity];
final double[] newWeightStorageArray = new double[newCapacity];
System.arraycopy(this.objectStorageArray, 0,
newObjectStorageArray,
0,
this.size);
System.arraycopy(this.weightStorageArray,
0,
newWeightStorageArray,
0,
this.size);
this.objectStorageArray = newObjectStorageArray;
this.weightStorageArray = newWeightStorageArray;
}
}
}
Comparing three data structures for dealing with probability distributions in Java
Introduction
Suppose you are given three elements \$a, b, c\$ with respective weights \1,ドル 1, 3\$. Now, a probability distribution data structures will return upon request \$a\$ with probability 20%, \$b\$ with probability 20%, and \$c\$ with probability 60%.
The API for my probability distribution data structures is defined by the following abstract class:
package net.coderodde.stat;
import java.util.Objects;
import java.util.Random;
/**
* This class implements an abstract base class for probability distributions.
* Elements are added with strictly positive weights and whenever asking this
* data structure for a random element, their respective weights are taken into
* account. For example, if this data structure contains three different
* elements (<tt>a</tt>, <tt>b</tt>, <tt>c</tt> with respective weights
* <tt>1.0</tt>, <tt>1.0</tt>, <tt>3.0</tt>), whenever asking for a random
* element, there is 20 percent chance of obtaining <tt>a</tt>, 20 percent
* chance of obtaining <tt>b</tt>, and 60 percent chance of obtaining
* <tt>c</tt>.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (Jun 11, 2016)
*/
public abstract class AbstractProbabilityDistribution<E> {
/**
* The amount of elements in this probability distribution.
*/
protected int size;
/**
* The sum of all weights.
*/
protected double totalWeight;
/**
* The random number generator of this probability distribution.
*/
protected final Random random;
/**
* Constructs this probability distribution.
*/
protected AbstractProbabilityDistribution() {
this(new Random());
}
/**
* Constructs this probability distribution using the input random number
* generator.
*
* @param random the random number generator.
*/
protected AbstractProbabilityDistribution(final Random random) {
this.random =
Objects.requireNonNull(random,
"The random number generator is null.");
}
public boolean isEmpty() {
return this.size == 0;
}
public int size() {
return this.size;
}
/**
* Adds the element {@code element} to this probability distribution, and
* assigns {@code weight} as its weight.
*
* @param element the element to add.
* @param weight the weight of the new element.
*
* @return {@code true} only if the input element did not reside in this
* structure and was successfully added.
*/
public abstract boolean addElement(final E element, final double weight);
/**
* Returns a randomly chosen element from this probability distribution
* taking the weights into account.
*
* @return a randomly chosen element.
*/
public abstract E sampleElement();
/**
* Returns {@code true} if this probability distribution contains the
* element {@code element}.
*
* @param element the element to query.
* @return {@code true} if the input element is in this probability
* distribution; {@code false} otherwise.
*/
public abstract boolean contains(final E element);
/**
* Removes the element {@code element} from this probability distribution.
*
* @param element the element to remove.
* @return {@code true} if the element was present in this probability
* distribution and was successfully removed.
*/
public abstract boolean removeElement(final E element);
/**
* Removes all elements from this probability distribution.
*/
public abstract void clear();
/**
* Checks that the element weight is valid. The weight must not be a
* <tt>NaN</tt> and must be positive, but not a positive infinity.
*
* @param weight the weight to validate.
*/
protected void checkWeight(final double weight) {
if (Double.isNaN(weight)) {
throw new IllegalArgumentException("The element weight is NaN.");
}
if (weight <= 0.0) {
throw new IllegalArgumentException(
"The element weight must be positive. Received " + weight);
}
if (Double.isInfinite(weight)) {
// Once here, 'weight' is positive infinity.
throw new IllegalArgumentException(
"The element weight is infinite.");
}
}
/**
* Checks that this probability distribution contains at least one element.
*/
protected void checkNotEmpty() {
if (size == 0) {
throw new IllegalStateException(
"This probability distribution is empty.");
}
}
}
Implementations
The first probability distribution data structure relies on arrays. It has following running times:
- element addition in ammortized constant time,
- element removal in worst-case linear time,
- element sampling in worst-case linear time.
The data structure follows:
ArrayProbabilityDistribution.java:
package net.coderodde.stat;
import java.util.Objects;
import java.util.Random;
/**
* This class implements an abstract base class for probability distributions.
* Elements are added with strictly positive weights and whenever asking this
* data structure for a random element, their respective weights are taken into
* account. For example, if this data structure contains three different
* elements (<tt>a</tt>, <tt>b</tt>, <tt>c</tt> with respective weights
* <tt>1.0</tt>, <tt>1.0</tt>, <tt>3.0</tt>), whenever asking for a random
* element, there is 20 percent chance of obtaining <tt>a</tt>, 20 percent
* chance of obtaining <tt>b</tt>, and 60 percent chance of obtaining
* <tt>c</tt>.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (Jun 11, 2016)
*/
public abstract class AbstractProbabilityDistribution<E> {
/**
* The amount of elements in this probability distribution.
*/
protected int size;
/**
* The sum of all weights.
*/
protected double totalWeight;
/**
* The random number generator of this probability distribution.
*/
protected final Random random;
/**
* Constructs this probability distribution.
*/
protected AbstractProbabilityDistribution() {
this(new Random());
}
/**
* Constructs this probability distribution using the input random number
* generator.
*
* @param random the random number generator.
*/
protected AbstractProbabilityDistribution(final Random random) {
this.random =
Objects.requireNonNull(random,
"The random number generator is null.");
}
public boolean isEmpty() {
return this.size == 0;
}
public int size() {
return this.size;
}
/**
* Adds the element {@code element} to this probability distribution, and
* assigns {@code weight} as its weight.
*
* @param element the element to add.
* @param weight the weight of the new element.
*/
public abstract boolean addElement(final E element, final double weight);
/**
* Returns a randomly chosen element from this probability distribution
* taking the weights into account.
*
* @return a randomly chosen element.
*/
public abstract E sampleElement();
/**
* Returns {@code true} if this probability distribution contains the
* element {@code element}.
*
* @param element the element to query.
* @return {@code true} if the input element is in this probability
* distribution; {@code false} otherwise.
*/
public abstract boolean contains(final E element);
/**
* Removes the element {@code element} from this probability distribution.
*
* @param element the element to remove.
* @return {@code true} if the element was present in this probability
* distribution and was successfully removed.
*/
public abstract boolean removeElement(final E element);
/**
* Removes all elements from this probability distribution.
*/
public abstract void clear();
/**
* Checks that the element weight is valid. The weight must not be a
* <tt>NaN</tt> and must be positive, but not a positive infinity.
*
* @param weight the weight to validate.
*/
protected void checkWeight(final double weight) {
if (Double.isNaN(weight)) {
throw new IllegalArgumentException("The element weight is NaN.");
}
if (weight <= 0.0) {
throw new IllegalArgumentException(
"The element weight must be positive. Received " + weight);
}
if (Double.isInfinite(weight)) {
// Once here, 'weight' is positive infinity.
throw new IllegalArgumentException(
"The element weight is infinite.");
}
}
/**
* Checks that this probability distribution contains at least one element.
*/
protected void checkNotEmpty() {
if (size == 0) {
throw new IllegalStateException(
"This probability distribution is empty.");
}
}
}
The second probability distribution data structure relies on a linked list, and provides the following operation:
- element addition in ammortized constant time,
- element removal in constant time,
- element sampling in worst-case linear time.
The data structure follows:
LinkedListProbabilityDistribution.java:
package net.coderodde.stat.support;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import net.coderodde.stat.AbstractProbabilityDistribution;
/**
* This class implements a probability distribution relying on a linked list.
* The running times of the main methods are as follows:
*
* <table>
* <tr><td>Method</td> <td>Complexity</td></tr>
* <tr><td><tt>addElement </tt></td>
* <td><tt>amortized constant time</tt>,</td></tr>
* <tr><td><tt>sampleElement</tt> </td> <td><tt>O(n)</tt>,</td></tr>
* <tr><td><tt>removeElement</tt> </td> <td><tt>O(1)</tt>.</td></tr>
* </table>
*
* This probability distribution class is best used whenever it is modified
* frequently compared to the number of queries made.
*
* @param <E> the actual type of the elements stored in this probability
* distribution.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (Jun 11, 2016)
*/
public class LinkedListProbabilityDistribution<E>
extends AbstractProbabilityDistribution<E> {
private static final class LinkedListNode<E> {
private final E element;
private final double weight;
private LinkedListNode<E> prev;
private LinkedListNode<E> next;
LinkedListNode(final E element, final double weight) {
this.element = element;
this.weight = weight;
}
E getElement() {
return this.element;
}
double getWeight() {
return this.weight;
}
LinkedListNode<E> getPreviousLinkedListNode() {
return this.prev;
}
LinkedListNode<E> getNextLinkedListNode() {
return this.next;
}
void setPreviousLinkedListNode(final LinkedListNode<E> node) {
this.prev = node;
}
void setNextLinkedListNode(final LinkedListNode<E> node) {
this.next = node;
}
}
/**
* This map maps the elements to their respective linked list nodes.
*/
private final Map<E, LinkedListNode<E>> map = new HashMap<>();
/**
* Stores the very first linked list node in this probability distribution.
*/
private LinkedListNode<E> linkedListHead;
/**
* Stores the very last linked list node in this probability distribution.
*/
private LinkedListNode<E> linkedListTail;
/**
* Construct a new probability distribution.
*/
public LinkedListProbabilityDistribution() {
super();
}
/**
* Constructs a new probability distribution using the input random number
* generator.
*
* @param random the random number generator to use.
*/
public LinkedListProbabilityDistribution(final Random random) {
super(random);
}
/**
* {@inheritDoc }
*/
@Override
public boolean addElement(final E element, final double weight) {
checkWeight(weight);
if (this.map.containsKey(element)) {
return false;
}
final LinkedListNode<E> newnode = new LinkedListNode<>(element, weight);
if (linkedListHead == null) {
linkedListHead = newnode;
linkedListTail = newnode;
} else {
linkedListTail.setNextLinkedListNode(newnode);
newnode.setPreviousLinkedListNode(linkedListTail);
linkedListTail = newnode;
}
this.map.put(element, newnode);
this.size++;
this.totalWeight += weight;
return true;
}
/**
* {@inheritDoc }
*/
@Override
public E sampleElement() {
checkNotEmpty();
double value = this.random.nextDouble() * this.totalWeight;
for (LinkedListNode<E> node = linkedListHead;
node != null;
node = node.getNextLinkedListNode()) {
if (value < node.getWeight()) {
return node.getElement();
}
value -= node.getWeight();
}
throw new IllegalStateException("Should not get here.");
}
/**
* {@inheritDoc }
*/
@Override
public boolean contains(E element) {
return this.map.containsKey(element);
}
/**
* {@inheritDoc }
*/
@Override
public boolean removeElement(E element) {
final LinkedListNode<E> node = map.get(element);
if (node == null) {
return false;
}
this.map.remove(element);
this.size--;
this.totalWeight -= node.getWeight();
unlink(node);
return true;
}
/**
* {@inheritDoc }
*/
@Override
public void clear() {
this.size = 0;
this.totalWeight = 0.0;
this.map.clear();
this.linkedListHead = null;
this.linkedListTail = null;
}
private void unlink(final LinkedListNode<E> node) {
final LinkedListNode<E> left = node.getPreviousLinkedListNode();
final LinkedListNode<E> right = node.getNextLinkedListNode();
if (left != null) {
left.setNextLinkedListNode(node.getNextLinkedListNode());
} else {
this.linkedListHead = node.getNextLinkedListNode();
}
if (right != null) {
right.setPreviousLinkedListNode(node.getPreviousLinkedListNode());
} else {
this.linkedListTail = node.getPreviousLinkedListNode();
}
}
}
The third data structure relies on a binary tree and runs all the three main methods in worst-case logarithmic time. It looks like this:
Above, the red nodes are the leaf nodes containing the actual elements. White nodes are called in the code relay nodes. The integers in each node denote how many leaf nodes a particular relay node contains, and the real numbers denote the sum of weights of all the leaves of a relay node.
The data structure follows:
BinaryTreeProbabilityDistribution.java:
package net.coderodde.stat.support;
import java.util.Deque;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Random;
import net.coderodde.stat.AbstractProbabilityDistribution;
/**
* This class implements a probability distribution relying on a binary tree
* structure. It allows <tt>O(log n)</tt> worst case time for adding, removing
* and sampling an element.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (Jun 11, 2016)
* @param <E> the actual type of the elements stored in this distribution.
*/
public class BinaryTreeProbabilityDistribution<E>
extends AbstractProbabilityDistribution<E> {
private static final class Node<E> {
/**
* Holds the element if this block is a leaf. Internal blocks have
* {@code null} assigned to this field.
*/
private final E element;
/**
* If this block is a leaf, specifies the weight of the {@code element}.
* Otherwise, this field caches the sum of all weights over all
* descendant leaves.
*/
private double weight;
private boolean isRelayNode;
/**
* The left child node.
*/
private Node<E> leftChild;
/**
* The right child node.
*/
private Node<E> rightChild;
/**
* The parent node.
*/
private Node<E> parent;
/**
* Caches the number of leaf nodes in the subtree starting from this
* node.
*/
private int numberOfLeafNodes;
Node(final E element, final double weight) {
this.element = element;
this.weight = weight;
this.numberOfLeafNodes = 1;
}
Node() {
this.element = null;
this.isRelayNode = true;
}
public String toString() {
if (this.isRelayNode) {
return "[" + String.format("%.3f", this.getWeight()) +
" : " + this.numberOfLeafNodes + "]";
}
return "(" + String.format("%.3f", this.getWeight()) +
" : " + this.element + ")";
}
E getElement() {
return this.element;
}
double getWeight() {
return this.weight;
}
void setWeight(final double weight) {
this.weight = weight;
}
int getNumberOfLeaves() {
return this.numberOfLeafNodes;
}
void setNumberOfLeaves(final int numberOfLeaves) {
this.numberOfLeafNodes = numberOfLeaves;
}
Node<E> getLeftChild() {
return this.leftChild;
}
void setLeftChild(final Node<E> block) {
this.leftChild = block;
}
Node<E> getRightChild() {
return this.rightChild;
}
void setRightChild(final Node<E> block) {
this.rightChild = block;
}
Node<E> getParent() {
return this.parent;
}
void setParent(final Node<E> block) {
this.parent = block;
}
boolean isRelayNode() {
return isRelayNode;
}
boolean isLeafNode() {
return !isRelayNode;
}
}
/**
* Maps each element to the list of nodes representing the element.
*/
private final Map<E, Node<E>> map = new HashMap<>();
/**
* The root node of this distribution tree.
*/
private Node<E> root;
/**
* Constructs this probability distribution using a default random number
* generator.
*/
public BinaryTreeProbabilityDistribution() {
this(new Random());
}
/**
* Constructs this probability distribution using the input random number
* generator.
*
* @param random the random number generator to use.
*/
public BinaryTreeProbabilityDistribution(final Random random) {
super(random);
}
/**
* {@inheritDoc }
*/
@Override
public boolean addElement(E element, double weight) {
checkWeight(weight);
if (this.map.containsKey(element)) {
return false;
}
final Node<E> newnode = new Node<>(element, weight);
insert(newnode);
this.size++;
this.totalWeight += weight;
this.map.put(element, newnode);
return true;
}
/**
* {@inheritDoc }
*/
@Override
public boolean contains(E element) {
return this.map.containsKey(element);
}
/**
* {@inheritDoc }
*/
@Override
public E sampleElement() {
checkNotEmpty();
double value = this.totalWeight * this.random.nextDouble();
Node<E> node = root;
while (node.isRelayNode()) {
if (value < node.getLeftChild().getWeight()) {
node = node.getLeftChild();
} else {
value -= node.getLeftChild().getWeight();
node = node.getRightChild();
}
}
return node.getElement();
}
/**
* {@inheritDoc }
*/
@Override
public boolean removeElement(final E element) {
final Node<E> node = this.map.get(element);
if (node == null) {
return false;
}
delete(node);
updateMetadata(node.getParent(), -node.getWeight(), -1);
this.size--;
this.totalWeight -= node.getWeight();
return true;
}
/**
* {@inheritDoc }
*/
@Override
public void clear() {
this.root = null;
this.size = 0;
this.totalWeight = 0.0;
}
/**
* Assuming that {@code leafNodeToBypass} is a leaf node, this procedure
* attaches a relay node instead of it, and assigns {@code leafNodeToBypass}
* and {@code newnode} as children of the new relay node.
*
* @param leafNodeToBypass the leaf node to bypass.
* @param newNode the new node to add.
*/
private void bypassLeafNode(final Node<E> leafNodeToBypass,
final Node<E> newNode) {
final Node<E> relayNode = new Node<>();
final Node<E> parentOfCurrentNode = leafNodeToBypass.getParent();
relayNode.setNumberOfLeaves(1);
relayNode.setWeight(leafNodeToBypass.getWeight());
relayNode.setLeftChild(leafNodeToBypass);
relayNode.setRightChild(newNode);
leafNodeToBypass.setParent(relayNode);
newNode.setParent(relayNode);
if (parentOfCurrentNode == null) {
this.root = relayNode;
} else if (parentOfCurrentNode.getLeftChild() == leafNodeToBypass) {
relayNode.setParent(parentOfCurrentNode);
parentOfCurrentNode.setLeftChild(relayNode);
} else {
relayNode.setParent(parentOfCurrentNode);
parentOfCurrentNode.setRightChild(relayNode);
}
updateMetadata(relayNode, newNode.getWeight(), 1);
}
private void insert(final Node<E> node) {
if (root == null) {
root = node;
return;
}
Node<E> currentNode = root;
while (currentNode.isRelayNode()) {
if (currentNode.getLeftChild().getNumberOfLeaves() <
currentNode.getRightChild().getNumberOfLeaves()) {
currentNode = currentNode.getLeftChild();
} else {
currentNode = currentNode.getRightChild();
}
}
bypassLeafNode(currentNode, node);
}
private void delete(final Node<E> leafToDelete) {
final Node<E> relayNode = leafToDelete.getParent();
if (relayNode == null) {
this.root = null;
return;
}
final Node<E> parentOfRelayNode = relayNode.getParent();
final Node<E> siblingLeaf = relayNode.getLeftChild() == leafToDelete ?
relayNode.getRightChild() :
relayNode.getLeftChild();
if (parentOfRelayNode == null) {
this.root = siblingLeaf;
siblingLeaf.setParent(null);
return;
}
if (parentOfRelayNode.getLeftChild() == relayNode) {
parentOfRelayNode.setLeftChild(siblingLeaf);
} else {
parentOfRelayNode.setRightChild(siblingLeaf);
}
siblingLeaf.setParent(parentOfRelayNode);
}
/**
* This method is responsible for updating the metadata of this data
* structure.
*
* @param node the node from which to start the metadata update. The
* updating routine updates also the metadata of all the
* predecessors of this node in the tree.
* @param weight the weight delta to add to each predecessor node.
* @param nodeDelta the node count delta to add to each predecessor node.
*/
private void updateMetadata(Node<E> node,
final double weightDelta,
final int nodeDelta) {
while (node != null) {
node.setNumberOfLeaves(node.getNumberOfLeaves() + nodeDelta);
node.setWeight(node.getWeight() + weightDelta);
node = node.getParent();
}
}
public String debugToString() {
if (root == null) {
return "empty";
}
final StringBuilder sb = new StringBuilder();
final int treeHeight = getTreeHeight(root);
final Deque<Node<E>> queue = new LinkedList<>();
queue.addLast(root);
for (int i = 0; i < treeHeight + 1; ++i) {
int currentQueueLength = queue.size();
for (int j = 0; j < currentQueueLength; ++j) {
final Node<E> node = queue.removeFirst();
addChildren(node, queue);
sb.append(node == null ? "null" : node.toString()).append(" ");
}
sb.append("\n");
}
return sb.toString();
}
private void addChildren(final Node<E> node, final Deque<Node<E>> queue) {
if (node == null) {
queue.addLast(null);
queue.addLast(null);
return;
}
queue.addLast(node.getLeftChild());
queue.addLast(node.getRightChild());
}
private int getTreeHeight(final Node<E> node) {
if (node == null) {
return -1;
}
return 1 + Math.max(getTreeHeight(node.getLeftChild()),
getTreeHeight(node.getRightChild()));
}
}
Finally, the demo (I don't want to get reviewed) is...
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import net.coderodde.stat.AbstractProbabilityDistribution;
import net.coderodde.stat.support.ArrayProbabilityDistribution;
import net.coderodde.stat.support.BinaryTreeProbabilityDistribution;
import net.coderodde.stat.support.LinkedListProbabilityDistribution;
public class Demo {
private static final int DISTRIBUTION_SIZE = 20_000;
public static void main(final String[] args) {
System.out.println("[DEMO] BinaryTreeProbabilityDistribution:");
binaryTreeProbabilityDistributionDemo();
System.out.println("[STATUS] Warming up...");
warmup();
System.out.println("[STATUS] Warming up done!");
System.out.println();
AbstractProbabilityDistribution<Integer> arraypd =
new ArrayProbabilityDistribution<>();
AbstractProbabilityDistribution<Integer> listpd =
new LinkedListProbabilityDistribution<>();
AbstractProbabilityDistribution<Integer> treepd =
new BinaryTreeProbabilityDistribution<>();
profile(arraypd);
profile(listpd);
profile(treepd);
}
private static void binaryTreeProbabilityDistributionDemo() {
BinaryTreeProbabilityDistribution<Integer> pd =
new BinaryTreeProbabilityDistribution<>();
pd.addElement(0, 1.0);
pd.addElement(1, 1.0);
pd.addElement(2, 1.0);
pd.addElement(3, 3.0);
int[] counts = new int[4];
for (int i = 0; i < 100; ++i) {
Integer myint = pd.sampleElement();
counts[myint]++;
System.out.println(myint);
}
System.out.println(Arrays.toString(counts));
}
private static void
profile(final AbstractProbabilityDistribution<Integer> pd) {
final Random random = new Random();
System.out.println("[" + pd.getClass().getSimpleName() + "]:");
long totalDuration = 0L;
long startTime = System.currentTimeMillis();
for (int i = 0; i < DISTRIBUTION_SIZE; ++i) {
pd.addElement(i, 10.0 * random.nextDouble());
}
long endTime = System.currentTimeMillis();
System.out.println("addElement() in " + (endTime - startTime) +
" milliseconds.");
totalDuration += (endTime - startTime);
startTime = System.currentTimeMillis();
for (int i = 0; i < DISTRIBUTION_SIZE; ++i) {
pd.sampleElement();
}
endTime = System.currentTimeMillis();
System.out.println("sampleElement() in " + (endTime - startTime) +
" milliseconds.");
totalDuration += (endTime - startTime);
final List<Integer> contents = new ArrayList<>(DISTRIBUTION_SIZE);
for (int i = 0; i < DISTRIBUTION_SIZE; ++i) {
contents.add(i);
}
shuffle(contents);
startTime = System.currentTimeMillis();
for (Integer i : contents) {
pd.removeElement(i);
}
endTime = System.currentTimeMillis();
System.out.println("removeElement() in " + (endTime - startTime) +
" milliseconds.");
totalDuration += (endTime - startTime);
System.out.println("Total duration: " + totalDuration +
" milliseconds.");
System.out.println();
}
private static void shuffle(final List<Integer> list) {
final Random random = new Random();
for (int i = 0; i < list.size(); ++i) {
final int index = random.nextInt(list.size());
final Integer integer = list.get(index);
list.set(index, list.get(i));
list.set(i, integer);
}
}
private static void warmup() {
final long seed =35214717058750L; System.nanoTime();
final Random inputRandom1 = new Random(seed);
final Random inputRandom2 = new Random(seed);
final Random inputRandom3 = new Random(seed);
final AbstractProbabilityDistribution<Integer> pd1 =
new ArrayProbabilityDistribution<>(inputRandom1);
final AbstractProbabilityDistribution<Integer> pd2 =
new LinkedListProbabilityDistribution<>(inputRandom2);
final AbstractProbabilityDistribution<Integer> pd3 =
new BinaryTreeProbabilityDistribution<>(inputRandom3);
final Random random = new Random(seed);
final List<Integer> content = new ArrayList<>();
System.out.println("Seed = " + seed);
for (int iteration = 0; iteration < 100_000; ++iteration) {
final double coin = random.nextDouble();
if (coin < 0.3) {
// Add a new element.
final Integer element = random.nextInt();
final double weight = 30.0 * random.nextDouble();
content.add(element);
pd1.addElement(element, weight);
pd2.addElement(element, weight);
pd3.addElement(element, weight);
} else if (coin < 0.5) {
// Remove an element.
if (!pd1.isEmpty()) {
final Integer element = choose(content, random);
pd1.removeElement(element);
pd2.removeElement(element);
pd3.removeElement(element);
content.remove(element);
}
} else if (!pd1.isEmpty()) {
// Sample elements:
pd1.sampleElement();
pd2.sampleElement();
pd3.sampleElement();
}
}
}
private static Integer choose(final List<Integer> list,
final Random random) {
return list.get(random.nextInt(list.size()));
}
}
The performance figures are as follows:
[STATUS] Warming up... Seed = 35214717058750 [STATUS] Warming up done! [ArrayProbabilityDistribution]: addElement() in 8 milliseconds. sampleElement() in 321 milliseconds. removeElement() in 500 milliseconds. Total duration: 829 milliseconds. [LinkedListProbabilityDistribution]: addElement() in 7 milliseconds. sampleElement() in 1184 milliseconds. removeElement() in 9 milliseconds. Total duration: 1200 milliseconds. [BinaryTreeProbabilityDistribution]: addElement() in 24 milliseconds. sampleElement() in 15 milliseconds. removeElement() in 16 milliseconds. Total duration: 55 milliseconds.
Critique request
I would like to hear the comments regarding the following:
- API design,
- naming conventions,
- coding conventions,
- performance.