Here is some code I have written as a connect four negamax evaluation function. The state of the board is stored inside a class state
. state.get()
is \$O(1)\,ドル state.getBoard()
is \$O(1)\,ドル and state.getWinnerColor()
and state.getDraw()
are also constant (although they have a higher constant coefficient). The real meat of the problem is the connectedStrength()
function, which is the most accurate component of the heuristic I have. With it commented out of evaluate()
, it runs in around 2500 nanoseconds, uncommented, it jumps 430% to around 10800!
private int[][] bS = {{3, 4, 5, 7, 5, 4, 3}, // coefficients for board values
{4, 6, 8, 10, 8, 6, 4},
{5, 8, 11, 13, 11, 8, 5},
{5, 8, 11, 13, 11, 8, 5},
{4, 6, 8, 10, 8, 6, 4},
{3, 4, 5, 7, 5, 4, 3}};
private int[] cS = {0, 10, 100}; // coefficients for x in a row
public int evaluate() {
int winner = state.getWinnerColor();
if (winner == 1) {
return Integer.MAX_VALUE;
}
else if (winner == -1) {
return -Integer.MAX_VALUE;
}
else if (state.isDraw()) {
return 0;
}
int total = 0;
total += boardStrength();
total += connectedStrength();
return total;
}
private int boardStrength() {
int total = 0;
for (int row = 0; row < 6; ++row) {
for (int col = 0; col < 7; ++col) {
total += state.getBoard()[row][col] * bS[row][col];
}
}
return total;
}
public int connectedStrength() {
int total = 0, color;
IntegerPair a = null, b = null;
int[][] s = {{1,1,0,-1},{0,1,1,1}}; //up/down, up-right/down-left, right/left, down-right/up-left
IntegerPair last;
boolean[][] visited = new boolean[6][7];
for (int row = 0; row < 6; ++row) {
for (int col = 0; col < 7; ++col) { // loop through each tile position possible
if (!visited[row][col]) {
last = new IntegerPair(row, col);
color = state.get(row, col);
if (color != 0) { // if square isn't empty
for (int i = 0; i < 4; ++ i) { // loop through four directions
a = check(last, s[0][i], s[1][i], visited); // get lengths
b = check(last, -s[0][i], -s[1][i], visited);
if (a.second() + b.second() >= 3) { // if possible to create a four with this chain
total += color * cS[a.first() + b.first()]; // add to heuristic
}
}
}
visited[row][col] = true;
}
}
}
return total;
}
private IntegerPair check(IntegerPair last, int d1, int d2, boolean[][] visited) { // returns the # in a row
//in row direction d1 and col direction d2 and also if the next one is free
int len = 1, player = state.get(last.first(), last.second());
while (last.first() + len * d1 >= 0
&& last.second() + len * d2 >= 0
&& last.first() + len * d1 <= 5
&& last.second() + len * d2 <= 6 // while inbounds
&& state.get(last.first() + len * d1, last.second() + len * d2) == player) { // and while the tiles are the same color as the player's
visited[last.first() + len * d1][last.second() + len * d2] = true;
len += 1;
}
int same = len;
while (last.first() + len * d1 >= 0
&& last.second() + len * d2 >= 0
&& last.first() + len * d1 <= 5
&& last.second() + len * d2 <= 6 // again, while inbounds
&& state.get(last.first() + len * d1, last.second() + len * d2) == 0) { // while the tiles are not the enemy's
visited[last.first() + len * d1][last.second() + len * d2] = true;
len += 1;
}
return new IntegerPair(same-1, len-1);
}
The evaluate function is quite simple: for each unvisited tile, if the tile has a tile in it (!=0) then it calls check, which returns a pair of ints, the first one being how long the current linear chain is, the second one being how long the chain could potentially extend to. Another way of thinking about it is the first one is when the chain ends at an empty square or enemy tile or the wall, and the second one is when it ends at an enemy tile or the wall. If the second value is less than 3 (thus a chain of four can never be constructed) (3 is not including the starting tile itself) the value is thrown away; otherwise, the value is added to the total (or subtracted depending on the player).
Are there are places where this function could be optimized?
3 Answers 3
Code duplication and magic numbers
You have some code duplication in check()
method which can be extracted to a separate method and be simplified to do less claculations.
Let us first add a isValidColorPosition()
method to get this ugly while condition extracted
private final int LEFT_BOTTOM_BORDER= 0;
private final int RIGHT_BORDER = 6;
private final int TOP_BORDER = 5;
private boolean isValidColorPosition(int row, int col, int color) {
return row >= LEFT_BOTTOM_BORDER
&& col >= LEFT_BOTTOM_BORDER
&& row <= TOP_BORDER
&& col <= RIGHT_BORDER
&& state.get(row, col) == color;
}
great, we also removed the magic numbers.
Now we use this method in the extracted getLength()
method
private int getLength(IntegerPair position, IntegerPair directions, int currentColor, boolean[][] visited, int count) {
int row = position.first() + count * directions.first();
int col = position.second() + count * directions.second();
while (isValidColorPosition(row, col, currentColor)) {
visited[row][col] = true;
count += 1;
row += directions.first();
col += directions.second();
}
return count;
}
which is called by the former check()
method now renamed to getLengthPair()
private final int EMPTY_COLOR = 0;
private IntegerPair getLengthPair(IntegerPair last, IntegerPair direction, int playerColor, boolean[][] visited) {
IntegerPair position = new IntegerPair(last.first() + direction.first(),
last.second() + direction.second());
int playerLength = getLength(position, direction, playerColor, visited, 0);
int possibleLength = getLength(position, direction, EMPTY_COLOR, visited, playerLength);
return new IntegerPair(playerLength, possibleLength);
}
and again one magic number is gone.
This method will be called by the new getStrength()
method
private int getStrength(IntegerPair last, int color, boolean[][] visited, int[][] directions) {
int strength = 0;
for (int i = 0; i < 4; ++i) {
IntegerPair direction = new IntegerPair(directions[0][i], directions[1][i]);
IntegerPair a = getLengthPair(last, direction, color, visited);
direction = new IntegerPair(-directions[0][i], -directions[1][i]);
IntegerPair b = getLengthPair(last, direction, color, visited);
if (a.second() + b.second() >= 3) {
strength += color * cS[a.first() + b.first()];
}
}
return strength;
}
which is called by the connectedStrength()
method
public int connectedStrength() {
int total = 0;
int[][] s = {{1, 1, 0, -1}, {0, 1, 1, 1}};
boolean[][] visited = new boolean[6][7];
for (int row = 0; row < 6; ++row) {
for (int col = 0; col < 7; ++col) {
if (visited[row][col]) {
continue;
}
int color = state.get(row, col);
if (color != 0) {
total += getStrength(new IntegerPair(row, col), color, visited, s);
}
visited[row][col] = true;
}
}
return total;
}
Comments should be used only to describe why something is done in the way it is done. The code itself should describe what is done by using meaningful names for classes, methods and variables.
By using guard clauses you can prevent to code by using the Arrow Anti Pattern.
I think this line:
&& state.get(last.first() + len * d1, last.second() + len * d2) == 0) { // while the tiles are not the enemy's
doesn't do what the comment says it should do. Assuming that a blank tile is 0
, one player is 1
, and the other player is -1
, I think that line should be changed to this:
&& state.get(last.first() + len * d1, last.second() + len * d2) != -player) { // while the tiles are not the enemy's
If you do that, then it will find chains in the form of O--O
, where O
is a checker and -
is a blank tile. Right now, O--O
will be considered a chain of length 3 instead of length 4.
Readability
Your check
method is pretty hard to read, because of a couple of reasons:
- bad naming:
check
what? why? how? what'sd1
andd2
?last
what? - no JavaDoc method comment describing what the method does, what arguments it accepts and what it returns
- hardcoded magic numbers: what do these represent? why 5 and 6?
- two nearly identical loops: it's a lot of work to find the one small difference between them. I would extract the loop to a method which accepts that one difference as argument.
These problems also occur in other parts of your code. None of the methods have comments, and bS
, cS
, s
, a
, b
, etc are all bad names. Even some of the longer names are not all that expressive. second
and first
what of what? last
what?
connectedStrength()
does what you think it should do. First, why setvisited
insidecheck()
? If you visit a tile while checking for a horizontal chain, you won't check for a vertical chain starting on that tile because it has already been visited. Second, you don't consider "tile space tile space" as a possible 4 chain becausecheck
stops searching after the first series of spaces. \$\endgroup\$