[4.60517019 6.90775528 9.21034037]
CSI 4106 - Fall 2024
Version: Nov 28, 2024 14:35
In the introductory lecture on state space search, I used Monte Carlo Tree Search (MCTS), a key component of AlphaGo, to exemplify the role of search algorithms in reasoning.
Today, we conclude this series by examining the implementation details of this algorithm.
A Monte Carlo algorithm is a computational method that uses random sampling to obtain numerical results, often used for optimization, numerical integration, and probability distribution estimation.
It is characterized by its ability to handle complex problems with probabilistic solutions, trading exactness for efficiency and scalability.
Like other algorithms previously discussed, such as BFS, DFS, and \(A^\star\), Monte Carlo Tree Search (MCTS) maintains a frontier of unexpanded nodes.
Similar to \(A^\star\), Monte Carlo Tree Search (MCTS) employs a heuristic, referred to as a policy, to determine the optimal node for expansion.
However, in \(A^\star\), the heuristic is typically a static function estimating cost to a goal, whereas in MCTS, the “policy” involves dynamic evaluation.
Similar to Simulated Annealing and Genetic Algorithms, Monte Carlo Tree Search (MCTS) incorporates a mechanism to balance exploration and exploitation.
In contrast to previous algorithms with implicit search trees, MCTS constructs an explicit tree structure during execution.
Initially, the tree has one node, it is \(S_0\).
We add its descendants and we are ready to start.
The Monte Carlo Tree Search slowly builds its search tree.
With each iteration, the following steps occur:
Selection: Identify the optimal node by traversing a single path in the tree, guided by UCB1.
Expansion: Expand the node if it is a leaf in the MCTS Tree and \(n \gt 0\).
Rollout: Simulate a game from the current state to a terminal state by randomly selecting actions.
Backpropagation: Use the obtained information to update the current node and all parent nodes up to the root.
Each node records its total score and visit count.
This information is used to calculate a value that guides tree traversal, balancing exploration and exploitation.
\[ \mathrm{UCB1}(S_i) = \overline{V_i} + C \sqrt{\frac{\ln(N)}{n_i}} \]
The usual value for \(C\) is \(\sqrt{2}\).
Exploration essentially occurs when two nodes have approximately the same average score, then MCTS favours nodes with fewer visits (dividing by \(n\)).
For \(n \lt \ln(N)\), the value of the ratio is greater than 1, whereas for \(n \gt \ln(N)\), the ratio becomes less than 1.
So there is a small fraction of the time where exploration kicks in. But even then, the contribution of the ratio is quite tame, we’re taking the square root of that ratio, multiplied by \(\sqrt{2} \sim 1.414213562\).
[4.60517019 6.90775528 9.21034037]
import numpy as np
import matplotlib.pyplot as plt
num_iterations = 10
# Define the range for n and N
n_values = np.arange(1, num_iterations + 1)
N_values = np.arange(1, num_iterations + 1)
# Prepare a meshgrid for n and N
N, n = np.meshgrid(N_values, n_values)
# Compute the expression for each pair (n, N)
Z = np.sqrt(2) * np.sqrt(np.log(N) / n)
# Plotting
plt.figure(figsize=(8, 6))
plt.contourf(N, n, Z, cmap='viridis')
plt.colorbar(label=r'$\sqrt{2} \times \sqrt{\frac{\log{N}}{n}}$')
plt.xlabel('N')
plt.ylabel('n')
plt.title(r'Visualization of $\sqrt{2} \times \sqrt{\frac{\log{N}}{n}}$ for $n, N = 1..\mathrm{num\_iterations}$')
plt.show()
import numpy as np
import matplotlib.pyplot as plt
num_iterations = 100
# Define the range for n and N
n_values = np.arange(1, num_iterations + 1)
N_values = np.arange(1, num_iterations + 1)
# Prepare a meshgrid for n and N
N, n = np.meshgrid(N_values, n_values)
# Compute the expression for each pair (n, N)
Z = np.sqrt(2) * np.sqrt(np.log(N) / n)
# Plotting
plt.figure(figsize=(8, 6))
plt.contourf(N, n, Z, cmap='viridis')
plt.colorbar(label=r'$\sqrt{2} \times \sqrt{\frac{\log{N}}{n}}$')
plt.xlabel('N')
plt.ylabel('n')
plt.title(r'Visualization of $\sqrt{2} \times \sqrt{\frac{\log{N}}{n}}$ for $n, N = 1..\mathrm{num\_iterations}$')
plt.show()
# Game class for Tic-Tac-Toe
class TicTacToe:
def __init__(self):
# Initialize the 3x3 board with empty spaces
self.size = 3
self.board = np.full((self.size, self.size), ' ')
def get_valid_moves(self, state):
"""
Return a list of available positions on the board.
"""
moves = [(i, j) for i in range(self.size) for j in range(self.size) if state[i][j] == ' ']
return moves
def make_move(self, state, move, player):
"""
Place the player's symbol on the board at the specified move.
Return the new state.
"""
new_state = state.copy()
new_state[move[0], move[1]] = player
return new_state
def get_opponent(self, player):
"""
Return the opponent of the given player.
"""
return 'O' if player == 'X' else 'X'
def is_terminal(self, state):
"""
Check if the game has ended, either by win or draw.
"""
return self.evaluate(state) != 0 or ' ' not in state
def evaluate(self, state):
"""
Evaluate the board state.
Return:
1 if 'X' wins,
-1 if 'O' wins,
0 otherwise (draw or ongoing game).
"""
lines = []
# Rows and columns
for i in range(self.size):
lines.append(state[i, :]) # Row i
lines.append(state[:, i]) # Column i
# Diagonals
lines.append(np.diag(state))
lines.append(np.diag(np.fliplr(state)))
# Check for a winner
for line in lines:
if np.all(line == 'X'):
return 1
elif np.all(line == 'O'):
return -1
# No winner
return 0
def display(self, state):
"""
Print the board to the console.
"""
print("\nCurrent board:")
for i in range(self.size):
row = '|'.join(state[i])
print(row)
if i < self.size - 1:
print('-' * (self.size * 2 - 1))
Node
# Node class for MCTS
class Node:
def __init__(self, state, parent=None, move=None, player='X'):
self.state = state # Game state at this node
self.parent = parent # Parent node
self.children = [] # List of child nodes
self.visits = 0 # Number of times node has been visited
self.wins = 0 # Number of wins from this node
self.move = move # Move that led to this node
self.player = player # Player who made the move
def is_fully_expanded(self, game):
"""
Check if all possible moves from this node have been explored.
"""
return len(self.children) == len(game.get_valid_moves(self.state))
def best_child(self, c_param=math.sqrt(2)):
"""
Select the child node with the highest UCB1 value.
"""
choices_weights = []
for child in self.children:
if child.visits == 0:
ucb1 = float('inf')
else:
win_rate = child.wins / child.visits
exploration = c_param * math.sqrt((2 * math.log(self.visits)) / child.visits)
ucb1 = win_rate + exploration
choices_weights.append(ucb1)
return self.children[np.argmax(choices_weights)]
def most_visited_child(self):
"""
Select the child node with the highest visit count.
"""
visits = [child.visits for child in self.children]
return self.children[np.argmax(visits)]
mcts
# MCTS Algorithm
def mcts(game, root, iterations):
for _ in range(iterations):
node = tree_policy(game, root)
reward = default_policy(game, node.state, node.player)
backup(node, reward)
return root.most_visited_child()
def tree_policy(game, node):
"""
Selection and Expansion phases.
"""
while not game.is_terminal(node.state):
if not node.is_fully_expanded(game):
return expand(game, node)
else:
node = node.best_child()
return node
def expand(game, node):
"""
Expand a child node from the current node.
"""
tried_moves = [child.move for child in node.children]
possible_moves = game.get_valid_moves(node.state)
for move in possible_moves:
if move not in tried_moves:
next_state = game.make_move(node.state, move, node.player)
child_node = Node(state=next_state, parent=node, move=move, player=game.get_opponent(node.player))
node.children.append(child_node)
return child_node
def default_policy(game, state, player):
"""
Simulation phase: play out the game randomly from the given state.
"""
current_state = state.copy()
current_player = player
while not game.is_terminal(current_state):
possible_moves = game.get_valid_moves(current_state)
move = random.choice(possible_moves)
current_state = game.make_move(current_state, move, current_player)
current_player = game.get_opponent(current_player)
result = game.evaluate(current_state)
return result
def backup(node, reward):
"""
Backpropagation phase: update the node statistics.
"""
while node is not None:
node.visits += 1
# Update wins. If the result is a win for the player who just played, add 1.
if (node.player == 'X' and reward == 1) or (node.player == 'O' and reward == -1):
node.wins += 1
elif reward == 0:
node.wins += 0.5 # Consider draw as half win
node = node.parent
test_tic_tac_toe_mcts
# Main function to demonstrate the application
def test_tic_tac_toe_mcts():
game = TicTacToe()
current_state = game.board.copy()
current_player = 'X'
root_node = Node(state=current_state, player=current_player)
while not game.is_terminal(current_state):
game.display(current_state)
print(f"Player {current_player}'s turn.")
if current_player == 'X':
# AI's turn using MCTS
iterations = 1000 # Adjust as needed
best_child = mcts(game, root_node, iterations)
current_state = best_child.state
root_node = best_child
else:
# Human player's turn
possible_moves = game.get_valid_moves(current_state)
print("Possible moves:", possible_moves)
move = None
while move not in possible_moves:
try:
move_input = input("Enter your move as 'row,col': ")
move = tuple(int(x.strip()) for x in move_input.split(','))
except:
print("Invalid input. Please enter row and column numbers separated by a comma.")
current_state = game.make_move(current_state, move, current_player)
# Update the tree: find or create the child node corresponding to the move
matching_child = None
for child in root_node.children:
if child.move == move:
matching_child = child
break
if matching_child:
root_node = matching_child
else:
root_node = Node(state=current_state, parent=None, move=move, player=game.get_opponent(current_player))
# Switch player
current_player = game.get_opponent(current_player)
# Game over
game.display(current_state)
result = game.evaluate(current_state)
if result == 1:
print("X wins!")
elif result == -1:
print("O wins!")
else:
print("It's a draw!")
if __name__ == "__main__":
test_tic_tac_toe_mcts()
Implement code to visualize the search tree, either as text or using Graphviz.
Incorporate heuristics to detect when a winning move is achievable in a single step.
Experiment with varying the number of iterations and the constant \(C\).
Marcel Turcotte
School of Electrical Engineering and Computer Science (EECS)
University of Ottawa