Skip to main content
Stack Overflow
  1. About
  2. For Teams

Return to Revisions

13 of 24
deleted 557 characters in body

How to implement a recursive distance function for Hex game AI evaluation?

I'm working on a Hex game AI that uses the alpha-beta pruning algorithm, and as part of evaluating the board state.

More detailed information about this distance metric can be found in this paper (see pages 14–17).

To simplify the problem and avoid unnecessary details, I’ve reduced the graph to a minimal example that isolates the issue.

the main problem is figuring out how to to calculate a specific distance metric between two cells for use in the evaluation function.

Conceptually, the problem comes down to calculating the distance between two cells (u and v) on a graph. I believe it's somewhat similar to BFS, DFS, or Dijkstra’s algorithm — but implemented recursively.

The Graph and Neighborhood Concept

Here is an example visualization of the graph, where I want to compute the distance from the node u to the node (4, 1):

Graph showing the structure for calculating distance from u to (4,1)

In this graph, each node is a cell on the Hex board, and connections represent valid moves depending on the board state.

To compute the distance, there are two main components:


1. Neighborhood of a Cell u (N(u))

This neighborhood is dynamic and not the standard definition from graph theory. It changes depending on the current board state.

Here’s a visualization where the shaded cells represent N(u) — the neighborhood of node u:

Shaded example showing dynamic neighbors of u

The neighborhood includes:

  • The immediate empty neighbors of u.
  • The neighbors of other friendly stones (connected to u) and the virtual nodes (edges of the board).
  • The neighbors of all connected chains that u is part of.

2. The Distance Function d(u, v)

Here is the high-level structure of the recursive distance function I’m trying to implement:

Recursive distance function diagram

The function c_k(u, v) returns the number of cells in the neighborhood of u whose distance to v is less than k:

Diagram showing the working of c_k(u, v)


Problem Example

Here’s a simple example where the function is called as:

d(u, (4, 1))

With the function defined as described above, the expected output is:

6

Why is the distance 6? I’m quoting from the paper:

"In that figure (see the last picture below), it shows the distancemetric between each unoccupied cell and the upper-left edge piece (the top-left black node) for the position in our running example. Notice that the cell with distancemetric 6 in the bottom corner of the last picture has a neighbourhood that includes the bottom-right black edge piece as well as the cell with distancemetric 5 further up on the right."


Code Implementation

Here’s my current code for the Hex board, graph construction, neighborhood extraction, and the distanceMetric function:

import math
SMALLBOARD = 5
def create_hex_graph(N):
 graph = {}
 directions = [(1, 0), (1, -1), (0, 1), (0, -1), (-1, 1), (-1, 0)]
 for r in range(N):
 for c in range(N):
 neighbors = []
 for dr, dc in directions:
 nr, nc = r + dr, c + dc
 if 0 <= nr < N and 0 <= nc < N:
 neighbors.append((nr, nc))
 graph[(r, c)] = neighbors
 graph['top'] = [(0, c) for c in range(N)]
 graph['bottom'] = [(N - 1, c) for c in range(N)]
 graph['left'] = [(r, 0) for r in range(N)]
 graph['right'] = [(r, N - 1) for r in range(N)]
 for c in range(N):
 graph[(0, c)].append('top')
 graph[(N - 1, c)].append('bottom')
 for r in range(N):
 graph[(r, 0)].append('left')
 graph[(r, N - 1)].append('right')
 return graph
class HexBoard:
 def __init__(self, size=SMALLBOARD):
 self.size = size
 self.board = [['.' for _ in range(size)] for _ in range(size)]
 def display(self):
 for i, row in enumerate(self.board):
 print(' ' * i + ' '.join(row))
 print()
 def get_valid_moves(self):
 return [(r, c) for r in range(self.size) for c in range(self.size) if self.board[r][c] == '.']
 def invalid_moves(self, row, col):
 return not (0 <= row < self.size and 0 <= col < self.size) or self.board[row][col] != '.'
 def make_move(self, row, col, player):
 if self.invalid_moves(row, col):
 print("invalid move!")
 return False
 self.board[row][col] = player
 return True
 def reset(self):
 self.board = [['.' for _ in range(self.size)] for _ in range(self.size)]
def filtered_board(board: HexBoard, player_symbol: str) -> dict:
 opponent = 'O' if player_symbol == 'X' else 'X'
 N = board.size
 full_graph = create_hex_graph(N)
 opponent_cells = {(r, c) for r in range(N) for c in range(N) if board.board[r][c] == opponent}
 filtered_graph = {}
 for vertex, neighbors in full_graph.items():
 if vertex in opponent_cells:
 continue
 filtered_neighbors = [n for n in neighbors if n not in opponent_cells]
 filtered_graph[vertex] = filtered_neighbors
 return filtered_graph
def chainfromFilteredBoard(board: HexBoard, player: str) -> list[set]:
 filtered_graph = filtered_board(board, player)
 visited = set()
 chains = []
 player_cells = {(r, c) for r in range(board.size) for c in range(board.size) if board.board[r][c] == player}
 relevant_virtual_nodes = ['top', 'bottom'] if player == 'X' else ['left', 'right']
 def dfs(start):
 stack = [start]
 current_chain = set()
 while stack:
 node = stack.pop()
 if node in visited:
 continue
 visited.add(node)
 current_chain.add(node)
 for neighbor in filtered_graph.get(node, []):
 if (neighbor in player_cells or neighbor in relevant_virtual_nodes) and neighbor not in visited:
 stack.append(neighbor)
 return current_chain
 for cell in player_cells.union(set(relevant_virtual_nodes)):
 if cell not in visited:
 chains.append(dfs(cell))
 return chains
def neighborhood(u, player: str, board: HexBoard) -> set:
 filtered_graph = filtered_board(board, player)
 neighbors_set = set()
 for neighbor in filtered_graph.get(u, []):
 if isinstance(neighbor, tuple):
 if board.board[neighbor[0]][neighbor[1]] == '.':
 neighbors_set.add(neighbor)
 else:
 if player == 'X' and neighbor in ['top', 'bottom']:
 neighbors_set.add(neighbor)
 elif player == 'O' and neighbor in ['left', 'right']:
 neighbors_set.add(neighbor)
 player_chains = chainfromFilteredBoard(board, player)
 for ch in player_chains:
 if any(u in filtered_graph.get(node, []) for node in ch):
 for node in ch:
 for neighbor in filtered_graph.get(node, []):
 if isinstance(neighbor, tuple):
 if board.board[neighbor[0]][neighbor[1]] == '.':
 neighbors_set.add(neighbor)
 else:
 if player == 'X' and neighbor in ['top', 'bottom']:
 neighbors_set.add(neighbor)
 elif player == 'O' and neighbor in ['left', 'right']:
 neighbors_set.add(neighbor)
 neighbors_set.discard(u)
 return neighbors_set
def distancemetric(u, v, player, board, visited=None, memo=None):
 if visited is None:
 visited = set()
 if memo is None:
 memo = {}
 if (u, v) in memo:
 return memo[(u, v)]
 if u == v:
 return 0
 if u in visited:
 return math.inf
 visited.add(u)
 neighbors_u = neighborhood(u, player, board)
 if v in neighbors_u:
 memo[(u, v)] = 1
 visited.remove(u)
 return 1
 max_k = board.size * 2 + 10
 for k in range(2, max_k):
 count = 0
 for w in neighbors_u:
 dist_wv = distancemetric(w, v, player, board, visited, memo)
 if dist_wv < k:
 count += 1
 if count >= 2:
 memo[(u, v)] = k
 visited.remove(u)
 return k
 memo[(u, v)] = math.inf
 visited.remove(u)
 return math.inf
# ---- Example Usage ----
if __name__ == "__main__":
 board = HexBoard(size=5)
 board.make_move(0, 3, 'X')
 board.make_move(2, 1, 'X')
 board.make_move(2, 2, 'X')
 board.make_move(3, 1, 'O')
 board.make_move(4, 2, 'O')
 board.display()
 u = 'top'
 v = (4, 1)
 player = 'X'
 dist = distancemetric(u, v, player, board)
 print(f"distancemetric({v}, {u}) = {dist}")

What’s Going Wrong?

My distanceMetric function calculates the distance to some target nodes incorrectly and and I believe my recursive approach might be fundamentally wrong

I am able to compute the distance to any target node, but the inconsistencies make it hard to trust the metric during evaluation.

Here is a visualization of the expected distances from u to every other node on the board (these are correct and by the paper verified):

![Expected distances from u to every node](https://i.sstatic.net/

lang-py

AltStyle によって変換されたページ (->オリジナル) /