Intro
(See MostProbablePath.java.)
This time, I elaborate on Computing most probable (reliable) path in a probabilistic graph (take II): instead of computing the most reliable path I now return \$k\$ most reliable paths for any \$k \in \mathbb{N}\$.
Code
io.github.coderodde.prob.KmostProbablePathsFinder.java:
package io.github.coderodde.prob;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Queue;
/**
*
* @author Rodion "rodde" Efremov
* @version 1.0.0 (Sep 17, 2025)
* @since 1.0.0 (Sep 17, 2025)
*/
public final class KmostProbablePathsFinder {
public List<Result> findKmostProbablePaths(GraphNode source,
GraphNode target,
int k) {
Objects.requireNonNull(source, "The source node is null");
Objects.requireNonNull(target, "The target node is null");
checkK(k);
List<LinkedPathNode> paths = new ArrayList<>(k);
Map<GraphNode, Integer> countMap = new HashMap<>();
Queue<LinkedPathNode> openQueue = new PriorityQueue<>();
openQueue.add(new LinkedPathNode(source));
while (!openQueue.isEmpty() && countMap.getOrDefault(target, 0) < k) {
LinkedPathNode currentPath = openQueue.remove();
GraphNode endNode = currentPath.getTailNode();
countMap.put(endNode, countMap.getOrDefault(endNode, 0) + 1);
if (endNode.equals(target)) {
paths.add(currentPath);
continue;
}
for (GraphNode child : endNode.getNeighbours()) {
if (!currentPath.contains(child)) {
openQueue.add(currentPath.append(child));
}
}
}
return buildPathResults(paths);
}
private static void checkK(int k) {
if (k < 1) {
throw new IllegalArgumentException(String.format("k(%d) < 1", k));
}
}
private static List<Result> buildPathResults(List<LinkedPathNode> paths) {
List<Result> results = new ArrayList<>(paths.size());
for (LinkedPathNode path : paths) {
results.add(new Result(path.toPath(), Math.exp(-path.getCost())));
}
// Sort the results. Most probable/reliable comes first:
Collections.sort(results,
(a, b) -> Double.compare(b.getProbability(),
a.getProbability()));
return results;
}
}
io.github.coderodde.prob.LinkedPathNode.java:
package io.github.coderodde.prob;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
*
* @author Rodion "rodde" Efremov
* @version 1.0.0 (Sep 17, 2025)
* @since 1.0.0 (Sep 17, 2025)
*/
final class LinkedPathNode implements Comparable<LinkedPathNode> {
private final GraphNode node;
private final LinkedPathNode previousLinkedNodePath;
private final double cost;
private final Set<GraphNode> onPath;
LinkedPathNode(GraphNode node) {
this.node = node;
cost = 0.0;
previousLinkedNodePath = null;
onPath = new HashSet<>();
onPath.add(node);
}
LinkedPathNode(LinkedPathNode path, GraphNode node) {
this.node = node;
previousLinkedNodePath = path;
this.onPath = new HashSet<>(path.onPath);
this.onPath.add(node);
double p = path.node.getEdgeProbabilityTo(node);
cost = path.cost - Math.log(p);
}
boolean contains(GraphNode node) {
return onPath.contains(node);
}
double getCost() {
return cost;
}
LinkedPathNode append(GraphNode node) {
return new LinkedPathNode(this, node);
}
LinkedPathNode getPreviousLinkedPathNode() {
return previousLinkedNodePath;
}
GraphNode getTailNode() {
return node;
}
List<GraphNode> toPath() {
List<GraphNode> path = new ArrayList<>();
for (LinkedPathNode lpn = this; lpn != null; lpn = lpn.getPreviousLinkedPathNode()) {
path.add(lpn.node);
}
Collections.reverse(path);
return path;
}
@Override
public int compareTo(LinkedPathNode o) {
return Double.compare(cost, o.cost);
}
}
io.github.coderodde.prob.KmostProbablePathsFinderTest.java:
package io.github.coderodde.prob;
import java.util.List;
import org.junit.Test;
import static org.junit.Assert.*;
public class KmostProbablePathsFinderTest {
@Test
public void findAllThreePathsWithK2() {
// Here, there is 3 paths, only 2 most probable will be returned:
GraphNode source = new GraphNode(0);
GraphNode target = new GraphNode(1);
GraphNode upper1 = new GraphNode(2);
GraphNode upper2 = new GraphNode(3);
GraphNode middle = new GraphNode(4);
GraphNode lower = new GraphNode(5);
source.connectTo(upper1, 0.99);
upper1.connectTo(upper2, 0.97);
upper2.connectTo(target, 0.98);
source.connectTo(middle, 0.8);
middle.connectTo(target, 0.9);
source.connectTo(lower, 0.7);
lower.connectTo(target, 0.8);
List<Result> results =
new KmostProbablePathsFinder()
.findKmostProbablePaths(source, target, 2);
assertEquals(2, results.size());
Result result1 = results.get(0);
Result result2 = results.get(1);
List<GraphNode> path1 = result1.getPath();
List<GraphNode> path2 = result2.getPath();
assertEquals(4, path1.size());
assertEquals(3, path2.size());
assertEquals(source, path1.get(0));
assertEquals(upper1, path1.get(1));
assertEquals(upper2, path1.get(2));
assertEquals(target, path1.get(3));
assertEquals(result1.getProbability(), 0.99 * 0.97 * 0.98, 0.01);
assertEquals(source, path2.get(0));
assertEquals(middle, path2.get(1));
assertEquals(target, path2.get(2));
assertEquals(result2.getProbability(), 0.8 * 0.9, 0.01);
}
@Test
public void findAllThreePathsWithK3() {
GraphNode source = new GraphNode(0);
GraphNode target = new GraphNode(1);
GraphNode upper1 = new GraphNode(2);
GraphNode upper2 = new GraphNode(3);
GraphNode middle = new GraphNode(4);
GraphNode lower = new GraphNode(5);
source.connectTo(upper1, 0.99);
upper1.connectTo(upper2, 0.97);
upper2.connectTo(target, 0.98);
source.connectTo(middle, 0.8);
middle.connectTo(target, 0.9);
source.connectTo(lower, 0.7);
lower.connectTo(target, 0.8);
List<Result> results =
new KmostProbablePathsFinder()
.findKmostProbablePaths(source, target, 3);
assertEquals(3, results.size());
Result result1 = results.get(0);
Result result2 = results.get(1);
Result result3 = results.get(2);
List<GraphNode> path1 = result1.getPath();
List<GraphNode> path2 = result2.getPath();
List<GraphNode> path3 = result3.getPath();
assertEquals(4, path1.size());
assertEquals(3, path2.size());
assertEquals(3, path3.size());
assertEquals(source, path1.get(0));
assertEquals(upper1, path1.get(1));
assertEquals(upper2, path1.get(2));
assertEquals(target, path1.get(3));
assertEquals(result1.getProbability(), 0.99 * 0.97 * 0.98, 0.01);
assertEquals(source, path2.get(0));
assertEquals(middle, path2.get(1));
assertEquals(target, path2.get(2));
assertEquals(result2.getProbability(), 0.8 * 0.9, 0.01);
assertEquals(source, path3.get(0));
assertEquals(lower, path3.get(1));
assertEquals(target, path3.get(2));
assertEquals(result3.getProbability(), 0.7 * 0.8, 0.01);
}
@Test
public void findAllThreePathsWithK4() {
GraphNode source = new GraphNode(0);
GraphNode target = new GraphNode(1);
GraphNode upper1 = new GraphNode(2);
GraphNode upper2 = new GraphNode(3);
GraphNode middle = new GraphNode(4);
GraphNode lower = new GraphNode(5);
source.connectTo(upper1, 0.99);
upper1.connectTo(upper2, 0.97);
upper2.connectTo(target, 0.98);
source.connectTo(middle, 0.8);
middle.connectTo(target, 0.9);
source.connectTo(lower, 0.7);
lower.connectTo(target, 0.8);
List<Result> results =
new KmostProbablePathsFinder()
.findKmostProbablePaths(source, target, 4);
assertEquals(3, results.size());
Result result1 = results.get(0);
Result result2 = results.get(1);
Result result3 = results.get(2);
List<GraphNode> path1 = result1.getPath();
List<GraphNode> path2 = result2.getPath();
List<GraphNode> path3 = result3.getPath();
assertEquals(4, path1.size());
assertEquals(3, path2.size());
assertEquals(3, path3.size());
assertEquals(source, path1.get(0));
assertEquals(upper1, path1.get(1));
assertEquals(upper2, path1.get(2));
assertEquals(target, path1.get(3));
assertEquals(result1.getProbability(), 0.99 * 0.97 * 0.98, 0.01);
assertEquals(source, path2.get(0));
assertEquals(middle, path2.get(1));
assertEquals(target, path2.get(2));
assertEquals(result2.getProbability(), 0.8 * 0.9, 0.01);
assertEquals(source, path3.get(0));
assertEquals(lower, path3.get(1));
assertEquals(target, path3.get(2));
assertEquals(result3.getProbability(), 0.7 * 0.8, 0.01);
}
@Test
public void overflowTargetNodeCount() {
GraphNode source = new GraphNode(0);
GraphNode target = new GraphNode(1);
GraphNode upper = new GraphNode(2);
GraphNode middle = new GraphNode(3);
GraphNode lower = new GraphNode(4);
source.connectTo(upper, 0.9);
upper.connectTo(target, 0.9);
source.connectTo(middle, 0.9);
middle.connectTo(target, 0.9);
source.connectTo(lower, 0.9);
lower.connectTo(target, 0.9);
List<Result> results =
new KmostProbablePathsFinder()
.findKmostProbablePaths(source, target, 2);
assertEquals(2, results.size());
Result result1 = results.get(0);
Result result2 = results.get(1);
List<GraphNode> path1 = result1.getPath();
List<GraphNode> path2 = result2.getPath();
assertEquals(3, path1.size());
assertEquals(source, path1.get(0));
assertEquals(lower, path1.get(1));
assertEquals(target, path1.get(2));
assertEquals(0.81, result1.getProbability(), 0.01);
assertEquals(3, path2.size());
assertEquals(source, path2.get(0));
assertEquals(middle, path2.get(1));
assertEquals(target, path2.get(2));
assertEquals(0.81, result2.getProbability(), 0.01);
}
}
Critique request
As always, I am eager to hear any constructive commentary on my work. Also, could you review my unit test as well?