5
\$\begingroup\$

What

Beam search is a best-first search algorithm that does not necessary find an optimal path, yet has smaller memory-footprint. In this program, I attempted to answer a question how does it compare to A* and whether bidirectional beam search provides any improvement over unidirectional variant what comes to running time and optimality of the result path.

Code

BeamSearchPathfinder.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import net.coderodde.graph.AbstractGraph;
public final class BeamSearchPathfinder implements Pathfinder {
 /**
 * The default width of the beam.
 */
 private static final int DEFAULT_BEAM_WIDTH = Integer.MAX_VALUE;
 /**
 * The minimum allowed beam width.
 */
 private static final int MINIMUM_BEAM_WIDHT = 1;
 /**
 * The current beam width.
 */
 private int beamWidth = DEFAULT_BEAM_WIDTH;
 public int getBeamWidth() {
 return beamWidth;
 }
 public void setBeamWidth(int beamWidth) {
 this.beamWidth = Math.max(beamWidth, MINIMUM_BEAM_WIDHT);
 }
 @Override
 public List<Integer> search(AbstractGraph graph,
 Integer sourceNode,
 Integer targetNode, 
 HeuristicFunction<Integer> heuristicFunction) {
 Objects.requireNonNull(graph, "The input graph is null.");
 Objects.requireNonNull(sourceNode, "The source node is null.");
 Objects.requireNonNull(targetNode, "The target node is null.");
 Objects.requireNonNull(heuristicFunction,
 "The heuristic function is null.");
 checkNodes(graph, sourceNode, targetNode);
 Queue<HeapNode> open = new PriorityQueue<>();
 Set<Integer> closed = new HashSet<>();
 Map<Integer, Integer> parents = new HashMap<>();
 Map<Integer, Double> distances = new HashMap<>();
 open.add(new HeapNode(sourceNode, 0.0));
 parents.put(sourceNode, null);
 distances.put(sourceNode, 0.0);
 while (!open.isEmpty()) {
 Integer currentNode = open.remove().node;
 if (currentNode.equals(targetNode)) {
 return tracebackPath(targetNode, parents);
 }
 if (closed.contains(currentNode)) {
 continue;
 }
 closed.add(currentNode);
 List<Integer> successorNodes = getSuccessors(graph,
 currentNode,
 targetNode,
 distances,
 heuristicFunction,
 beamWidth);
 for (Integer childNode : successorNodes) {
 if (closed.contains(childNode)) {
 continue;
 }
 double tentativeDistance = 
 distances.get(currentNode) +
 graph.getEdgeWeight(currentNode, childNode);
 if (!distances.containsKey(childNode)
 || distances.get(childNode) > tentativeDistance) {
 distances.put(childNode, tentativeDistance);
 parents.put(childNode, currentNode);
 open.add(
 new HeapNode(childNode, 
 tentativeDistance + 
 heuristicFunction.estimate(
 childNode, 
 targetNode)));
 }
 }
 }
 throw new PathNotFoundException(
 "Path from " + sourceNode + " to " + targetNode + 
 " not found.");
 }
 private static List<Integer> 
 getSuccessors(AbstractGraph graph,
 Integer currentNode,
 Integer targetNode,
 Map<Integer, Double> distances,
 HeuristicFunction<Integer> heuristicFunction,
 int beamWidth) {
 List<Integer> successors = new ArrayList<>();
 Map<Integer, Double> costMap = new HashMap<>();
 for (Integer successor : graph.getChildrenOf(currentNode)) {
 successors.add(successor);
 costMap.put(
 successor, 
 distances.get(currentNode) + 
 graph.getEdgeWeight(currentNode, successor) +
 heuristicFunction.estimate(successor, targetNode));
 }
 Collections.sort(successors, (a, b) -> {
 return Double.compare(costMap.get(a), costMap.get(b));
 });
 return successors.subList(0, Math.min(successors.size(), beamWidth));
 }
 private static final class HeapNode implements Comparable<HeapNode> {
 Integer node;
 double fScore;
 HeapNode(Integer node, double fScore) {
 this.node = node;
 this.fScore = fScore;
 }
 @Override
 public int compareTo(HeapNode o) {
 return Double.compare(fScore, o.fScore);
 }
 }
}

BidirectionalBeamSearchPathfinder.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import net.coderodde.graph.AbstractGraph;
public final class BidirectionalBeamSearchPathfinder implements Pathfinder {
 /**
 * The default width of the beam.
 */
 private static final int DEFAULT_BEAM_WIDTH = Integer.MAX_VALUE;
 /**
 * The minimum allowed beam width.
 */
 private static final int MINIMUM_BEAM_WIDHT = 1;
 /**
 * The current beam width.
 */
 private int beamWidth = DEFAULT_BEAM_WIDTH;
 public int getBeamWidth() {
 return beamWidth;
 }
 public void setBeamWidth(int beamWidth) {
 this.beamWidth = Math.max(beamWidth, MINIMUM_BEAM_WIDHT);
 }
 @Override
 public List<Integer> search(AbstractGraph graph, 
 Integer sourceNode, 
 Integer targetNode, 
 HeuristicFunction<Integer> heuristicFunction) {
 Objects.requireNonNull(graph, "The input graph is null.");
 Objects.requireNonNull(sourceNode, "The source node is null.");
 Objects.requireNonNull(targetNode, "The target node is null.");
 Objects.requireNonNull(heuristicFunction,
 "The heuristic function is null.");
 checkNodes(graph, sourceNode, targetNode);
 Queue<HeapNode> openForward = new PriorityQueue<>();
 Queue<HeapNode> openBackward = new PriorityQueue<>();
 Set<Integer> closedForward = new HashSet<>();
 Set<Integer> closedBackward = new HashSet<>();
 Map<Integer, Integer> parentsForward = new HashMap<>();
 Map<Integer, Integer> parentsBackward = new HashMap<>();
 Map<Integer, Double> distancesForward = new HashMap<>();
 Map<Integer, Double> distancesBackward = new HashMap<>();
 double bestPathLength = Double.POSITIVE_INFINITY;
 Integer touchNode = null;
 openForward.add(new HeapNode(sourceNode, 0.0));
 openBackward.add(new HeapNode(targetNode, 0.0));
 parentsForward.put(sourceNode, null);
 parentsBackward.put(targetNode, null);
 distancesForward.put(sourceNode, 0.0);
 distancesBackward.put(targetNode, 0.0);
 while (!openForward.isEmpty() && !openBackward.isEmpty()) {
 if (touchNode != null) {
 Integer minA = openForward.peek().node;
 Integer minB = openBackward.peek().node;
 double distanceA = distancesForward.get(minA) +
 heuristicFunction.estimate(minA, targetNode);
 double distanceB = distancesBackward.get(minB) +
 heuristicFunction.estimate(minB, sourceNode);
 if (bestPathLength <= Math.max(distanceA, distanceB)) {
 return tracebackPath(touchNode, 
 parentsForward, 
 parentsBackward);
 }
 }
 if (openForward.size() + closedForward.size() <
 openBackward.size() + closedBackward.size()) {
 Integer currentNode = openForward.remove().node;
 if (closedForward.contains(currentNode)) {
 continue;
 }
 closedForward.add(currentNode);
 List<Integer> successors = 
 getForwardSuccessors(graph,
 openBackward.peek().node,
 currentNode, 
 targetNode,
 distancesForward,
 heuristicFunction,
 beamWidth);
 for (Integer childNode : successors) {
 if (closedForward.contains(childNode)) {
 continue;
 }
 double tentativeScore = 
 distancesForward.get(currentNode) +
 graph.getEdgeWeight(currentNode, childNode);
 if (!distancesForward.containsKey(childNode) 
 || distancesForward.get(childNode) > 
 tentativeScore) {
 distancesForward.put(childNode, tentativeScore);
 parentsForward.put(childNode, currentNode);
 openForward.add(
 new HeapNode(
 childNode, 
 tentativeScore + heuristicFunction
 .estimate(childNode, targetNode)));
 if (closedBackward.contains(childNode)) {
 double pathLength = 
 distancesBackward.get(childNode) +
 tentativeScore;
 if (bestPathLength > pathLength) {
 bestPathLength = pathLength;
 touchNode = childNode;
 }
 }
 }
 }
 } else {
 Integer currentNode = openBackward.remove().node;
 if (closedBackward.contains(currentNode)) {
 continue;
 }
 closedBackward.add(currentNode);
 List<Integer> successors = 
 getBackwardSuccessors(graph,
 openForward.peek().node,
 currentNode, 
 sourceNode,
 distancesBackward,
 heuristicFunction,
 beamWidth);
 for (Integer parentNode : successors) {
 if (closedBackward.contains(parentNode)) {
 continue;
 }
 double tentativeScore = 
 distancesBackward.get(currentNode) +
 graph.getEdgeWeight(parentNode, currentNode);
 if (!distancesBackward.containsKey(parentNode)
 || distancesBackward.get(parentNode) >
 tentativeScore) {
 distancesBackward.put(parentNode, tentativeScore);
 parentsBackward.put(parentNode, currentNode);
 openBackward.add(
 new HeapNode(
 parentNode,
 tentativeScore + heuristicFunction
 .estimate(parentNode, sourceNode)));
 if (closedForward.contains(parentNode)) {
 double pathLength = 
 distancesForward.get(parentNode) + 
 tentativeScore;
 if (bestPathLength > pathLength) {
 bestPathLength = pathLength;
 touchNode = parentNode;
 }
 }
 }
 }
 }
 }
 throw new PathNotFoundException(
 "Target node " + targetNode + " is not reachable from " +
 sourceNode);
 }
 private static List<Integer> 
 getForwardSuccessors(AbstractGraph graph,
 Integer backwardTop,
 Integer currentNode,
 Integer targetNode,
 Map<Integer, Double> distances,
 HeuristicFunction<Integer> heuristicFunction,
 int beamWidth) {
 List<Integer> successors = new ArrayList<>();
 Map<Integer, Double> costMap = new HashMap<>();
 for (Integer successor : graph.getChildrenOf(currentNode)) {
 successors.add(successor);
 costMap.put(
 successor,
 distances.get(currentNode) + 
 graph.getEdgeWeight(currentNode, successor) +
 heuristicFunction.estimate(successor, backwardTop));
 }
 Collections.sort(successors, (a, b) -> {
 return Double.compare(costMap.get(a), costMap.get(b));
 });
 return successors.subList(0, Math.min(successors.size(), 
 beamWidth));
 }
 private static List<Integer>
 getBackwardSuccessors(AbstractGraph graph,
 Integer forwardTop,
 Integer currentNode, 
 Integer sourceNode,
 Map<Integer, Double> distances,
 HeuristicFunction<Integer> heuristicFunction,
 int beamWidth) {
 List<Integer> successors = new ArrayList<>();
 Map<Integer, Double> costMap = new HashMap<>();
 for (Integer successor : graph.getParentsOf(currentNode)) {
 successors.add(successor);
 costMap.put(
 successor,
 distances.get(currentNode) +
 graph.getEdgeWeight(successor, currentNode) +
 heuristicFunction.estimate(successor, forwardTop));
 }
 Collections.sort(successors, (a, b) -> {
 return Double.compare(costMap.get(a), costMap.get(b));
 });
 return successors.subList(0, Math.min(successors.size(),
 beamWidth));
 }
 private static final class HeapNode implements Comparable<HeapNode> {
 Integer node;
 double fScore;
 HeapNode(Integer node, double fScore) {
 this.node = node;
 this.fScore = fScore;
 }
 @Override
 public int compareTo(HeapNode o) {
 return Double.compare(fScore, o.fScore);
 }
 }
}

Coordinates.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.awt.geom.Point2D;
import java.util.HashMap;
import java.util.Map;
public final class Coordinates {
 private final Map<Integer, Point2D.Double> map = new HashMap<>();
 public Point2D.Double get(Integer node) {
 return map.get(node);
 }
 public void put(Integer node, Point2D.Double point) {
 map.put(node, point);
 }
}

DefaultHeuristicFunction.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.util.Objects;
public final class DefaultHeuristicFunction 
 implements HeuristicFunction<Integer> {
 private final Coordinates coordinates;
 public DefaultHeuristicFunction(Coordinates coordinates) {
 this.coordinates = 
 Objects.requireNonNull(coordinates, 
 "The coordinate function is null.");
 }
 @Override
 public double estimate(Integer source, Integer target) {
 return coordinates.get(source).distance(coordinates.get(target));
 }
}

Demo.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.awt.geom.Point2D;
import java.util.List;
import java.util.Random;
import net.coderodde.graph.AbstractGraph;
import net.coderodde.graph.DirectedGraph;
public final class Demo {
 /**
 * The width of the plane containing all the graph nodes.
 */
 private static final double GRAPH_LAYOUT_WIDTH = 1000.0;
 /**
 * The height of the plane containing all the graph nodes.
 */
 private static final double GRAPH_LAYOUT_HEIGHT = 1000.0;
 /**
 * Given two nodes {@code u} and {@code v}, the cost of the arc
 * {@code (u,v)} will be their Euclidean distance times this factor.
 */
 private static final double ARC_LENGTH_FACTOR = 1.2;
 /**
 * The number of nodes in the graph.
 */
 private static final int NODES = 250_000;
 /**
 * The number of arcs in the graph.
 */
 private static final int ARCS = 1_500_000;
 /**
 * The beam width used in the demonstration.
 */
 private static final int BEAM_WIDTH = 4;
 public static void main(String[] args) {
 long seed = System.currentTimeMillis();
 Random random = new Random(seed);
 System.out.println("Seed = " + seed);
 GraphData data = createDirectedGraph(NODES, ARCS, random);
 warmup(data.graph, data.heuristicFunction, new Random(seed));
 benchmark(data.graph, data.heuristicFunction, new Random(seed));
 }
 private static final void 
 warmup(DirectedGraph graph, 
 HeuristicFunction<Integer> heuristicFunction,
 Random random) {
 perform(graph, heuristicFunction, random, false);
 }
 private static final void 
 benchmark(DirectedGraph graph, 
 HeuristicFunction<Integer> heuristicFunction,
 Random random) {
 perform(graph, heuristicFunction, random, true);
 }
 private static final void 
 perform(DirectedGraph graph, 
 HeuristicFunction<Integer> heuristicFunction,
 Random random,
 boolean output) {
 Integer sourceNode = random.nextInt(graph.size());
 Integer targetNode = random.nextInt(graph.size());
 BeamSearchPathfinder finder1 = new BeamSearchPathfinder();
 BidirectionalBeamSearchPathfinder finder2 = 
 new BidirectionalBeamSearchPathfinder();
 finder1.setBeamWidth(BEAM_WIDTH);
 finder2.setBeamWidth(BEAM_WIDTH);
 long start = System.currentTimeMillis();
 List<Integer> path1 = finder1.search(graph,
 sourceNode,
 targetNode,
 heuristicFunction);
 long end = System.currentTimeMillis();
 if (output) {
 System.out.println(finder1.getClass().getSimpleName() + ":");
 System.out.println("Path: " + path1 + ", length = " +
 getPathLength(path1, graph));
 System.out.println("Time: " + (end - start) + " milliseconds.");
 }
 finder1.setBeamWidth(Integer.MAX_VALUE);
 start = System.currentTimeMillis();
 List<Integer> optimalPath = finder1.search(graph, 
 sourceNode, 
 targetNode, 
 heuristicFunction);
 end = System.currentTimeMillis();
 if (output) {
 System.out.println("A*:");
 System.out.println("Path: " + optimalPath + ", length = " +
 getPathLength(optimalPath, graph));
 System.out.println("Time: " + (end - start) + " milliseconds.");
 }
 start = System.currentTimeMillis();
 List<Integer> path2 = finder2.search(graph,
 sourceNode, 
 targetNode, 
 heuristicFunction);
 end = System.currentTimeMillis();
 if (output) {
 System.out.println(finder2.getClass().getSimpleName() + ":");
 System.out.println("Path: " + path2 + ", length = " + 
 getPathLength(path2, graph));
 System.out.println("Time: " + (end - start) + " milliseconds.");
 }
 }
 private static double getPathLength(List<Integer> path,
 AbstractGraph graph) {
 double sum = 0.0;
 for (int i = 0; i < path.size() - 1; ++i) {
 sum += graph.getEdgeWeight(path.get(i), path.get(i + 1));
 }
 return sum;
 }
 private static final class GraphData {
 DirectedGraph graph;
 HeuristicFunction<Integer> heuristicFunction;
 }
 private static final Coordinates getRandomCoordinates(AbstractGraph graph,
 Random random) {
 Coordinates coordinates = new Coordinates();
 for (Integer node : graph.getAllNodes()) {
 coordinates.put(node, createRandomPoint(GRAPH_LAYOUT_WIDTH,
 GRAPH_LAYOUT_HEIGHT,
 random));
 }
 return coordinates;
 }
 private static final Point2D.Double
 createRandomPoint(double graphLayoutWidth,
 double graphLayoutHeight,
 Random random) {
 return new Point2D.Double(random.nextDouble() * graphLayoutWidth,
 random.nextDouble() * graphLayoutHeight);
 }
 private static final GraphData createDirectedGraph(int nodes,
 int arcs,
 Random random) {
 DirectedGraph graph = new DirectedGraph();
 for (int node = 0; node < nodes; ++node) {
 graph.addNode(node);
 }
 Coordinates coordinates = getRandomCoordinates(graph, random);
 HeuristicFunction<Integer> heuristicFunction =
 new DefaultHeuristicFunction(coordinates);
 for (int arc = 0; arc < arcs; ++arc) {
 Integer source = random.nextInt(nodes);
 Integer target = random.nextInt(nodes);
 double euclideanDistance = heuristicFunction.estimate(source,
 target);
 graph.addEdge(source,
 target, 
 ARC_LENGTH_FACTOR * euclideanDistance);
 }
 GraphData data = new GraphData();
 data.graph = graph;
 data.heuristicFunction = heuristicFunction;
 return data;
 }
}

HeuristicFunction.java

package net.coderodde.graph.pathfinding.beamsearch;
/**
 * This interface defines the API for heuristic functions.
 * 
 * @author Rodion "rodde" Efremov
 * @param <Node> the actual node type.
 * @version 1.6 (Sep 10, 2017)
 */
public interface HeuristicFunction<Node> {
 /**
 * Returns an optimistic estimate for the path from {@code source} to 
 * {@code target}.
 * 
 * @param source the source node.
 * @param target the target node.
 * @return distance estimate.
 */
 public double estimate(Node source, Node target);
}

PathNotFoundException.java

package net.coderodde.graph.pathfinding.beamsearch;
public final class PathNotFoundException extends RuntimeException {
 public PathNotFoundException(String message) {
 super(message);
 }
}

Pathfinder.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import net.coderodde.graph.AbstractGraph;
public interface Pathfinder {
 /**
 * Searches for a path from {@code source} to {@code target} in 
 * {@code graph} using {@code heuristicFunction} as a guide.
 * 
 * @param graph the graph to search in.
 * @param source the source (start) node.
 * @param target the target (goal) node.
 * @param heuristicFunction the heuristic function.
 * @return 
 */
 public List<Integer> search(AbstractGraph graph,
 Integer source, 
 Integer target,
 HeuristicFunction<Integer> heuristicFunction);
 default List<Integer> tracebackPath(Integer target,
 Map<Integer, Integer> parents) {
 List<Integer> path = new ArrayList<>();
 Integer currentNode = target;
 while (currentNode != null) {
 path.add(currentNode);
 currentNode = parents.get(currentNode);
 }
 Collections.<Integer>reverse(path);
 return path;
 }
 default List<Integer> tracebackPath(Integer touch, 
 Map<Integer, Integer> forwardParents,
 Map<Integer, Integer> backwardParents) {
 List<Integer> prefixPath = tracebackPath(touch, forwardParents);
 Integer currentNode = backwardParents.get(touch);
 while (currentNode != null) {
 prefixPath.add(currentNode);
 currentNode = backwardParents.get(currentNode);
 }
 return prefixPath;
 }
 /**
 * Makes sure that both {@code source} and {@code target} are in the
 * {@code graph}.
 * 
 * @param graph the graph.
 * @param source the source node.
 * @param target the target node.
 */
 default void checkNodes(AbstractGraph graph, Integer source, Integer target) {
 if (!graph.hasNode(source)) {
 throw new IllegalArgumentException(
 "The source node " + source + " is not in the graph.");
 }
 if (!graph.hasNode(target)) {
 throw new IllegalArgumentException(
 "The target node " + target + " is not in the graph.");
 }
 }
}

Dependency This program relies on this Maven project.

Critique request

Please tell me anything that comes to mind.

asked Sep 11, 2017 at 17:15
\$\endgroup\$

1 Answer 1

4
\$\begingroup\$

This looks good, I don't have much to say beyond some superficial syntax/formatting things:


Typo in BeamSearchPathfinder and BidirectionalBeamSearchPathfinder:

 private static final int MINIMUM_BEAM_WIDHT = 1;

WIDHT -> WIDTH here and in setBeamWidth()


open.add(
 new HeapNode(childNode, 
 tentativeDistance + 
 heuristicFunction.estimate(
 childNode, 
 targetNode)));

Much as I love Lisp, and this is totally subjective, but I find this indentation and parenthesis placement kinda weird (at least the initial lining up under the parenthesis) and would lean towards something like

open.add(
 new HeapNode(
 childNode,
 tentativeDistance + heuristicFunction.estimate(childNode, targetNode)
 )
);

I believe you could choose to annotate HeuristicFunction as a @FunctionalInterface, but it's not a big deal.


Collections.sort(successors, (a, b) -> {
 return Double.compare(costMap.get(a), costMap.get(b));
});

This can use a single-expression lambda:

Collections.sort(successors,
 (a, b) -> Double.compare(costMap.get(a), costMap.get(b)));

which can then be further simplified (thanks @Roland Illig!):

successors.sort(Comparator.comparing(costMap::get));
answered Sep 12, 2017 at 7:04
\$\endgroup\$
1
  • 2
    \$\begingroup\$ Do you know Comparator.comparing? That could be even simpler. successors.sort(Comparator.comparing(costMap::get)). \$\endgroup\$ Commented Sep 12, 2017 at 11:09

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.