What
Beam search is a best-first search algorithm that does not necessary find an optimal path, yet has smaller memory-footprint. In this program, I attempted to answer a question how does it compare to A* and whether bidirectional beam search provides any improvement over unidirectional variant what comes to running time and optimality of the result path.
Code
BeamSearchPathfinder.java
package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import net.coderodde.graph.AbstractGraph;
public final class BeamSearchPathfinder implements Pathfinder {
/**
* The default width of the beam.
*/
private static final int DEFAULT_BEAM_WIDTH = Integer.MAX_VALUE;
/**
* The minimum allowed beam width.
*/
private static final int MINIMUM_BEAM_WIDHT = 1;
/**
* The current beam width.
*/
private int beamWidth = DEFAULT_BEAM_WIDTH;
public int getBeamWidth() {
return beamWidth;
}
public void setBeamWidth(int beamWidth) {
this.beamWidth = Math.max(beamWidth, MINIMUM_BEAM_WIDHT);
}
@Override
public List<Integer> search(AbstractGraph graph,
Integer sourceNode,
Integer targetNode,
HeuristicFunction<Integer> heuristicFunction) {
Objects.requireNonNull(graph, "The input graph is null.");
Objects.requireNonNull(sourceNode, "The source node is null.");
Objects.requireNonNull(targetNode, "The target node is null.");
Objects.requireNonNull(heuristicFunction,
"The heuristic function is null.");
checkNodes(graph, sourceNode, targetNode);
Queue<HeapNode> open = new PriorityQueue<>();
Set<Integer> closed = new HashSet<>();
Map<Integer, Integer> parents = new HashMap<>();
Map<Integer, Double> distances = new HashMap<>();
open.add(new HeapNode(sourceNode, 0.0));
parents.put(sourceNode, null);
distances.put(sourceNode, 0.0);
while (!open.isEmpty()) {
Integer currentNode = open.remove().node;
if (currentNode.equals(targetNode)) {
return tracebackPath(targetNode, parents);
}
if (closed.contains(currentNode)) {
continue;
}
closed.add(currentNode);
List<Integer> successorNodes = getSuccessors(graph,
currentNode,
targetNode,
distances,
heuristicFunction,
beamWidth);
for (Integer childNode : successorNodes) {
if (closed.contains(childNode)) {
continue;
}
double tentativeDistance =
distances.get(currentNode) +
graph.getEdgeWeight(currentNode, childNode);
if (!distances.containsKey(childNode)
|| distances.get(childNode) > tentativeDistance) {
distances.put(childNode, tentativeDistance);
parents.put(childNode, currentNode);
open.add(
new HeapNode(childNode,
tentativeDistance +
heuristicFunction.estimate(
childNode,
targetNode)));
}
}
}
throw new PathNotFoundException(
"Path from " + sourceNode + " to " + targetNode +
" not found.");
}
private static List<Integer>
getSuccessors(AbstractGraph graph,
Integer currentNode,
Integer targetNode,
Map<Integer, Double> distances,
HeuristicFunction<Integer> heuristicFunction,
int beamWidth) {
List<Integer> successors = new ArrayList<>();
Map<Integer, Double> costMap = new HashMap<>();
for (Integer successor : graph.getChildrenOf(currentNode)) {
successors.add(successor);
costMap.put(
successor,
distances.get(currentNode) +
graph.getEdgeWeight(currentNode, successor) +
heuristicFunction.estimate(successor, targetNode));
}
Collections.sort(successors, (a, b) -> {
return Double.compare(costMap.get(a), costMap.get(b));
});
return successors.subList(0, Math.min(successors.size(), beamWidth));
}
private static final class HeapNode implements Comparable<HeapNode> {
Integer node;
double fScore;
HeapNode(Integer node, double fScore) {
this.node = node;
this.fScore = fScore;
}
@Override
public int compareTo(HeapNode o) {
return Double.compare(fScore, o.fScore);
}
}
}
BidirectionalBeamSearchPathfinder.java
package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import net.coderodde.graph.AbstractGraph;
public final class BidirectionalBeamSearchPathfinder implements Pathfinder {
/**
* The default width of the beam.
*/
private static final int DEFAULT_BEAM_WIDTH = Integer.MAX_VALUE;
/**
* The minimum allowed beam width.
*/
private static final int MINIMUM_BEAM_WIDHT = 1;
/**
* The current beam width.
*/
private int beamWidth = DEFAULT_BEAM_WIDTH;
public int getBeamWidth() {
return beamWidth;
}
public void setBeamWidth(int beamWidth) {
this.beamWidth = Math.max(beamWidth, MINIMUM_BEAM_WIDHT);
}
@Override
public List<Integer> search(AbstractGraph graph,
Integer sourceNode,
Integer targetNode,
HeuristicFunction<Integer> heuristicFunction) {
Objects.requireNonNull(graph, "The input graph is null.");
Objects.requireNonNull(sourceNode, "The source node is null.");
Objects.requireNonNull(targetNode, "The target node is null.");
Objects.requireNonNull(heuristicFunction,
"The heuristic function is null.");
checkNodes(graph, sourceNode, targetNode);
Queue<HeapNode> openForward = new PriorityQueue<>();
Queue<HeapNode> openBackward = new PriorityQueue<>();
Set<Integer> closedForward = new HashSet<>();
Set<Integer> closedBackward = new HashSet<>();
Map<Integer, Integer> parentsForward = new HashMap<>();
Map<Integer, Integer> parentsBackward = new HashMap<>();
Map<Integer, Double> distancesForward = new HashMap<>();
Map<Integer, Double> distancesBackward = new HashMap<>();
double bestPathLength = Double.POSITIVE_INFINITY;
Integer touchNode = null;
openForward.add(new HeapNode(sourceNode, 0.0));
openBackward.add(new HeapNode(targetNode, 0.0));
parentsForward.put(sourceNode, null);
parentsBackward.put(targetNode, null);
distancesForward.put(sourceNode, 0.0);
distancesBackward.put(targetNode, 0.0);
while (!openForward.isEmpty() && !openBackward.isEmpty()) {
if (touchNode != null) {
Integer minA = openForward.peek().node;
Integer minB = openBackward.peek().node;
double distanceA = distancesForward.get(minA) +
heuristicFunction.estimate(minA, targetNode);
double distanceB = distancesBackward.get(minB) +
heuristicFunction.estimate(minB, sourceNode);
if (bestPathLength <= Math.max(distanceA, distanceB)) {
return tracebackPath(touchNode,
parentsForward,
parentsBackward);
}
}
if (openForward.size() + closedForward.size() <
openBackward.size() + closedBackward.size()) {
Integer currentNode = openForward.remove().node;
if (closedForward.contains(currentNode)) {
continue;
}
closedForward.add(currentNode);
List<Integer> successors =
getForwardSuccessors(graph,
openBackward.peek().node,
currentNode,
targetNode,
distancesForward,
heuristicFunction,
beamWidth);
for (Integer childNode : successors) {
if (closedForward.contains(childNode)) {
continue;
}
double tentativeScore =
distancesForward.get(currentNode) +
graph.getEdgeWeight(currentNode, childNode);
if (!distancesForward.containsKey(childNode)
|| distancesForward.get(childNode) >
tentativeScore) {
distancesForward.put(childNode, tentativeScore);
parentsForward.put(childNode, currentNode);
openForward.add(
new HeapNode(
childNode,
tentativeScore + heuristicFunction
.estimate(childNode, targetNode)));
if (closedBackward.contains(childNode)) {
double pathLength =
distancesBackward.get(childNode) +
tentativeScore;
if (bestPathLength > pathLength) {
bestPathLength = pathLength;
touchNode = childNode;
}
}
}
}
} else {
Integer currentNode = openBackward.remove().node;
if (closedBackward.contains(currentNode)) {
continue;
}
closedBackward.add(currentNode);
List<Integer> successors =
getBackwardSuccessors(graph,
openForward.peek().node,
currentNode,
sourceNode,
distancesBackward,
heuristicFunction,
beamWidth);
for (Integer parentNode : successors) {
if (closedBackward.contains(parentNode)) {
continue;
}
double tentativeScore =
distancesBackward.get(currentNode) +
graph.getEdgeWeight(parentNode, currentNode);
if (!distancesBackward.containsKey(parentNode)
|| distancesBackward.get(parentNode) >
tentativeScore) {
distancesBackward.put(parentNode, tentativeScore);
parentsBackward.put(parentNode, currentNode);
openBackward.add(
new HeapNode(
parentNode,
tentativeScore + heuristicFunction
.estimate(parentNode, sourceNode)));
if (closedForward.contains(parentNode)) {
double pathLength =
distancesForward.get(parentNode) +
tentativeScore;
if (bestPathLength > pathLength) {
bestPathLength = pathLength;
touchNode = parentNode;
}
}
}
}
}
}
throw new PathNotFoundException(
"Target node " + targetNode + " is not reachable from " +
sourceNode);
}
private static List<Integer>
getForwardSuccessors(AbstractGraph graph,
Integer backwardTop,
Integer currentNode,
Integer targetNode,
Map<Integer, Double> distances,
HeuristicFunction<Integer> heuristicFunction,
int beamWidth) {
List<Integer> successors = new ArrayList<>();
Map<Integer, Double> costMap = new HashMap<>();
for (Integer successor : graph.getChildrenOf(currentNode)) {
successors.add(successor);
costMap.put(
successor,
distances.get(currentNode) +
graph.getEdgeWeight(currentNode, successor) +
heuristicFunction.estimate(successor, backwardTop));
}
Collections.sort(successors, (a, b) -> {
return Double.compare(costMap.get(a), costMap.get(b));
});
return successors.subList(0, Math.min(successors.size(),
beamWidth));
}
private static List<Integer>
getBackwardSuccessors(AbstractGraph graph,
Integer forwardTop,
Integer currentNode,
Integer sourceNode,
Map<Integer, Double> distances,
HeuristicFunction<Integer> heuristicFunction,
int beamWidth) {
List<Integer> successors = new ArrayList<>();
Map<Integer, Double> costMap = new HashMap<>();
for (Integer successor : graph.getParentsOf(currentNode)) {
successors.add(successor);
costMap.put(
successor,
distances.get(currentNode) +
graph.getEdgeWeight(successor, currentNode) +
heuristicFunction.estimate(successor, forwardTop));
}
Collections.sort(successors, (a, b) -> {
return Double.compare(costMap.get(a), costMap.get(b));
});
return successors.subList(0, Math.min(successors.size(),
beamWidth));
}
private static final class HeapNode implements Comparable<HeapNode> {
Integer node;
double fScore;
HeapNode(Integer node, double fScore) {
this.node = node;
this.fScore = fScore;
}
@Override
public int compareTo(HeapNode o) {
return Double.compare(fScore, o.fScore);
}
}
}
Coordinates.java
package net.coderodde.graph.pathfinding.beamsearch;
import java.awt.geom.Point2D;
import java.util.HashMap;
import java.util.Map;
public final class Coordinates {
private final Map<Integer, Point2D.Double> map = new HashMap<>();
public Point2D.Double get(Integer node) {
return map.get(node);
}
public void put(Integer node, Point2D.Double point) {
map.put(node, point);
}
}
DefaultHeuristicFunction.java
package net.coderodde.graph.pathfinding.beamsearch;
import java.util.Objects;
public final class DefaultHeuristicFunction
implements HeuristicFunction<Integer> {
private final Coordinates coordinates;
public DefaultHeuristicFunction(Coordinates coordinates) {
this.coordinates =
Objects.requireNonNull(coordinates,
"The coordinate function is null.");
}
@Override
public double estimate(Integer source, Integer target) {
return coordinates.get(source).distance(coordinates.get(target));
}
}
Demo.java
package net.coderodde.graph.pathfinding.beamsearch;
import java.awt.geom.Point2D;
import java.util.List;
import java.util.Random;
import net.coderodde.graph.AbstractGraph;
import net.coderodde.graph.DirectedGraph;
public final class Demo {
/**
* The width of the plane containing all the graph nodes.
*/
private static final double GRAPH_LAYOUT_WIDTH = 1000.0;
/**
* The height of the plane containing all the graph nodes.
*/
private static final double GRAPH_LAYOUT_HEIGHT = 1000.0;
/**
* Given two nodes {@code u} and {@code v}, the cost of the arc
* {@code (u,v)} will be their Euclidean distance times this factor.
*/
private static final double ARC_LENGTH_FACTOR = 1.2;
/**
* The number of nodes in the graph.
*/
private static final int NODES = 250_000;
/**
* The number of arcs in the graph.
*/
private static final int ARCS = 1_500_000;
/**
* The beam width used in the demonstration.
*/
private static final int BEAM_WIDTH = 4;
public static void main(String[] args) {
long seed = System.currentTimeMillis();
Random random = new Random(seed);
System.out.println("Seed = " + seed);
GraphData data = createDirectedGraph(NODES, ARCS, random);
warmup(data.graph, data.heuristicFunction, new Random(seed));
benchmark(data.graph, data.heuristicFunction, new Random(seed));
}
private static final void
warmup(DirectedGraph graph,
HeuristicFunction<Integer> heuristicFunction,
Random random) {
perform(graph, heuristicFunction, random, false);
}
private static final void
benchmark(DirectedGraph graph,
HeuristicFunction<Integer> heuristicFunction,
Random random) {
perform(graph, heuristicFunction, random, true);
}
private static final void
perform(DirectedGraph graph,
HeuristicFunction<Integer> heuristicFunction,
Random random,
boolean output) {
Integer sourceNode = random.nextInt(graph.size());
Integer targetNode = random.nextInt(graph.size());
BeamSearchPathfinder finder1 = new BeamSearchPathfinder();
BidirectionalBeamSearchPathfinder finder2 =
new BidirectionalBeamSearchPathfinder();
finder1.setBeamWidth(BEAM_WIDTH);
finder2.setBeamWidth(BEAM_WIDTH);
long start = System.currentTimeMillis();
List<Integer> path1 = finder1.search(graph,
sourceNode,
targetNode,
heuristicFunction);
long end = System.currentTimeMillis();
if (output) {
System.out.println(finder1.getClass().getSimpleName() + ":");
System.out.println("Path: " + path1 + ", length = " +
getPathLength(path1, graph));
System.out.println("Time: " + (end - start) + " milliseconds.");
}
finder1.setBeamWidth(Integer.MAX_VALUE);
start = System.currentTimeMillis();
List<Integer> optimalPath = finder1.search(graph,
sourceNode,
targetNode,
heuristicFunction);
end = System.currentTimeMillis();
if (output) {
System.out.println("A*:");
System.out.println("Path: " + optimalPath + ", length = " +
getPathLength(optimalPath, graph));
System.out.println("Time: " + (end - start) + " milliseconds.");
}
start = System.currentTimeMillis();
List<Integer> path2 = finder2.search(graph,
sourceNode,
targetNode,
heuristicFunction);
end = System.currentTimeMillis();
if (output) {
System.out.println(finder2.getClass().getSimpleName() + ":");
System.out.println("Path: " + path2 + ", length = " +
getPathLength(path2, graph));
System.out.println("Time: " + (end - start) + " milliseconds.");
}
}
private static double getPathLength(List<Integer> path,
AbstractGraph graph) {
double sum = 0.0;
for (int i = 0; i < path.size() - 1; ++i) {
sum += graph.getEdgeWeight(path.get(i), path.get(i + 1));
}
return sum;
}
private static final class GraphData {
DirectedGraph graph;
HeuristicFunction<Integer> heuristicFunction;
}
private static final Coordinates getRandomCoordinates(AbstractGraph graph,
Random random) {
Coordinates coordinates = new Coordinates();
for (Integer node : graph.getAllNodes()) {
coordinates.put(node, createRandomPoint(GRAPH_LAYOUT_WIDTH,
GRAPH_LAYOUT_HEIGHT,
random));
}
return coordinates;
}
private static final Point2D.Double
createRandomPoint(double graphLayoutWidth,
double graphLayoutHeight,
Random random) {
return new Point2D.Double(random.nextDouble() * graphLayoutWidth,
random.nextDouble() * graphLayoutHeight);
}
private static final GraphData createDirectedGraph(int nodes,
int arcs,
Random random) {
DirectedGraph graph = new DirectedGraph();
for (int node = 0; node < nodes; ++node) {
graph.addNode(node);
}
Coordinates coordinates = getRandomCoordinates(graph, random);
HeuristicFunction<Integer> heuristicFunction =
new DefaultHeuristicFunction(coordinates);
for (int arc = 0; arc < arcs; ++arc) {
Integer source = random.nextInt(nodes);
Integer target = random.nextInt(nodes);
double euclideanDistance = heuristicFunction.estimate(source,
target);
graph.addEdge(source,
target,
ARC_LENGTH_FACTOR * euclideanDistance);
}
GraphData data = new GraphData();
data.graph = graph;
data.heuristicFunction = heuristicFunction;
return data;
}
}
HeuristicFunction.java
package net.coderodde.graph.pathfinding.beamsearch;
/**
* This interface defines the API for heuristic functions.
*
* @author Rodion "rodde" Efremov
* @param <Node> the actual node type.
* @version 1.6 (Sep 10, 2017)
*/
public interface HeuristicFunction<Node> {
/**
* Returns an optimistic estimate for the path from {@code source} to
* {@code target}.
*
* @param source the source node.
* @param target the target node.
* @return distance estimate.
*/
public double estimate(Node source, Node target);
}
PathNotFoundException.java
package net.coderodde.graph.pathfinding.beamsearch;
public final class PathNotFoundException extends RuntimeException {
public PathNotFoundException(String message) {
super(message);
}
}
Pathfinder.java
package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import net.coderodde.graph.AbstractGraph;
public interface Pathfinder {
/**
* Searches for a path from {@code source} to {@code target} in
* {@code graph} using {@code heuristicFunction} as a guide.
*
* @param graph the graph to search in.
* @param source the source (start) node.
* @param target the target (goal) node.
* @param heuristicFunction the heuristic function.
* @return
*/
public List<Integer> search(AbstractGraph graph,
Integer source,
Integer target,
HeuristicFunction<Integer> heuristicFunction);
default List<Integer> tracebackPath(Integer target,
Map<Integer, Integer> parents) {
List<Integer> path = new ArrayList<>();
Integer currentNode = target;
while (currentNode != null) {
path.add(currentNode);
currentNode = parents.get(currentNode);
}
Collections.<Integer>reverse(path);
return path;
}
default List<Integer> tracebackPath(Integer touch,
Map<Integer, Integer> forwardParents,
Map<Integer, Integer> backwardParents) {
List<Integer> prefixPath = tracebackPath(touch, forwardParents);
Integer currentNode = backwardParents.get(touch);
while (currentNode != null) {
prefixPath.add(currentNode);
currentNode = backwardParents.get(currentNode);
}
return prefixPath;
}
/**
* Makes sure that both {@code source} and {@code target} are in the
* {@code graph}.
*
* @param graph the graph.
* @param source the source node.
* @param target the target node.
*/
default void checkNodes(AbstractGraph graph, Integer source, Integer target) {
if (!graph.hasNode(source)) {
throw new IllegalArgumentException(
"The source node " + source + " is not in the graph.");
}
if (!graph.hasNode(target)) {
throw new IllegalArgumentException(
"The target node " + target + " is not in the graph.");
}
}
}
Dependency This program relies on this Maven project.
Critique request
Please tell me anything that comes to mind.
1 Answer 1
This looks good, I don't have much to say beyond some superficial syntax/formatting things:
Typo in BeamSearchPathfinder
and BidirectionalBeamSearchPathfinder
:
private static final int MINIMUM_BEAM_WIDHT = 1;
WIDHT
-> WIDTH
here and in setBeamWidth()
open.add(
new HeapNode(childNode,
tentativeDistance +
heuristicFunction.estimate(
childNode,
targetNode)));
Much as I love Lisp, and this is totally subjective, but I find this indentation and parenthesis placement kinda weird (at least the initial lining up under the parenthesis) and would lean towards something like
open.add(
new HeapNode(
childNode,
tentativeDistance + heuristicFunction.estimate(childNode, targetNode)
)
);
I believe you could choose to annotate HeuristicFunction
as a @FunctionalInterface
, but it's not a big deal.
Collections.sort(successors, (a, b) -> {
return Double.compare(costMap.get(a), costMap.get(b));
});
This can use a single-expression lambda:
Collections.sort(successors,
(a, b) -> Double.compare(costMap.get(a), costMap.get(b)));
which can then be further simplified (thanks @Roland Illig!):
successors.sort(Comparator.comparing(costMap::get));
-
2\$\begingroup\$ Do you know
Comparator.comparing
? That could be even simpler.successors.sort(Comparator.comparing(costMap::get))
. \$\endgroup\$Roland Illig– Roland Illig2017年09月12日 11:09:02 +00:00Commented Sep 12, 2017 at 11:09