I wanted a data structure that allowed me to set the chances of randomly returning each of it's elements.
For example, suppose I have a Human class. Every Human has an attribute called eyeColor. I don't know what the actual percentages are, but let's say 60% of people have brown eyes, 30% have blue eyes, and 10% have green eyes.
Using this class I set the chances, out of 100, of returning any given eye color.
To do this, I use a TreeMap and choose a random double between 0, and 100 (inclusive). Then I return the value using the TreeMap's ceilingEntry method, unless that would return a null. In that case I return the value from the floorEntry method.
The restriction is that the sum of all the chances must equal 100, or nearly so, to return anything.
How can I make this data structure run faster, and make the code more elegant?
My Data Structure:
import java.util.TreeMap;
import java.util.concurrent.ThreadLocalRandom;
/**
* Objects of this class can have object or primitive types added to them and
* retrieved randomly with a different chance for different objects.
* <p>
* This is done by having the user include a "percentage chance" when adding new
* elements.
* <p>
* The percentage chance represents the percentage, out of 100, that the added
* element will be returned when the getRandomElement() method is called.
* <p>
* The sum of all the percentage chances should never be greater than 100 or the
* program will throw an IllegalArgumentException.
*/
public final class RandomTree<T> {
// holds the objects and primitives to be randomly returned
private TreeMap<Double, T> tree;
// keeps track of whether the RandomSet is full
private double sum;
public RandomTree() {
// contains the values to be randomly returned with getRandomElement()
this.tree = new TreeMap<>();
// keeps track of the sum of the percentages
this.sum = 0.0;
}
/**
* Adds a new object or primitive to the RandomTree. The percentChance
* argument represents the chances, out of 100, that the element will be
* return when the getRandomElement() method is called.
* <p>
* The sum of all the percentage chances including the given percentChance
* argument must be less than or equal to 100. Also, the percentChance
* argument must be greater than zero. Otherwise, the program throws
* an IllegalArgumentException.
*
* @param object The object or primitive to add.
* @param percentChance The chance of returning the object argument.
*/
public void add(final T object, final double percentChance) {
this.sum += percentChance;
// do not allow negative percent chances
if (percentChance <= 0.0) {
throw new IllegalArgumentException("percentChance must be > 0.0");
// prevent unnecessary exception throwing over being slightly more than 100.0
} else if (Math.abs(this.sum - 100.0) < 0.1 && this.sum != 100.0) {
this.sum = 100.0;
this.tree.put(this.sum, object);
// do not allow values to be above 100.0
} else if (this.sum > 100.0) {
throw new IllegalArgumentException(this.sum + " is > 100.");
// prevent unnecessary exception throwing over being slightly less than 100.0
} else if (100.0 - this.sum < 0.1 && this.sum != 100.0) {
this.sum = 100.0;
this.tree.put(this.sum, object);
} else // add the key and value to this.tree
this.tree.put(this.sum, object);
}
/**
* Returns a getRandomRace element from this.values.
* Elements with higher associated percentage-change values (in this.keys)
* are more likely to be returned.
* <p>
* Throws an IllegalArgumentException if this.sum is not equal to 100.
*/
public T getRandomElement() {
// don't allow retrieval before this.sum == 100.0
if (this.sum != 100.0)
throw new IllegalArgumentException("sum == " + this.sum);
double choice = ThreadLocalRandom.current().nextDouble(Math.nextUp(100.0));
final T obj = this.tree.ceilingEntry(choice).getValue();
if (obj != null)
return obj;
else
return this.tree.floorEntry(choice).getValue();
}
/**
* Returns the sum of the percentage chances added to this RandomTree.
*
* @return The sum of the percentage chances added to this RandomTree.
*/
public double getSum() {
return this.sum;
}
}
Test Class:
public class RandomTreeTest {
public static void main(String[] args) {
RandomTree<String> randTree = new RandomTree<>();
randTree.add("ten", 10.0);
randTree.add("twenty", 20.0);
randTree.add("thirty", 30.0);
randTree.add("forty", 40.0);
int countTens = 0;
int countTwenties = 0;
int countThirties = 0;
int countForties = 0;
final double iterationNumber = 1000.0;
for(int i = 0; i < iterationNumber; i++) {
String num = randTree.getRandomElement();
switch(num) {
case "ten":
countTens++;
break;
case "twenty":
countTwenties++;
break;
case "thirty":
countThirties++;
break;
case "forty":
countForties++;
break;
}
}
double percentOfTens = (countTens / iterationNumber) * 100;
double percentOfTwenties = (countTwenties / iterationNumber) * 100;
double percentOfThirties = (countThirties / iterationNumber) * 100;
double percentOfForties = (countForties / iterationNumber) * 100;
String msg = "tens: " + percentOfTens + "%" + System.lineSeparator();
msg += "twenties: " + percentOfTwenties + "%" + System.lineSeparator();
msg += "thirties: " + percentOfThirties + "%" + System.lineSeparator();
msg += "forties: " + percentOfForties + "%" + System.lineSeparator();
System.out.println(msg);
}
}
2 Answers 2
Validate using Builder pattern
Building a RandomTree
follows a very canonical pattern: you add data, validate, and then start using the tree. However, your object-oriented design fails to reflect this: a valid or invalid RandomTree
has the same type. I recommend using the Builder pattern to remedy this: instantiate a RandomTree.Builder
, add entries to the builder, and then call a build()
method to validate and return a RandomTree
. If done properly, this guarantees that all RandomTree
objects are valid. It also leads to a clean separation of validation and sampling code.
Simplify approximate double equality
If you use Math.abs
properly, you should be able to deal with the cases where sum is slightly too high and slightly too low at the same time.
Use simpler data structure
Using your tree, each call to getRandomElement
takes time O(log(n)); this is already very fast. However, building the tree takes time O(n log(n)).
We can do better using a simple array: store the same "cumulative probability" values you currently have in your tree in an array. Then binary search the array to find the ceiling entry. This too takes time O(log(n)) to search, but only takes O(n) to build. Concretely, searching a list should be a bit faster than searching a tree; you'd have to do some profiling to test this.
Randomness
Percents are kind of arbitrary; a much more natural way to represent probabilities in with numbers in [0,1]. You can always divide the input by 100 if need be.
In addition, you should let the caller pass in an instance of Random
to getRandomElement
. Since your class doesn't care how the randomness arises, this allows more flexibility.
Changes
Here is my stab at the changes (comments omitted)
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
public class RandomSampler<T> {
private static final double PRECISION = 0.001;
public static class Builder<T> {
private List<T> items;
private List<Double> probabilities;
private Builder() {
this.items = new ArrayList<T>();
this.probabilities = new ArrayList<Double>();
}
public void add(final T item, final double probability) {
this.items.add(item);
this.probabilities.add(probability);
}
public RandomSampler<T> build() {
return new RandomSampler<T>(items, probabilities);
}
}
public static <T> Builder<T> builder() {
return new Builder<T>();
}
private final List<T> items;
private final double[] cumulativeProbabilities;
private RandomSampler(final List<T> items, final List<Double> probabilities) {
double cumulativeProbability = 0.0;
this.items = items;
this.cumulativeProbabilities = new double[items.size()];
Iterator<Double> it = probabilities.iterator();
for (int i = 0; i < items.size(); i++) {
cumulativeProbability += it.next();
this.cumulativeProbabilities[i] = cumulativeProbability;
}
if (Math.abs(cumulativeProbability - 1.0) > PRECISION) {
throw new IllegalStateException("probabilities do not sum to 1.0");
} else {
// fix last cumulative probability to 1.0
this.cumulativeProbabilities[items.size() - 1] = 1.0;
}
}
public T getRandomElement(Random rand) {
double choice = rand.nextDouble();
// equal to (-(i)-1) where cumulativeProbabilities[i] is the first element > choice
int searchResult = Arrays.binarySearch(this.cumulativeProbabilities, choice);
int i = -(searchResult + 1);
return this.items.get(i);
}
}
import java.util.concurrent.ThreadLocalRandom;
public class RandomSamplerTest {
public static void main(String[] args) {
RandomSampler.Builder<String> builder = RandomSampler.builder();
builder.add("ten", 0.1);
builder.add("twenty", 0.2);
builder.add("thirty", 0.3);
builder.add("forty", 0.4);
RandomSampler<String> sampler = builder.build();
ThreadLocalRandom rand = ThreadLocalRandom.current();
int countTens = 0;
int countTwenties = 0;
int countThirties = 0;
int countForties = 0;
final double iterationNumber = 1000;
for(int i = 0; i < iterationNumber; i++) {
String num = sampler.getRandomElement(rand);
switch(num) {
case "ten":
countTens++;
break;
case "twenty":
countTwenties++;
break;
case "thirty":
countThirties++;
break;
case "forty":
countForties++;
break;
}
}
double percentOfTens = (countTens / iterationNumber) * 100;
double percentOfTwenties = (countTwenties / iterationNumber) * 100;
double percentOfThirties = (countThirties / iterationNumber) * 100;
double percentOfForties = (countForties / iterationNumber) * 100;
String msg = "tens: " + percentOfTens + "%" + System.lineSeparator();
msg += "twenties: " + percentOfTwenties + "%" + System.lineSeparator();
msg += "thirties: " + percentOfThirties + "%" + System.lineSeparator();
msg += "forties: " + percentOfForties + "%" + System.lineSeparator();
System.out.println(msg);
}
}
How much storage space and accuracy are you willing to trade for O(1)?
Initialize a 100 element array that contains all samples (60 brown eyes, 30 blue eyes, and 10 green eyes) and pick an index randomly.
-
1\$\begingroup\$ This idea works quite well for integer (or even rational valued) percentages; but not for arbitrary doubles. \$\endgroup\$Benjamin Kuykendall– Benjamin Kuykendall2019年08月27日 06:07:21 +00:00Commented Aug 27, 2019 at 6:07
-
\$\begingroup\$ That's right. The correct balance of accuracy, performance and storage must be selected to match the requirements. \$\endgroup\$TorbenPutkonen– TorbenPutkonen2019年08月27日 06:13:16 +00:00Commented Aug 27, 2019 at 6:13
RandomTree(int[] percentages, T[] items)
? You can then enforcesum(percentages) == 100
much more effectively. From an API design perspective, it's problematic to support anadd
function where you have specific input restrictions across multiple calls. It's easy for clients to screw up the input. \$\endgroup\$