Skip to content

Vectorized AlphaZero

Jacob Marshall edited this page Jan 21, 2024 · 2 revisions

There have been countless articles written on AlphaZero and Monte Carlo Tree Search, so rather than explaining it again I'll devote my writing to details of the vectorized implementation.

Before reading, I'd suggest you read the original AlphaZero paper. I'd also recommend this primer on Monte Carlo Tree Search, and finally this article from Stanford walking through a toy implementation of AlphaZero.

These papers and articles focus on single-process AlphaZero, where state is maintained for a single instance of MCTS and a single environment. There have been numerous attempts to optimize and parallelize MCTS and AlphaZero, such as collecting batches of states within a single process prior to model inference, or distributing episode collecting across many processes (this is what was done in the AlphaZero paper, with 5000 TPUs!). Both techniques provide effective ways to scale up AlphaZero, but neither effectively take full advantage of the GPUs/TPUs already being used for model inference, as the MCTS algorithm and environment logic are both still computed with the CPU. Instead of process-based parallelism, this project uses vectorization to parallelize environment logic, MCTS, and model inference without data ever leaving the GPU! By encoding environment and MCTS state as multi-dimensional matrices and implementing all operations as matrix operations, every element of the AlphaZero algorithm is vectorized, meaning that instead of being constrained by CPU/GPU instances, processes, and memory, we are instead constrained by GPU memory and throughput. This means the TurboZero implementation can collect episodes and train much, much quicker than traditional implementations by vectorizing across thousands of environments and MCTS instances at once, performing each step of simulation, search, and inference across all instances at once. This page will mostly focus on vectorized MCTS, as it is the main departure from traditional AlphaZero implementations.

Implementation Details

Rather than try to explain the vectorized implementation in broad strokes, it'll be easier to walk through it line-by-line, explaining as we go. If you'd like to follow along, the source code can be found here.

State

First, let's set up a way to maintain state. We'll do so with a few PyTorch tensors.

Most tensors we create will have a 0th-dimension of size N, where N is the number of environments. All tensors maintaining state in this implementation will follow the same pattern.

Recall that Monte Carlo Tree Search maintains $Q(S)$, $N(S)$, $W(S)$, and $P(S)$ for each state $S$ in the search tree. Storing $Q(S)$ is unnessecary, since

$Q(s) = \frac{W(s)}{N(s)}$

so we can just maintain tensors for $N(S)$, $W(S)$, and $P(S)$:

# garbage node + root node + other nodes
self.total_slots = 2 + self.max_nodes 
self.p_vals = torch.zeros(
    (self.parallel_envs, self.total_slots, self.policy_size),
    dtype=torch.float32,
    device=self.device,
    requires_grad=False
)

self.n_vals = torch.zeros(
    (self.parallel_envs, self.total_slots, self.policy_size),
    dtype=torch.long,
    device=self.device,
    requires_grad=False
)

self.w_vals = torch.zeros(
    (self.parallel_envs, self.total_slots, self.policy_size),
    dtype=torch.float32,
    device=self.device,
    requires_grad=False
)

Each of the tensors will be of shape $(N, M, A)$, where $N$ is the number of parallel environments, $M$ is the maximum number of nodes, and $A$ is the size of the (flattened) environment action space. You can just think of $A$ as the quantity of all of the discrete actions that could possibly ever be taken in an environment. We store data for explored nodes in each environment along the 1st dimension, e.g.

self.p_vals[0,1,15] 

yields the p-value for action_id = 15, node_id = 1, environment_index = 0. You'll notice we allocate an extra node index (1st-dim) at index 0. We use this as a garbage index and do not store any data here. It gives us a few tricks later on.

In order to keep track of the current node id for each instance of MCTS, we can use a one-dimensional tensor:

self.cur_nodes = torch.ones(
     (self.parallel_envs,), 
     dtype=torch.int64, 
     device=self.device, 
     requires_grad=False
) 

Next, we need a way to map an action taken at a given node to the corresponding child node id. We can do that with a similarly sized tensor:

self.next_idx = torch.zeros(
    (self.parallel_envs, self.total_slots, self.policy_size),
    dtype=torch.long,
    device=self.device,
    requires_grad=False
)

such that:

self.next_idx[3,2,8] 

gives the child node_id when action_id = 8 is taken at node_id=2, in environment_id=3.

We also need to maintain a few tensors for the purpose of backpropagation.

self.actions = torch.zeros(
    (self.parallel_envs, self.max_nodes), 
    dtype=torch.int64, 
    device=self.device, 
    requires_grad=False
)
self.visits = torch.zeros(
    (self.parallel_envs, self.max_nodes),
    dtype=torch.int64, 
    device=self.device, 
    requires_grad=False
)

self.actions will track the actions taken on the current path from the root, and self.visits will track the node ids visited. Maintaining these tensors lets us perform backpropagation across all environments in a single operation rather than doing so iteratively.

Finally, we also allocate a few helper tensors to let us execute common indexing operations:

self.env_indices = torch.arange(self.parallel_envs, dtype=torch.int64, device=self.device, requires_grad=False)
self.env_indices_expnd = self.env_indices.view(-1, 1)
self.slots_aranged = torch.arange(self.total_slots, dtype=torch.int64, device=self.device, requires_grad=False)

There are actually a few more components to state that we'll touch on as they come up, but we have enough pieces to begin explaining the algorithm itself.

MCTS Iteration

Recall that in AlphaZero, MCTS will run for a fixed number of iterations per state evaluation, where each iteration traverses from the root until an unvisited node is encountered, which is then added to the tree. Parameterizing MCTS by its number of iterations is useful in single-process applications, it's simple and grokkable.

In our vectorized implementations, this gets a bit trickier. A single traditional MCTS iteration from root to leaf traverses a variable number of edges, depending on the state of the evaluator, environment, and search tree. This means that the computation required to execute a single MCTS iteration across many environments could vary greatly from environment to environment, which is something we'd like to avoid in a vectorized implementation. Instead, we'd like to apply the same operations to each environment in parallel, and apply the same quantity of total operations to each environment.

To achieve this, instead of defining an MCTS iteration as traversing from a root to a leaf, we instead define an iteration as traversing a single edge: we give an MCTS evaluator a budget of edges to travel along. As you'll see, traversing an edge requires a set of fixed operations, meaning that we can apply the same operations to each environment and search tree, regardless of state.

Operations like backpropagation will of course be applied selectively to leaf nodes, but it is not hard to apply operations selectively while still keeping all operations vectorizable.

So our psuedo-code for traversal-budgeted MCTS looks something like:

def mcts_evaluate():
    # initialize
    for i in range(num_iterations):
        # choose an action
        # traverse the corresponding edge
        # backpropagate
    
    # return root node n-vals

There are a few things to take care of before we start building a tree:

  • initialize search memory
  • get policy for root node
  • apply Dirichlet noise to the policy
  • save the root node state

This is done with the following code:

    self.reset_search()

	# get root node policy
    with torch.no_grad():
        policy_logits, _ = evaluation_fn(self.env)

    # set root node policy, apply dirilecht noise
    self.p_vals[self.env_indices, self.cur_nodes] = \
	    (torch.softmax(policy_logits, dim=1) * (1 - self.dirichlet_e)) \
        + (self.dirilecht.rsample(torch.Size((self.parallel_envs,))) * self.dirichlet_e) 
        
    # save root node, so that we can load it again when a leaf node is reached
    saved = self.env.save_node()
        
    # store player id of current player at root node
    cur_players = self.env.cur_players.clone()

Re-initializing Search Memory

The relevant parts of MCTS memory are reset using self.reset_search(), which is implemented as follows:

def reset_search(self) -> None:
    self.depths.fill_(1)
    self.cur_nodes.fill_(1)
    self.visits.zero_()
    self.visits[:, 0] =  1
    self.actions.zero_()

Since we begin iterating at the root node, depths are all reset to 1. Similarly cur_nodes is set equal to the root node_id (1), the visits tensor is cleared out save for the root node, and the actions tensor is cleared.

Note that we do not reset any of the node data tensors ($W, P, N$, next_idx). One of the neat features of this implementation of MCTS is that subtrees are persisted (a feature that DeepMind's mctx project lacks). Persisting subtrees is implemented separately, we'll touch on that later. For now we can assume that all other state tensors are ready to go.

Root Node Policy + Dirichlet Noise

Next, we need to populate the root node policy. We use an evaluation function passed to MCTS, this could be a heuristic or a random number generator, but in the case of AlphaZero it's a neural network. It doesn't really matter what it is, the point is that it generates a policy distribution (set of probabilities for each possible action), and an evaluation value. We just care about the policy distribution right now:

with torch.no_grad():
     policy_logits, _ = evaluation_fn(self.env)

Next, since we are at the root node, we apply Dirichlet noise to the probability distribution, according to parameters $\alpha, \epsilon$ by the following: $$ P_{root}(S) = (P(S) * (1 - \epsilon)) + (Dir(\alpha) * \epsilon)$$ (This is covered on page 14 of the AlphaZero paper) We perform this operation and store the noisy distribution with:

self.p_vals[self.env_indices, self.cur_nodes] = \
	    (torch.softmax(policy_logits, dim=1) * (1 - self.dirichlet_e)) \
        + (self.dirilecht.rsample(torch.Size((self.parallel_envs,))) * self.dirichlet_e)

self.dirichlet is initialized like this:

self.dirilecht = torch.distributions.dirichlet.Dirichlet(torch.full(
    (self.policy_size,), 
    self.dirichlet_a, 
    device=self.device, 
    dtype=torch.float32, 
    requires_grad=False)
)

Saving/Retrieving Root Node State

Upon reaching a leaf node, we'll need to return to the root node to begin traversing the tree again. The simplest way to do this is to cache the root state. All TurboZero Envs implement save_node() and load_node(env_ids), which we use to save and load the root node as needed during MCTS.

self.save_node

We also keep track of the current player id at the root node for purposes of reward assignment later on:

cur_players = self.env.cur_players.clone()

Now we're ready to jump into the code for a single iteration (remember: this means edge traversal. We can break an edge traversal down into a few main steps:

  1. Choose an action to take for each environment
  2. Take those actions
  3. If the resultant state is a leaf node, give it a new id and tell the parent where to find it
  4. Evaluate the new state
  5. Propagate rewards back along the traversal from root to leaf if this node is a leaf node
  6. Update MCTS state tensors
  7. Return the to root if we've reached a leaf

The full code for one full iteration looks like this, we will break each piece down:

# STEP 1 (Choose an action to take for each environment)
actions = self.choose_action()
# STEP 2 (Take those actions)
self.env.step(actions)
master_action_indices =  self.next_idx[self.env_indices, self.cur_nodes, actions]
# STEP 3 (If the resultant state is a leaf node, give it a new id and tell the parent where to find it)
unvisited = master_action_indices ==  0
in_bounds = ~((self.next_empty >= self.total_slots) & unvisited)
master_action_indices += self.next_empty * in_bounds * unvisited
self.next_empty += 1 * in_bounds * unvisited
self.next_idx[self.env_indices, self.cur_nodes, actions] = master_action_indices
self.visits[self.env_indices, self.depths] = master_action_indices
self.actions[self.env_indices, self.depths - 1] = actions
self.parents[self.env_indices, master_action_indices] =  self.cur_nodes
self.cur_nodes = master_action_indices
# STEP 4 (Evaluate the new state)
with torch.no_grad():
	policy_logits, values =  evaluation_fn(self.env)
self.p_vals[self.env_indices, self.cur_nodes] = torch.softmax(policy_logits, dim=1)
# STEP 5 (Propagate rewards back along the traversal from root to leaf if this node is a leaf node)
terminated =  self.env.is_terminal()
rewards = (self.env.get_rewards() * terminated) + (values.view(-1) *  ~terminated)
rewards = (rewards * (cur_players ==  self.env.cur_players)) + ((1-rewards) * (cur_players !=  self.env.cur_players))
rewards.unsqueeze_(1)
is_leaf = unvisited | terminated
valid = torch.roll(self.visits,  -1,  1) >  0
valid[:,-1] =  0
leaf_inc = valid * is_leaf.long().view(-1,  1)
self.n_vals[self.env_indices_expnd, self.visits, self.actions] += leaf_inc
self.w_vals[self.env_indices_expnd, self.visits, self.actions] += (rewards * leaf_inc *  self.reward_indices) + ((1-rewards) * leaf_inc *  ~self.reward_indices)
# STEP 6 (Update MCTS state tensors)
self.depths *=  ~is_leaf
self.depths +=  1
self.max_depths = torch.max(self.max_depths,  self.depths)
self.visits[:, 1:] *=  ~is_leaf.view(-1,  1)
self.actions *=  ~is_leaf.view(-1,  1)
# STEP 7 (Return the to root if we've reached a leaf)
self.env.load_node(is_leaf, saved)
self.cur_nodes *=  ~is_leaf
self.cur_nodes += is_leaf

Choosing Actions

Recall that AlphaZero uses a variant of the PUCT (Polynomial Upper Confidence Trees) algorithm to choose which action $a_{t}$ to explore, which is given as follows:

$Q(s, a) = \frac{W(s,a)}{N(s,a)}$

$U(s,a) = c_{puct}P(s,a)\frac{\sqrt{\sum_{b}{N(s,b)}}}{1+N(s,a)}$

$a_{t} = argmax_{a}(Q(s,a) + U(s,a))$

where $N(s,a), P(s,a), W(s,a)$ are $N, P, W$ values for an action $a$ taken in state $s$, and $c_{puct}$ is a constant hyperparameter. This algorithm tends to initially prefer high probability $P$ actions

This is fairly trivial to implement in a vectorized fashion:

def choose_action(self) -> torch.Tensor:
	visits = self.n_vals[self.env_indices, self.cur_nodes]
	zero_visits = (visits == 0)
    visits_augmented = visits + zero_visits
    q_values = self.w_vals[self.env_indices, self.cur_nodes] / visits_augmented
    n_sum = visits.sum(dim=1, keepdim=True)
    probs = self.p_vals[self.env_indices, self.cur_nodes]
    puct_scores = q_values + (self.puct_coeff * probs * torch.sqrt(1 + n_sum) / (1 + visits))
	puct_scores *= self.env.get_legal_actions()
	return torch.argmax(puct_scores, dim=1)

First, we extract the $N$-values (visit counts) for the current node across each of the environment instances:

visits = self.n_vals[self.env_indices, self.cur_nodes]

Next, $Q(S)$ is calculated. $Q(S) = \frac{W(S)}{N(S)}$ 1 is added to any $N(S,a)$ where $N(S,a) = 0$ to avoid division by zero.

visits = self.n_vals[self.env_indices, self.cur_nodes]
zero_visits = (visits == 0)
visits_augmented = visits + zero_visits
q_values = self.w_vals[self.env_indices, self.cur_nodes] / visits_augmented

Next, we calculate $U(S)$ for each action per its formula:

$U(s,a) = c_{puct}P(s,a)\frac{\sqrt{\sum_{b}{N(s,b)}}}{1+N(s,a)}$

and get scores for each action, given by:

$Q(s,a) + U(s,a)$

n_sum = visits.sum(dim=1, keepdim=True) 
probs = self.p_vals[self.env_indices, self.cur_nodes] 
puct_scores = q_values + (self.puct_coeff * probs * torch.sqrt(1 + n_sum) / (1 + visits)) 

Next, we zero out any illegal actions, so that they are not chosen.

puct_scores *= self.env.get_legal_actions()

Finally, we choose the legal action in each environment instance with the highest puct score:

return torch.argmax(puct_scores, dim=1)

Taking the Actions

Back to evaluation, we return the chosen actions and apply them to the environments. This will update each environment's internal state.

actions = self.choose_action()
self.env.step(actions)

Create New Nodes

Within a parent node, we use an action id to index into various pieces of information for a direct child node ($P$-value, $W$-value, etc.). To access the child's children, we must visit the child's master index, which can be looked up by mapping the action representing the edge from the parent to the child -- to its master index.

We can check if the child node has already been mapped to a master index via:

master_action_indices = self.next_idx[self.env_indices, self.cur_nodes, actions]

if it hasn't yet been mapped and is unvisited, its value in next_idx will be 0:

unvisited = master_action_indices ==  0

Next, we need to make sure that we haven't already run out of space for new nodes. During initialization, self.next_empty is allocated with shape $(N, M)$, which will track the next empty node slot for each environment:

self.next_empty = torch.full_like(
    self.env_indices, 
    2, # initialize to 2, 0th index is garbage, 1st index is root
    dtype=torch.int64, 
    device=self.device, 
    requires_grad=False
)

We go out of bounds if:

  • Space for a new node is needed
  • next_empty > total # of slots

So we check if we stay in-bounds via:

in_bounds = ~((self.next_empty >= self.total_slots) & unvisited)

Then, if the next node is unvisited, we assign it a new master index, if there's space, and update next_empty:

master_action_indices += self.next_empty * in_bounds * unvisited
self.next_empty += 1 * in_bounds * unvisited

Note that for instances where the master index has already been assigned (i.e. the node has already been visited), these operations have no effect. Next, we update self.next_idx to reflect any newly assigned master indices.

self.next_idx[self.env_indices, self.cur_nodes, actions] = master_action_indices

During initialization, self.depths is also initialized, which simply tracks the current depth from the root across each instance.

self.depths = torch.ones(
    (self.parallel_envs,), 
    dtype=torch.int64, 
    device=self.device, 
    requires_grad=False
)

We use this to add the traversed edges to self.visits and self.actions at the appropriate index:

self.visits[self.env_indices, self.depths] = master_action_indices
self.actions[self.env_indices, self.depths - 1] = actions

Finally, we update self.parents to store the mapping from parent back to child, and set self.cur_nodes to the new index.

self.parents[self.env_indices, master_action_indices] = self.cur_nodes
# cur nodes should now reflect the taken actions
self.cur_nodes = master_action_indices

Evaluating the new State

Now we're ready to evaluate the new state. As we discussed, MCTS supports any kind of evaluation function, but the one we're using for AlphaZero will produce an approximate policy and state evaluation using a trained neural network:

# get (policy distribution, evaluation) from evaluation function
with torch.no_grad():
    policy_logits, values = evaluation_fn(self.env)

The returned policy logits will not be normalized, so we'll need to pass them through a softmax prior to storing them:

self.p_vals[self.env_indices, self.cur_nodes] = torch.softmax(policy_logits, dim=1)

Backpropagation

Now that we have an evaluation for the current state from our evaluation function, it's time to backpropagate values and update visits counts in whichever instances we have encountered a leaf node.

The values propagated should be:

  • The final reward, if the environment is now terminated
  • The evaluation from the evaluation function, otherwise We'll assign these values to rewards:
terminated = self.env.is_terminal()
rewards = (self.env.get_rewards() * terminated) + (values.view(-1) * ~terminated)

We'd actually like to have the rewards from the perspective of the player who's turn it was at the root node, which we can achieve with:

rewards = (rewards * (cur_players == self.env.cur_players)) \
    + ((1-rewards) * (cur_players != self.env.cur_players))

Next, for all leaf (or terminated) nodes, we'd like to increment the $N$ and $W$ values for all nodes visited on our path from root to leaf. self.actions and self.visits already hold these indices, so all that's left to do is properly index into self.n_vals and self.w_vals.

is_leaf = unvisited | terminated
valid = torch.roll(self.visits, -1, 1) > 0
valid[:,-1] = 0
leaf_inc = valid * is_leaf.long().view(-1, 1)
self.n_vals[self.env_indices_expnd, self.visits, self.actions] += leaf_inc

To increment $W$ values, we need to make sure rewards reflect the current player's perspective at each node along the path. We initialize a tensor self.reward_indices:

self.reward_indices = self.build_reward_indices(env.num_players)
def build_reward_indices(self, num_players: int) -> torch.Tensor:
    num_repeats = math.ceil(self.max_nodes / num_players)
    return torch.tensor([1] + [0] * (num_players - 1), dtype=torch.bool, device=self.device).repeat(num_repeats)[:self.max_nodes].view(1, -1)

which holds a boolean tensor whose indices reflect which nodes along the path from root to leaf should have their rewards flipped. This works well for single and two-player environments but may need to be adjusted/modified for environments with more than two players.

We use this mask to flip the rewards along the traversal path accordingly.

self.w_vals[self.env_indices_expnd, self.visits, self.actions] += (rewards * leaf_inc * self.reward_indices) + ((1-rewards) * leaf_inc * ~self.reward_indices)

Keep in mind that we are only back propagating values in envs that have reached a leaf/terminal node! These operations do nothing for instances where the current node is not terminal or a leaf.

Other State Updates

Now that we've run backpropagation for any leaf nodes encountered, it's time to update our other state tensors.

First, we'll increment depths to reflect having traveled 1 node further, and reset it to 1 if we've reached a leaf node:

self.depths *= ~is_leaf
self.depths += 1
self.max_depths = torch.max(self.max_depths, self.depths)

(we maintain self.max_depths for persisting subtrees)

Similarly, we reset self.actions and self.visits if a leaf/terminal state has been reached:

self.visits[:, 1:] *= ~is_leaf.view(-1, 1)
self.actions *= ~is_leaf.view(-1, 1)

(leaf states will be multiplied by zero!)

Returning to the Root Node

The final step of one iteration is resetting self.cur_nodes to the root node (index = 1) for any leaf nodes:

self.env.load_node(is_leaf, saved)
self.cur_nodes *= ~is_leaf
self.cur_nodes += is_leaf

We use env.load_node to load the environment state that we cached earlier!

Final Evaluation

Once we've completed the specified number of iterations, we need to return the final policy distribution and evaluation.

The policy distribution is simply $N(root)$.

The evaluation can be interpreted to be $argmax_{a}(Q(root, a))$, which is how I implement it. This value isn't actually used for any computation or training, but it's nice to have an evaluation for debugging and demonstrative purposes. We train the neural network by comparing the initial root evaluation to the final game reward, rather than use an evaluation obtained via MCTS.

# return to the root node
self.cur_nodes.fill_(1)
# reload all envs to the root node
self.env.load_node(self.cur_nodes.bool(), saved)
# self.cur_nodes.bool() is a tensor filled with ones, so we reload every environment
# return visited counts at the root node
max_inds = self.n_vals[self.env_indices, self.cur_nodes].argmax(dim=1)
max_q_vals = self.w_vals[self.env_indices, self.cur_nodes, max_inds] / self.n_vals[self.env_indices, self.cur_nodes, max_inds]
return self.n_vals[self.env_indices, self.cur_nodes], max_q_vals

We just made it though every line of an MCTS evaluation! You can see the full source code here: mcts.py

AlphaZero itself is configured to sample from this policy distribution, with a temperature parameter $\tau$ applied:

$P_{f}(S) = P(S)^{1/\tau}$

This is a hyperparameter that can be adjusted. A value of 1.0 has no effect on the distribution. Values smaller than 1.0 will make it more high-probability nodes more likely to be selected, while values larger than 1.0 will soften the probabilities s.t. all probabilities are closer together. Setting $\tau$ to a very small value is effectively the same as $argmax$.

def choose_actions(self, visits: torch.Tensor) -> torch.Tensor:
    if self.config.temperature > 0:
	return torch.multinomial(
	    torch.pow(visits, 1/self.config.temperature), 
	    1, 	
	    replacement=True
	).flatten()
    else:
        return rand_argmax_2d(visits).flatten()

(from alphazero.py)

Persisting MCTS Subtrees

One of the main contributions of this project is a vectorized implementation of MCTS that supports persisting subtrees. I am not aware of another implementation that has this feature.

DeepMind's mctx omits this feature, because it has more of a focus on model-free algorithms like MuZero. MuZero can freely discard search trees after MCTS iteration -- as each search tree node is represented by an approximation of the actual state rather than the state itself. As soon as an action is taken the true next state is available, which makes saving subtrees unnessecary.

With AlphaZero, we know the true state throughout the entire search process, so holding onto the relevant subtree after taking an action is very useful, and lets us avoid repeating a lot of work, which improves the performance of the algorithm by allowing it to visit new states rather than re-visit old ones.

Subtrees are persisted using the following two functions, which I'll explain line-by-line:

def propogate_root_subtrees(self):
    self.subtrees.zero_()
    self.subtrees += self.slots_aranged
    self.parents[:, 0] = 0
    for _ in range(self.max_depths.max() + 1):
	parent_subtrees = self.subtrees[self.env_indices_expnd, self.parents]
	self.subtrees = (parent_subtrees * (parent_subtrees > 1)) + (self.subtrees * (parent_subtrees <= 1))

def load_subtree(self, actions: torch.Tensor):
    self.propogate_root_subtrees()
	
    subtree_master_indices = self.next_idx[self.env_indices, 1, actions]
    is_real = subtree_master_indices > 1
    new_nodes = (self.subtrees == subtree_master_indices.view(-1, 1))
	
    translation = new_nodes * is_real.view(-1, 1) * new_nodes.long().cumsum(dim=1)
	
    old_subtree_idxs = self.slots_aranged * new_nodes
    self.next_empty = torch.amax(translation, dim=1) + 1
    self.next_empty.clamp_(min=2)
    erase = self.slots_aranged * (self.slots_aranged >= self.next_empty.view(-1, 1))
	
    self.w_vals[self.env_indices_expnd, translation] = self.w_vals[self.env_indices_expnd, old_subtree_idxs]
    self.w_vals[self.env_indices_expnd, erase] = 0
	
    self.n_vals[self.env_indices_expnd, translation] = self.n_vals[self.env_indices_expnd, old_subtree_idxs]
    self.n_vals[self.env_indices_expnd, erase] = 0
	
    self.p_vals[self.env_indices_expnd, translation] = self.p_vals[self.env_indices_expnd, old_subtree_idxs]
    self.p_vals[self.env_indices_expnd, erase] = 0
	
    self.next_idx[self.env_indices_expnd, translation] = translation[self.env_indices.view(-1, 1, 1), self.next_idx]
    self.next_idx[self.env_indices_expnd, erase] = 0

    self.parents[self.env_indices_expnd, translation] = translation[self.env_indices_expnd, self.parents]
    self.parents[self.env_indices_expnd, erase] = 0

    self.max_depths -= 1
    self.max_depths.clamp_(min=1)

We can divide persisting subtrees into a few steps:

  1. Assign each node to a root subtree
  2. Re-index subtree nodes
  3. Copy data to new node indices
  4. Erase any other data that's left
  5. Adjust MCTS state tensors

Assigning Nodes to Root Subtrees

First, we need to figure out which root subtree each node in the search tree belongs to. We'd like the new tree to only include nodes contained within the subtree corresponding to a chosen action.

We create self.subtrees (shape $N$ x $M$) to store the root subtree each node belongs to when MCTS is initialized.

self.subtrees = torch.zeros(
    (self.parallel_envs, self.total_slots), 
    dtype=torch.long, 
    device=self.device, 
    requires_grad=False
) # index by master id

First, we zero out each subtree, and set each node's subtree equal to its own ID.

def propogate_root_subtrees(self):
    self.subtrees.zero_()
    self.subtrees += self.slots_aranged # node index value = node index
    self.parents[:, 0] = 0

The plan is to iteratively set each node's subtree equal to its parent subtree, until all nodes subtrees are assigned one of the root node's children. We will have to do this at most $D$ times for each instance, where $D$ is the maximum depth of the tree. Fortunately, we saved this value for each MCTS instance in self.max_depths during MCTS iteration, We know how many times we'll need to iterate at most -- it's the maximum of self.max_depths, or self.max_depths.max()!

for _ in range(self.max_depths.max() + 1):
    parent_subtrees = self.subtrees[self.env_indices_expnd, self.parents]
    self.subtrees = (parent_subtrees * (parent_subtrees > 1)) + \
		(self.subtrees * (parent_subtrees <= 1))

For each of $D_{max}$ iterations, we set each node's subtree equal to its parent subtree, except in the case where its parent subtree is the root, so that we only collapse as far as the root's children. After we're done iterating, self.subtrees contains the immediate root subtree that each and every node belongs to.

Re-indexing Subtree Nodes

Next, we need to select the subtree root indices corresponding to each action taken in the environments. We can find these by accessing the action-node_id mapping of the current tree root node.

subtree_master_indices = self.next_idx[self.env_indices, 1, actions]

In edge cases, it's possible the chosen action does not have a node assigned to it yet. We can check via:

is_real = subtree_master_indices > 1

Then, we can filter nodes that are included in the action's subtree:

new_nodes = (self.subtrees == subtree_master_indices.view(-1, 1))

This gives us a boolean tensor with False at indices of nodes not included in the subtree, and True at indices of nodes that do belong to the given subtree.

Next, we'd like to compress the subtree nodes, such that they take up the first $T_{a}$ slots, where $T_{a}$ is the size of the subtree for the taken action. We can employ a trick using torch.cumsum to map a node's current index to its index relative to the other subtree nodes. It's a pretty slick line of code.

translation = new_nodes * is_real.view(-1, 1) * new_nodes.long().cumsum(dim=1)

We can also create a filtered index containing current node ids for nodes in the subtree:

old_subtree_idxs = self.slots_aranged * new_nodes

Copying data to New Node Indices

Now we're ready to re-index all of the subtree node data. We need to re-index $W$, $N$ and $P$ values, the action-node_id mapping self.next_idx, and self.parents. We can actually achieve this with a single line of code:

self.w_vals[self.env_indices_expnd, translation] = self.w_vals[self.env_indices_expnd, old_subtree_idxs]

We use the translation to index into each data tensor (in this case, $W$), and set this equal to the values at the current/old node indices of nodes in the subtreee. Any node not in the subtree will have an index of 0, which we have assigned to be a garbage index, letting us complete this operation elegantly in a single line!

We can do the same for each of the other data tensors:

self.n_vals[self.env_indices_expnd, translation] = self.n_vals[self.env_indices_expnd, old_subtree_idxs]
self.p_vals[self.env_indices_expnd, translation] = self.p_vals[self.env_indices_expnd, old_subtree_idxs]
self.parents[self.env_indices_expnd, translation] = translation[self.env_indices_expnd, self.parents]
self.next_idx[self.env_indices_expnd, translation] = translation[self.env_indices.view(-1, 1, 1), self.next_idx]

self.parents and self.next_idx require a little more thought, see if you can figure out the ideas!

Erasing Extra Data

Now that the subtree data is in the correct location, any indices we haven't yet touched will contain old data, which we need to zero out. We just need to know where the new subtree ends for each MCTS instance, and then zero out the slice of data occurring after that index. We know that the new next empty tensor should be the maximum new index + 1:

self.next_empty = torch.amax(translation, dim=1) + 1

I'm sure there is a way to use this tensor directly as an index to slice our desired data, but I'm not sure how. Instead, we can gather all the indices not in our subtree occurring after this value:

erase = self.slots_aranged * (self.slots_aranged >= self.next_empty.view(-1, 1))
self.next_empty.clamp_(min=2)

Then, just set these to 0!

self.w_vals[self.env_indices_expnd, erase] = 0
self.n_vals[self.env_indices_expnd, erase] = 0
self.p_vals[self.env_indices_expnd, erase] = 0
self.next_idx[self.env_indices_expnd, erase] = 0
self.parents[self.env_indices_expnd, erase] = 0

Adjusting MCTS State

State should now be in good shape, we've remapped all the subtree nodes and updated self.next_empty. All that's left to do is to update self.max_depths, which we can just decrease by 1 have height $h_{s} &lt;= h_{root} - 1$. This is an approximation, since the highest subtree could have been in a branch that was pruned, but an upper bound is good enough for our purposes.

self.max_depths -= 1
self.max_depths.clamp_(min=1)

And that's it! This allows us to save lots of work by persisting subtrees from previous MCTS evaluations! You can view the full source code here.

Parameters

  • num_iters: number of iterations/edges traversed before final evaluation
  • max_nodes: memory for this many nodes is allocated -- if no memory is free, every iteration that would result in traversing to a new, unexplored node instead resets back to the root node
  • puct_coeff: $c_{puct}$ in $U(s,a) = c_{puct}P(s,a)\frac{\sqrt{\sum_{b}{N(s,b)}}}{1+N(s,a)}$, higher values bias towards exploration instead of exploitation
  • dirichlet_epsilon: $d_{e}$ in $P_{n}(S) = (P(S) * (1-d_{e})) + (Dirichlet(d_{a}) * d_{e})$, proportion of root policy composed of Dirichlet noise
  • dirichlet_alpha: $d_{a}$ in $P_{n}(S) = (P(S) * (1-d_{e})) + (Dirichlet(d_{a}) * d_{e})$, magnitude of Dirichlet noise added to root policy
  • temperature: $\tau$ in $P_{f}(S) = P(S)^{1/\tau}$, adjusts policy sampling
Clone this wiki locally