Monte Carlo Tree Search

CSI 4106 - Fall 2024

Marcel Turcotte

Version: Nov 28, 2024 14:35

Preamble

Quote of the Day

Learning objectives

  • Explain the concept and key steps of Monte Carlo Tree Search (MCTS).
  • Compare MCTS with other search algorithms like BFS, DFS, A*, Simulated Annealing, and Genetic Algorithms.
  • Analyze how MCTS balances exploration and exploitation using the UCB1 formula.
  • Implement MCTS in practical applications such as Tic-Tac-Toe.

Introduction

Monte Carlo Tree Search (MCTS)

In the introductory lecture on state space search, I used Monte Carlo Tree Search (MCTS), a key component of AlphaGo, to exemplify the role of search algorithms in reasoning.

Today, we conclude this series by examining the implementation details of this algorithm.

Applications

  • De novo drug design
  • Electronic circuit routing
  • Load monitoring in smart grids
  • Lane keeping and overtaking tasks
  • Motion planning in autonomous driving
  • Even solving the travelling salesman problem

Historical Notes

Definition

A Monte Carlo algorithm is a computational method that uses random sampling to obtain numerical results, often used for optimization, numerical integration, and probability distribution estimation.

It is characterized by its ability to handle complex problems with probabilistic solutions, trading exactness for efficiency and scalability.

Algorithm

  1. Selection (tree traversal)
  2. Node expansion
  3. Rollout (simulation)
  4. Back-propagation

Algorithm

Discussion

Like other algorithms previously discussed, such as BFS, DFS, and \(A^\star\), Monte Carlo Tree Search (MCTS) maintains a frontier of unexpanded nodes.

Discussion

Similar to \(A^\star\), Monte Carlo Tree Search (MCTS) employs a heuristic, referred to as a policy, to determine the optimal node for expansion.

However, in \(A^\star\), the heuristic is typically a static function estimating cost to a goal, whereas in MCTS, the “policy” involves dynamic evaluation.

Discussion

Similar to Simulated Annealing and Genetic Algorithms, Monte Carlo Tree Search (MCTS) incorporates a mechanism to balance exploration and exploitation.

Discussion

  • MCTS leverages all visited nodes in its decision-making process, unlike \(A^\star\), which primarily focuses on the current frontier.
  • Additionally, MCTS iteratively updates the value of its nodes based on simulations, whereas \(A^\star\) typically uses a static heuristic.

Discussion

In contrast to previous algorithms with implicit search trees, MCTS constructs an explicit tree structure during execution.

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Walk-through

Russell and Norvig

Summary (1)

Initially, the tree has one node, it is \(S_0\).

We add its descendants and we are ready to start.

The Monte Carlo Tree Search slowly builds its search tree.

Summary (2)

With each iteration, the following steps occur:

  1. Selection: Identify the optimal node by traversing a single path in the tree, guided by UCB1.

  2. Expansion: Expand the node if it is a leaf in the MCTS Tree and \(n \gt 0\).

  3. Rollout: Simulate a game from the current state to a terminal state by randomly selecting actions.

  4. Backpropagation: Use the obtained information to update the current node and all parent nodes up to the root.

Summary (3)

Each node records its total score and visit count.

This information is used to calculate a value that guides tree traversal, balancing exploration and exploitation.

Summary (4)

\[ \mathrm{UCB1}(S_i) = \overline{V_i} + C \sqrt{\frac{\ln(N)}{n_i}} \]

The usual value for \(C\) is \(\sqrt{2}\).

Exploration essentially occurs when two nodes have approximately the same average score, then MCTS favours nodes with fewer visits (dividing by \(n\)).

For \(n \lt \ln(N)\), the value of the ratio is greater than 1, whereas for \(n \gt \ln(N)\), the ratio becomes less than 1.

So there is a small fraction of the time where exploration kicks in. But even then, the contribution of the ratio is quite tame, we’re taking the square root of that ratio, multiplied by \(\sqrt{2} \sim 1.414213562\).

Summary (5)

[4.60517019 6.90775528 9.21034037]

Summary (5)

Code
import numpy as np
import matplotlib.pyplot as plt

num_iterations = 10

# Define the range for n and N
n_values = np.arange(1, num_iterations + 1)
N_values = np.arange(1, num_iterations + 1)

# Prepare a meshgrid for n and N
N, n = np.meshgrid(N_values, n_values)

# Compute the expression for each pair (n, N)
Z = np.sqrt(2) * np.sqrt(np.log(N) / n)

# Plotting
plt.figure(figsize=(8, 6))
plt.contourf(N, n, Z, cmap='viridis')
plt.colorbar(label=r'$\sqrt{2} \times \sqrt{\frac{\log{N}}{n}}$')
plt.xlabel('N')
plt.ylabel('n')
plt.title(r'Visualization of $\sqrt{2} \times \sqrt{\frac{\log{N}}{n}}$ for $n, N = 1..\mathrm{num\_iterations}$')
plt.show()

Summary (5)

Code
import numpy as np
import matplotlib.pyplot as plt

num_iterations = 100

# Define the range for n and N
n_values = np.arange(1, num_iterations + 1)
N_values = np.arange(1, num_iterations + 1)

# Prepare a meshgrid for n and N
N, n = np.meshgrid(N_values, n_values)

# Compute the expression for each pair (n, N)
Z = np.sqrt(2) * np.sqrt(np.log(N) / n)

# Plotting
plt.figure(figsize=(8, 6))
plt.contourf(N, n, Z, cmap='viridis')
plt.colorbar(label=r'$\sqrt{2} \times \sqrt{\frac{\log{N}}{n}}$')
plt.xlabel('N')
plt.ylabel('n')
plt.title(r'Visualization of $\sqrt{2} \times \sqrt{\frac{\log{N}}{n}}$ for $n, N = 1..\mathrm{num\_iterations}$')
plt.show()

Tic-tac-toe

# Game class for Tic-Tac-Toe
class TicTacToe:
    def __init__(self):
        # Initialize the 3x3 board with empty spaces
        self.size = 3
        self.board = np.full((self.size, self.size), ' ')

    def get_valid_moves(self, state):
        """
        Return a list of available positions on the board.
        """
        moves = [(i, j) for i in range(self.size) for j in range(self.size) if state[i][j] == ' ']
        return moves

    def make_move(self, state, move, player):
        """
        Place the player's symbol on the board at the specified move.
        Return the new state.
        """
        new_state = state.copy()
        new_state[move[0], move[1]] = player
        return new_state

    def get_opponent(self, player):
        """
        Return the opponent of the given player.
        """
        return 'O' if player == 'X' else 'X'

    def is_terminal(self, state):
        """
        Check if the game has ended, either by win or draw.
        """
        return self.evaluate(state) != 0 or ' ' not in state

    def evaluate(self, state):
        """
        Evaluate the board state.
        Return:
            1 if 'X' wins,
            -1 if 'O' wins,
            0 otherwise (draw or ongoing game).
        """
        lines = []

        # Rows and columns
        for i in range(self.size):
            lines.append(state[i, :])      # Row i
            lines.append(state[:, i])      # Column i

        # Diagonals
        lines.append(np.diag(state))
        lines.append(np.diag(np.fliplr(state)))

        # Check for a winner
        for line in lines:
            if np.all(line == 'X'):
                return 1
            elif np.all(line == 'O'):
                return -1

        # No winner
        return 0

    def display(self, state):
        """
        Print the board to the console.
        """
        print("\nCurrent board:")
        for i in range(self.size):
            row = '|'.join(state[i])
            print(row)
            if i < self.size - 1:
                print('-' * (self.size * 2 - 1))

Node

# Node class for MCTS
class Node:
    def __init__(self, state, parent=None, move=None, player='X'):
        self.state = state            # Game state at this node
        self.parent = parent          # Parent node
        self.children = []            # List of child nodes
        self.visits = 0               # Number of times node has been visited
        self.wins = 0                 # Number of wins from this node
        self.move = move              # Move that led to this node
        self.player = player          # Player who made the move

    def is_fully_expanded(self, game):

        """
        Check if all possible moves from this node have been explored.
        """

        return len(self.children) == len(game.get_valid_moves(self.state))

    def best_child(self, c_param=math.sqrt(2)):

        """
        Select the child node with the highest UCB1 value.
        """

        choices_weights = []
        for child in self.children:
            if child.visits == 0:
                ucb1 = float('inf')
            else:
                win_rate = child.wins / child.visits
                exploration = c_param * math.sqrt((2 * math.log(self.visits)) / child.visits)
                ucb1 = win_rate + exploration
            choices_weights.append(ucb1)

        return self.children[np.argmax(choices_weights)]

    def most_visited_child(self):

        """
        Select the child node with the highest visit count.
        """

        visits = [child.visits for child in self.children]
        return self.children[np.argmax(visits)]

mcts

# MCTS Algorithm
def mcts(game, root, iterations):

    for _ in range(iterations):
        node = tree_policy(game, root)
        reward = default_policy(game, node.state, node.player)
        backup(node, reward)

    return root.most_visited_child()

def tree_policy(game, node):

    """
    Selection and Expansion phases.
    """

    while not game.is_terminal(node.state):
        if not node.is_fully_expanded(game):
            return expand(game, node)
        else:
            node = node.best_child()

    return node

def expand(game, node):

    """
    Expand a child node from the current node.
    """

    tried_moves = [child.move for child in node.children]
    possible_moves = game.get_valid_moves(node.state)

    for move in possible_moves:
        if move not in tried_moves:
            next_state = game.make_move(node.state, move, node.player)
            child_node = Node(state=next_state, parent=node, move=move, player=game.get_opponent(node.player))
            node.children.append(child_node)
            return child_node

def default_policy(game, state, player):

    """
    Simulation phase: play out the game randomly from the given state.
    """

    current_state = state.copy()
    current_player = player

    while not game.is_terminal(current_state):
        possible_moves = game.get_valid_moves(current_state)
        move = random.choice(possible_moves)
        current_state = game.make_move(current_state, move, current_player)
        current_player = game.get_opponent(current_player)

    result = game.evaluate(current_state)

    return result

def backup(node, reward):

    """
    Backpropagation phase: update the node statistics.
    """

    while node is not None:
        node.visits += 1
        # Update wins. If the result is a win for the player who just played, add 1.
        if (node.player == 'X' and reward == 1) or (node.player == 'O' and reward == -1):
            node.wins += 1
        elif reward == 0:
            node.wins += 0.5  # Consider draw as half win
        node = node.parent

test_tic_tac_toe_mcts

# Main function to demonstrate the application
def test_tic_tac_toe_mcts():

    game = TicTacToe()
    current_state = game.board.copy()
    current_player = 'X'
    root_node = Node(state=current_state, player=current_player)

    while not game.is_terminal(current_state):

        game.display(current_state)
        print(f"Player {current_player}'s turn.")

        if current_player == 'X':
            # AI's turn using MCTS
            iterations = 1000  # Adjust as needed
            best_child = mcts(game, root_node, iterations)
            current_state = best_child.state
            root_node = best_child
        else:
            # Human player's turn
            possible_moves = game.get_valid_moves(current_state)
            print("Possible moves:", possible_moves)
            move = None
            while move not in possible_moves:
                try:
                    move_input = input("Enter your move as 'row,col': ")
                    move = tuple(int(x.strip()) for x in move_input.split(','))
                except:
                    print("Invalid input. Please enter row and column numbers separated by a comma.")
            current_state = game.make_move(current_state, move, current_player)
            # Update the tree: find or create the child node corresponding to the move
            matching_child = None
            for child in root_node.children:
                if child.move == move:
                    matching_child = child
                    break
            if matching_child:
                root_node = matching_child
            else:
                root_node = Node(state=current_state, parent=None, move=move, player=game.get_opponent(current_player))
        # Switch player
        current_player = game.get_opponent(current_player)

    # Game over
    game.display(current_state)
    result = game.evaluate(current_state)
    if result == 1:
        print("X wins!")
    elif result == -1:
        print("O wins!")
    else:
        print("It's a draw!")

if __name__ == "__main__":
    test_tic_tac_toe_mcts()

Exploration

  • Implement code to visualize the search tree, either as text or using Graphviz.

  • Incorporate heuristics to detect when a winning move is achievable in a single step.

  • Experiment with varying the number of iterations and the constant \(C\).

Prologue

Summary

  • Monte Carlo Tree Search (MCTS) is a search algorithm used for decision-making in complex games.
  • MCTS operates through four main steps: Selection, Expansion, Rollout (Simulation), and Backpropagation.
  • It balances exploration and exploitation using the UCB1 formula, which guides node selection based on visit counts and scores.
  • MCTS maintains an explicit search tree, updating node values iteratively based on simulations.
  • The algorithm has wide-ranging applications, including AI gaming, drug design, circuit routing, and autonomous driving.
  • Introduced in 2008, MCTS gained prominence with its use in AlphaGo in 2016.
  • Unlike traditional algorithms like \(A^\star\), MCTS uses dynamic policies and leverages all visited nodes for decision-making.
  • Implementing MCTS involves tracking node statistics and applying the UCB1 formula to guide search.
  • A practical example of MCTS is demonstrated through implementing Tic-Tac-Toe.
  • Further exploration includes integrating MCTS with deep learning models like AlphaZero and MuZero.

Further exploration

The End

  • Consult the course website for information on the final examination.

References

Chaslot, Guillaume, Sander Bakkes, Istvan Szita, and Pieter Spronck. 2008. “Monte-Carlo Tree Search: A New Framework for Game AI.” In Proceedings of the Fourth AAAI Conference on Artificial Intelligence and Interactive Digital Entertainment, 216–17. AIIDE’08. Stanford, California: AAAI Press.
Kemmerling, Marco, Daniel Lütticke, and Robert H. Schmitt. 2024. “Beyond Games: A Systematic Review of Neural Monte Carlo Tree Search Applications.” Applied Intelligence 54 (1): 1020–46. https://doi.org/10.1007/s10489-023-05240-w.
Russell, Stuart, and Peter Norvig. 2020. Artificial Intelligence: A Modern Approach. 4th ed. Pearson. http://aima.cs.berkeley.edu/.
Silver, David, Aja Huang, Chris J. Maddison, Arthur Guez, Laurent Sifre, George van den Driessche, Julian Schrittwieser, et al. 2016. Mastering the game of Go with deep neural networks and tree search.” Nature 529 (7587): 484–89. https://doi.org/10.1038/nature16961.

Marcel Turcotte

Marcel.Turcotte@uOttawa.ca

School of Electrical Engineering and Computer Science (EECS)

University of Ottawa