Skip to main content
Code Review

Return to Question

Commonmark migration
Source Link

#What

What

#Code

Code

#Critique request

Critique request

#What

#Code

#Critique request

What

Code

Critique request

Source Link
coderodde
  • 31.7k
  • 15
  • 77
  • 202

Uni- and bidirectional beam search in Java

#What

Beam search is a best-first search algorithm that does not necessary find an optimal path, yet has smaller memory-footprint. In this program, I attempted to answer a question how does it compare to A* and whether bidirectional beam search provides any improvement over unidirectional variant what comes to running time and optimality of the result path.

#Code

BeamSearchPathfinder.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import net.coderodde.graph.AbstractGraph;
public final class BeamSearchPathfinder implements Pathfinder {
 /**
 * The default width of the beam.
 */
 private static final int DEFAULT_BEAM_WIDTH = Integer.MAX_VALUE;
 /**
 * The minimum allowed beam width.
 */
 private static final int MINIMUM_BEAM_WIDHT = 1;
 /**
 * The current beam width.
 */
 private int beamWidth = DEFAULT_BEAM_WIDTH;
 public int getBeamWidth() {
 return beamWidth;
 }
 public void setBeamWidth(int beamWidth) {
 this.beamWidth = Math.max(beamWidth, MINIMUM_BEAM_WIDHT);
 }
 @Override
 public List<Integer> search(AbstractGraph graph,
 Integer sourceNode,
 Integer targetNode, 
 HeuristicFunction<Integer> heuristicFunction) {
 Objects.requireNonNull(graph, "The input graph is null.");
 Objects.requireNonNull(sourceNode, "The source node is null.");
 Objects.requireNonNull(targetNode, "The target node is null.");
 Objects.requireNonNull(heuristicFunction,
 "The heuristic function is null.");
 checkNodes(graph, sourceNode, targetNode);
 Queue<HeapNode> open = new PriorityQueue<>();
 Set<Integer> closed = new HashSet<>();
 Map<Integer, Integer> parents = new HashMap<>();
 Map<Integer, Double> distances = new HashMap<>();
 open.add(new HeapNode(sourceNode, 0.0));
 parents.put(sourceNode, null);
 distances.put(sourceNode, 0.0);
 while (!open.isEmpty()) {
 Integer currentNode = open.remove().node;
 if (currentNode.equals(targetNode)) {
 return tracebackPath(targetNode, parents);
 }
 if (closed.contains(currentNode)) {
 continue;
 }
 closed.add(currentNode);
 List<Integer> successorNodes = getSuccessors(graph,
 currentNode,
 targetNode,
 distances,
 heuristicFunction,
 beamWidth);
 for (Integer childNode : successorNodes) {
 if (closed.contains(childNode)) {
 continue;
 }
 double tentativeDistance = 
 distances.get(currentNode) +
 graph.getEdgeWeight(currentNode, childNode);
 if (!distances.containsKey(childNode)
 || distances.get(childNode) > tentativeDistance) {
 distances.put(childNode, tentativeDistance);
 parents.put(childNode, currentNode);
 open.add(
 new HeapNode(childNode, 
 tentativeDistance + 
 heuristicFunction.estimate(
 childNode, 
 targetNode)));
 }
 }
 }
 throw new PathNotFoundException(
 "Path from " + sourceNode + " to " + targetNode + 
 " not found.");
 }
 private static List<Integer> 
 getSuccessors(AbstractGraph graph,
 Integer currentNode,
 Integer targetNode,
 Map<Integer, Double> distances,
 HeuristicFunction<Integer> heuristicFunction,
 int beamWidth) {
 List<Integer> successors = new ArrayList<>();
 Map<Integer, Double> costMap = new HashMap<>();
 for (Integer successor : graph.getChildrenOf(currentNode)) {
 successors.add(successor);
 costMap.put(
 successor, 
 distances.get(currentNode) + 
 graph.getEdgeWeight(currentNode, successor) +
 heuristicFunction.estimate(successor, targetNode));
 }
 Collections.sort(successors, (a, b) -> {
 return Double.compare(costMap.get(a), costMap.get(b));
 });
 return successors.subList(0, Math.min(successors.size(), beamWidth));
 }
 private static final class HeapNode implements Comparable<HeapNode> {
 Integer node;
 double fScore;
 HeapNode(Integer node, double fScore) {
 this.node = node;
 this.fScore = fScore;
 }
 @Override
 public int compareTo(HeapNode o) {
 return Double.compare(fScore, o.fScore);
 }
 }
}

BidirectionalBeamSearchPathfinder.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import net.coderodde.graph.AbstractGraph;
public final class BidirectionalBeamSearchPathfinder implements Pathfinder {
 /**
 * The default width of the beam.
 */
 private static final int DEFAULT_BEAM_WIDTH = Integer.MAX_VALUE;
 /**
 * The minimum allowed beam width.
 */
 private static final int MINIMUM_BEAM_WIDHT = 1;
 /**
 * The current beam width.
 */
 private int beamWidth = DEFAULT_BEAM_WIDTH;
 public int getBeamWidth() {
 return beamWidth;
 }
 public void setBeamWidth(int beamWidth) {
 this.beamWidth = Math.max(beamWidth, MINIMUM_BEAM_WIDHT);
 }
 @Override
 public List<Integer> search(AbstractGraph graph, 
 Integer sourceNode, 
 Integer targetNode, 
 HeuristicFunction<Integer> heuristicFunction) {
 Objects.requireNonNull(graph, "The input graph is null.");
 Objects.requireNonNull(sourceNode, "The source node is null.");
 Objects.requireNonNull(targetNode, "The target node is null.");
 Objects.requireNonNull(heuristicFunction,
 "The heuristic function is null.");
 checkNodes(graph, sourceNode, targetNode);
 Queue<HeapNode> openForward = new PriorityQueue<>();
 Queue<HeapNode> openBackward = new PriorityQueue<>();
 Set<Integer> closedForward = new HashSet<>();
 Set<Integer> closedBackward = new HashSet<>();
 Map<Integer, Integer> parentsForward = new HashMap<>();
 Map<Integer, Integer> parentsBackward = new HashMap<>();
 Map<Integer, Double> distancesForward = new HashMap<>();
 Map<Integer, Double> distancesBackward = new HashMap<>();
 double bestPathLength = Double.POSITIVE_INFINITY;
 Integer touchNode = null;
 openForward.add(new HeapNode(sourceNode, 0.0));
 openBackward.add(new HeapNode(targetNode, 0.0));
 parentsForward.put(sourceNode, null);
 parentsBackward.put(targetNode, null);
 distancesForward.put(sourceNode, 0.0);
 distancesBackward.put(targetNode, 0.0);
 while (!openForward.isEmpty() && !openBackward.isEmpty()) {
 if (touchNode != null) {
 Integer minA = openForward.peek().node;
 Integer minB = openBackward.peek().node;
 double distanceA = distancesForward.get(minA) +
 heuristicFunction.estimate(minA, targetNode);
 double distanceB = distancesBackward.get(minB) +
 heuristicFunction.estimate(minB, sourceNode);
 if (bestPathLength <= Math.max(distanceA, distanceB)) {
 return tracebackPath(touchNode, 
 parentsForward, 
 parentsBackward);
 }
 }
 if (openForward.size() + closedForward.size() <
 openBackward.size() + closedBackward.size()) {
 Integer currentNode = openForward.remove().node;
 if (closedForward.contains(currentNode)) {
 continue;
 }
 closedForward.add(currentNode);
 List<Integer> successors = 
 getForwardSuccessors(graph,
 openBackward.peek().node,
 currentNode, 
 targetNode,
 distancesForward,
 heuristicFunction,
 beamWidth);
 for (Integer childNode : successors) {
 if (closedForward.contains(childNode)) {
 continue;
 }
 double tentativeScore = 
 distancesForward.get(currentNode) +
 graph.getEdgeWeight(currentNode, childNode);
 if (!distancesForward.containsKey(childNode) 
 || distancesForward.get(childNode) > 
 tentativeScore) {
 distancesForward.put(childNode, tentativeScore);
 parentsForward.put(childNode, currentNode);
 openForward.add(
 new HeapNode(
 childNode, 
 tentativeScore + heuristicFunction
 .estimate(childNode, targetNode)));
 if (closedBackward.contains(childNode)) {
 double pathLength = 
 distancesBackward.get(childNode) +
 tentativeScore;
 if (bestPathLength > pathLength) {
 bestPathLength = pathLength;
 touchNode = childNode;
 }
 }
 }
 }
 } else {
 Integer currentNode = openBackward.remove().node;
 if (closedBackward.contains(currentNode)) {
 continue;
 }
 closedBackward.add(currentNode);
 List<Integer> successors = 
 getBackwardSuccessors(graph,
 openForward.peek().node,
 currentNode, 
 sourceNode,
 distancesBackward,
 heuristicFunction,
 beamWidth);
 for (Integer parentNode : successors) {
 if (closedBackward.contains(parentNode)) {
 continue;
 }
 double tentativeScore = 
 distancesBackward.get(currentNode) +
 graph.getEdgeWeight(parentNode, currentNode);
 if (!distancesBackward.containsKey(parentNode)
 || distancesBackward.get(parentNode) >
 tentativeScore) {
 distancesBackward.put(parentNode, tentativeScore);
 parentsBackward.put(parentNode, currentNode);
 openBackward.add(
 new HeapNode(
 parentNode,
 tentativeScore + heuristicFunction
 .estimate(parentNode, sourceNode)));
 if (closedForward.contains(parentNode)) {
 double pathLength = 
 distancesForward.get(parentNode) + 
 tentativeScore;
 if (bestPathLength > pathLength) {
 bestPathLength = pathLength;
 touchNode = parentNode;
 }
 }
 }
 }
 }
 }
 throw new PathNotFoundException(
 "Target node " + targetNode + " is not reachable from " +
 sourceNode);
 }
 private static List<Integer> 
 getForwardSuccessors(AbstractGraph graph,
 Integer backwardTop,
 Integer currentNode,
 Integer targetNode,
 Map<Integer, Double> distances,
 HeuristicFunction<Integer> heuristicFunction,
 int beamWidth) {
 List<Integer> successors = new ArrayList<>();
 Map<Integer, Double> costMap = new HashMap<>();
 for (Integer successor : graph.getChildrenOf(currentNode)) {
 successors.add(successor);
 costMap.put(
 successor,
 distances.get(currentNode) + 
 graph.getEdgeWeight(currentNode, successor) +
 heuristicFunction.estimate(successor, backwardTop));
 }
 Collections.sort(successors, (a, b) -> {
 return Double.compare(costMap.get(a), costMap.get(b));
 });
 return successors.subList(0, Math.min(successors.size(), 
 beamWidth));
 }
 private static List<Integer>
 getBackwardSuccessors(AbstractGraph graph,
 Integer forwardTop,
 Integer currentNode, 
 Integer sourceNode,
 Map<Integer, Double> distances,
 HeuristicFunction<Integer> heuristicFunction,
 int beamWidth) {
 List<Integer> successors = new ArrayList<>();
 Map<Integer, Double> costMap = new HashMap<>();
 for (Integer successor : graph.getParentsOf(currentNode)) {
 successors.add(successor);
 costMap.put(
 successor,
 distances.get(currentNode) +
 graph.getEdgeWeight(successor, currentNode) +
 heuristicFunction.estimate(successor, forwardTop));
 }
 Collections.sort(successors, (a, b) -> {
 return Double.compare(costMap.get(a), costMap.get(b));
 });
 return successors.subList(0, Math.min(successors.size(),
 beamWidth));
 }
 private static final class HeapNode implements Comparable<HeapNode> {
 Integer node;
 double fScore;
 HeapNode(Integer node, double fScore) {
 this.node = node;
 this.fScore = fScore;
 }
 @Override
 public int compareTo(HeapNode o) {
 return Double.compare(fScore, o.fScore);
 }
 }
}

Coordinates.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.awt.geom.Point2D;
import java.util.HashMap;
import java.util.Map;
public final class Coordinates {
 private final Map<Integer, Point2D.Double> map = new HashMap<>();
 public Point2D.Double get(Integer node) {
 return map.get(node);
 }
 public void put(Integer node, Point2D.Double point) {
 map.put(node, point);
 }
}

DefaultHeuristicFunction.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.util.Objects;
public final class DefaultHeuristicFunction 
 implements HeuristicFunction<Integer> {
 private final Coordinates coordinates;
 public DefaultHeuristicFunction(Coordinates coordinates) {
 this.coordinates = 
 Objects.requireNonNull(coordinates, 
 "The coordinate function is null.");
 }
 @Override
 public double estimate(Integer source, Integer target) {
 return coordinates.get(source).distance(coordinates.get(target));
 }
}

Demo.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.awt.geom.Point2D;
import java.util.List;
import java.util.Random;
import net.coderodde.graph.AbstractGraph;
import net.coderodde.graph.DirectedGraph;
public final class Demo {
 /**
 * The width of the plane containing all the graph nodes.
 */
 private static final double GRAPH_LAYOUT_WIDTH = 1000.0;
 /**
 * The height of the plane containing all the graph nodes.
 */
 private static final double GRAPH_LAYOUT_HEIGHT = 1000.0;
 /**
 * Given two nodes {@code u} and {@code v}, the cost of the arc
 * {@code (u,v)} will be their Euclidean distance times this factor.
 */
 private static final double ARC_LENGTH_FACTOR = 1.2;
 /**
 * The number of nodes in the graph.
 */
 private static final int NODES = 250_000;
 /**
 * The number of arcs in the graph.
 */
 private static final int ARCS = 1_500_000;
 /**
 * The beam width used in the demonstration.
 */
 private static final int BEAM_WIDTH = 4;
 public static void main(String[] args) {
 long seed = System.currentTimeMillis();
 Random random = new Random(seed);
 System.out.println("Seed = " + seed);
 GraphData data = createDirectedGraph(NODES, ARCS, random);
 warmup(data.graph, data.heuristicFunction, new Random(seed));
 benchmark(data.graph, data.heuristicFunction, new Random(seed));
 }
 private static final void 
 warmup(DirectedGraph graph, 
 HeuristicFunction<Integer> heuristicFunction,
 Random random) {
 perform(graph, heuristicFunction, random, false);
 }
 private static final void 
 benchmark(DirectedGraph graph, 
 HeuristicFunction<Integer> heuristicFunction,
 Random random) {
 perform(graph, heuristicFunction, random, true);
 }
 private static final void 
 perform(DirectedGraph graph, 
 HeuristicFunction<Integer> heuristicFunction,
 Random random,
 boolean output) {
 Integer sourceNode = random.nextInt(graph.size());
 Integer targetNode = random.nextInt(graph.size());
 BeamSearchPathfinder finder1 = new BeamSearchPathfinder();
 BidirectionalBeamSearchPathfinder finder2 = 
 new BidirectionalBeamSearchPathfinder();
 finder1.setBeamWidth(BEAM_WIDTH);
 finder2.setBeamWidth(BEAM_WIDTH);
 long start = System.currentTimeMillis();
 List<Integer> path1 = finder1.search(graph,
 sourceNode,
 targetNode,
 heuristicFunction);
 long end = System.currentTimeMillis();
 if (output) {
 System.out.println(finder1.getClass().getSimpleName() + ":");
 System.out.println("Path: " + path1 + ", length = " +
 getPathLength(path1, graph));
 System.out.println("Time: " + (end - start) + " milliseconds.");
 }
 finder1.setBeamWidth(Integer.MAX_VALUE);
 start = System.currentTimeMillis();
 List<Integer> optimalPath = finder1.search(graph, 
 sourceNode, 
 targetNode, 
 heuristicFunction);
 end = System.currentTimeMillis();
 if (output) {
 System.out.println("A*:");
 System.out.println("Path: " + optimalPath + ", length = " +
 getPathLength(optimalPath, graph));
 System.out.println("Time: " + (end - start) + " milliseconds.");
 }
 start = System.currentTimeMillis();
 List<Integer> path2 = finder2.search(graph,
 sourceNode, 
 targetNode, 
 heuristicFunction);
 end = System.currentTimeMillis();
 if (output) {
 System.out.println(finder2.getClass().getSimpleName() + ":");
 System.out.println("Path: " + path2 + ", length = " + 
 getPathLength(path2, graph));
 System.out.println("Time: " + (end - start) + " milliseconds.");
 }
 }
 private static double getPathLength(List<Integer> path,
 AbstractGraph graph) {
 double sum = 0.0;
 for (int i = 0; i < path.size() - 1; ++i) {
 sum += graph.getEdgeWeight(path.get(i), path.get(i + 1));
 }
 return sum;
 }
 private static final class GraphData {
 DirectedGraph graph;
 HeuristicFunction<Integer> heuristicFunction;
 }
 private static final Coordinates getRandomCoordinates(AbstractGraph graph,
 Random random) {
 Coordinates coordinates = new Coordinates();
 for (Integer node : graph.getAllNodes()) {
 coordinates.put(node, createRandomPoint(GRAPH_LAYOUT_WIDTH,
 GRAPH_LAYOUT_HEIGHT,
 random));
 }
 return coordinates;
 }
 private static final Point2D.Double
 createRandomPoint(double graphLayoutWidth,
 double graphLayoutHeight,
 Random random) {
 return new Point2D.Double(random.nextDouble() * graphLayoutWidth,
 random.nextDouble() * graphLayoutHeight);
 }
 private static final GraphData createDirectedGraph(int nodes,
 int arcs,
 Random random) {
 DirectedGraph graph = new DirectedGraph();
 for (int node = 0; node < nodes; ++node) {
 graph.addNode(node);
 }
 Coordinates coordinates = getRandomCoordinates(graph, random);
 HeuristicFunction<Integer> heuristicFunction =
 new DefaultHeuristicFunction(coordinates);
 for (int arc = 0; arc < arcs; ++arc) {
 Integer source = random.nextInt(nodes);
 Integer target = random.nextInt(nodes);
 double euclideanDistance = heuristicFunction.estimate(source,
 target);
 graph.addEdge(source,
 target, 
 ARC_LENGTH_FACTOR * euclideanDistance);
 }
 GraphData data = new GraphData();
 data.graph = graph;
 data.heuristicFunction = heuristicFunction;
 return data;
 }
}

HeuristicFunction.java

package net.coderodde.graph.pathfinding.beamsearch;
/**
 * This interface defines the API for heuristic functions.
 * 
 * @author Rodion "rodde" Efremov
 * @param <Node> the actual node type.
 * @version 1.6 (Sep 10, 2017)
 */
public interface HeuristicFunction<Node> {
 /**
 * Returns an optimistic estimate for the path from {@code source} to 
 * {@code target}.
 * 
 * @param source the source node.
 * @param target the target node.
 * @return distance estimate.
 */
 public double estimate(Node source, Node target);
}

PathNotFoundException.java

package net.coderodde.graph.pathfinding.beamsearch;
public final class PathNotFoundException extends RuntimeException {
 public PathNotFoundException(String message) {
 super(message);
 }
}

Pathfinder.java

package net.coderodde.graph.pathfinding.beamsearch;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import net.coderodde.graph.AbstractGraph;
public interface Pathfinder {
 /**
 * Searches for a path from {@code source} to {@code target} in 
 * {@code graph} using {@code heuristicFunction} as a guide.
 * 
 * @param graph the graph to search in.
 * @param source the source (start) node.
 * @param target the target (goal) node.
 * @param heuristicFunction the heuristic function.
 * @return 
 */
 public List<Integer> search(AbstractGraph graph,
 Integer source, 
 Integer target,
 HeuristicFunction<Integer> heuristicFunction);
 default List<Integer> tracebackPath(Integer target,
 Map<Integer, Integer> parents) {
 List<Integer> path = new ArrayList<>();
 Integer currentNode = target;
 while (currentNode != null) {
 path.add(currentNode);
 currentNode = parents.get(currentNode);
 }
 Collections.<Integer>reverse(path);
 return path;
 }
 default List<Integer> tracebackPath(Integer touch, 
 Map<Integer, Integer> forwardParents,
 Map<Integer, Integer> backwardParents) {
 List<Integer> prefixPath = tracebackPath(touch, forwardParents);
 Integer currentNode = backwardParents.get(touch);
 while (currentNode != null) {
 prefixPath.add(currentNode);
 currentNode = backwardParents.get(currentNode);
 }
 return prefixPath;
 }
 /**
 * Makes sure that both {@code source} and {@code target} are in the
 * {@code graph}.
 * 
 * @param graph the graph.
 * @param source the source node.
 * @param target the target node.
 */
 default void checkNodes(AbstractGraph graph, Integer source, Integer target) {
 if (!graph.hasNode(source)) {
 throw new IllegalArgumentException(
 "The source node " + source + " is not in the graph.");
 }
 if (!graph.hasNode(target)) {
 throw new IllegalArgumentException(
 "The target node " + target + " is not in the graph.");
 }
 }
}

Dependency This program relies on this Maven project.

#Critique request

Please tell me anything that comes to mind.

lang-java

AltStyle によって変換されたページ (->オリジナル) /