-
-
Notifications
You must be signed in to change notification settings - Fork 5
Vectorized AlphaZero
There have been countless articles written on AlphaZero and Monte Carlo Tree Search, so rather than explaining it again I'll devote my writing to details of the vectorized implementation.
Before reading, I'd suggest you read the original AlphaZero paper. I'd also recommend this primer on Monte Carlo Tree Search, and finally this article from Stanford walking through a toy implementation of AlphaZero.
These papers and articles focus on single-process AlphaZero, where state is maintained for a single instance of MCTS and a single environment. There have been numerous attempts to optimize and parallelize MCTS and AlphaZero, such as collecting batches of states within a single process prior to model inference, or distributing episode collecting across many processes (this is what was done in the AlphaZero paper, with 5000 TPUs!). Both techniques provide effective ways to scale up AlphaZero, but neither effectively take full advantage of the GPUs/TPUs already being used for model inference, as the MCTS algorithm and environment logic are both still computed with the CPU. Instead of process-based parallelism, this project uses vectorization to parallelize environment logic, MCTS, and model inference without data ever leaving the GPU! By encoding environment and MCTS state as multi-dimensional matrices and implementing all operations as matrix operations, every element of the AlphaZero algorithm is vectorized, meaning that instead of being constrained by CPU/GPU instances, processes, and memory, we are instead constrained by GPU memory and throughput. This means the TurboZero implementation can collect episodes and train much, much quicker than traditional implementations by vectorizing across thousands of environments and MCTS instances at once, performing each step of simulation, search, and inference across all instances at once. This page will mostly focus on vectorized MCTS, as it is the main departure from traditional AlphaZero implementations.
Rather than try to explain the vectorized implementation in broad strokes, it'll be easier to walk through it line-by-line, explaining as we go. If you'd like to follow along, the source code can be found here.
First, let's set up a way to maintain state. We'll do so with a few PyTorch tensors.
Most tensors we create will have a 0th-dimension of size N, where N is the number of environments. All tensors maintaining state in this implementation will follow the same pattern.
Recall that Monte Carlo Tree Search maintains
so we can just maintain tensors for
# garbage node + root node + other nodes
self.total_slots = 2 + self.max_nodes
self.p_vals = torch.zeros(
(self.parallel_envs, self.total_slots, self.policy_size),
dtype=torch.float32,
device=self.device,
requires_grad=False
)
self.n_vals = torch.zeros(
(self.parallel_envs, self.total_slots, self.policy_size),
dtype=torch.long,
device=self.device,
requires_grad=False
)
self.w_vals = torch.zeros(
(self.parallel_envs, self.total_slots, self.policy_size),
dtype=torch.float32,
device=self.device,
requires_grad=False
)
Each of the tensors will be of shape
self.p_vals[0,1,15]
yields the p-value for action_id = 15, node_id = 1, environment_index = 0. You'll notice we allocate an extra node index (1st-dim) at index 0. We use this as a garbage index and do not store any data here. It gives us a few tricks later on.
In order to keep track of the current node id for each instance of MCTS, we can use a one-dimensional tensor:
self.cur_nodes = torch.ones(
(self.parallel_envs,),
dtype=torch.int64,
device=self.device,
requires_grad=False
)
Next, we need a way to map an action taken at a given node to the corresponding child node id. We can do that with a similarly sized tensor:
self.next_idx = torch.zeros(
(self.parallel_envs, self.total_slots, self.policy_size),
dtype=torch.long,
device=self.device,
requires_grad=False
)
such that:
self.next_idx[3,2,8]
gives the child node_id when action_id = 8 is taken at node_id=2, in environment_id=3.
We also need to maintain a few tensors for the purpose of backpropagation.
self.actions = torch.zeros(
(self.parallel_envs, self.max_nodes),
dtype=torch.int64,
device=self.device,
requires_grad=False
)
self.visits = torch.zeros(
(self.parallel_envs, self.max_nodes),
dtype=torch.int64,
device=self.device,
requires_grad=False
)
self.actions
will track the actions taken on the current path from the root, and self.visits
will track the node ids visited. Maintaining these tensors lets us perform backpropagation across all environments in a single operation rather than doing so iteratively.
Finally, we also allocate a few helper tensors to let us execute common indexing operations:
self.env_indices = torch.arange(self.parallel_envs, dtype=torch.int64, device=self.device, requires_grad=False)
self.env_indices_expnd = self.env_indices.view(-1, 1)
self.slots_aranged = torch.arange(self.total_slots, dtype=torch.int64, device=self.device, requires_grad=False)
There are actually a few more components to state that we'll touch on as they come up, but we have enough pieces to begin explaining the algorithm itself.
Recall that in AlphaZero, MCTS will run for a fixed number of iterations per state evaluation, where each iteration traverses from the root until an unvisited node is encountered, which is then added to the tree. Parameterizing MCTS by its number of iterations is useful in single-process applications, it's simple and grokkable.
In our vectorized implementations, this gets a bit trickier. A single traditional MCTS iteration from root to leaf traverses a variable number of edges, depending on the state of the evaluator, environment, and search tree. This means that the computation required to execute a single MCTS iteration across many environments could vary greatly from environment to environment, which is something we'd like to avoid in a vectorized implementation. Instead, we'd like to apply the same operations to each environment in parallel, and apply the same quantity of total operations to each environment.
To achieve this, instead of defining an MCTS iteration as traversing from a root to a leaf, we instead define an iteration as traversing a single edge: we give an MCTS evaluator a budget of edges to travel along. As you'll see, traversing an edge requires a set of fixed operations, meaning that we can apply the same operations to each environment and search tree, regardless of state.
Operations like backpropagation will of course be applied selectively to leaf nodes, but it is not hard to apply operations selectively while still keeping all operations vectorizable.
So our psuedo-code for traversal-budgeted MCTS looks something like:
def mcts_evaluate():
# initialize
for i in range(num_iterations):
# choose an action
# traverse the corresponding edge
# backpropagate
# return root node n-vals
There are a few things to take care of before we start building a tree:
- initialize search memory
- get policy for root node
- apply Dirichlet noise to the policy
- save the root node state
This is done with the following code:
self.reset_search()
# get root node policy
with torch.no_grad():
policy_logits, _ = evaluation_fn(self.env)
# set root node policy, apply dirilecht noise
self.p_vals[self.env_indices, self.cur_nodes] = \
(torch.softmax(policy_logits, dim=1) * (1 - self.dirichlet_e)) \
+ (self.dirilecht.rsample(torch.Size((self.parallel_envs,))) * self.dirichlet_e)
# save root node, so that we can load it again when a leaf node is reached
saved = self.env.save_node()
# store player id of current player at root node
cur_players = self.env.cur_players.clone()
The relevant parts of MCTS memory are reset using self.reset_search()
, which is implemented as follows:
def reset_search(self) -> None:
self.depths.fill_(1)
self.cur_nodes.fill_(1)
self.visits.zero_()
self.visits[:, 0] = 1
self.actions.zero_()
Since we begin iterating at the root node, depths are all reset to 1. Similarly cur_nodes is set equal to the root node_id (1), the visits tensor is cleared out save for the root node, and the actions tensor is cleared.
Note that we do not reset any of the node data tensors (
Next, we need to populate the root node policy. We use an evaluation function passed to MCTS, this could be a heuristic or a random number generator, but in the case of AlphaZero it's a neural network. It doesn't really matter what it is, the point is that it generates a policy distribution (set of probabilities for each possible action), and an evaluation value. We just care about the policy distribution right now:
with torch.no_grad():
policy_logits, _ = evaluation_fn(self.env)
Next, since we are at the root node, we apply Dirichlet noise to the probability distribution, according to parameters
self.p_vals[self.env_indices, self.cur_nodes] = \
(torch.softmax(policy_logits, dim=1) * (1 - self.dirichlet_e)) \
+ (self.dirilecht.rsample(torch.Size((self.parallel_envs,))) * self.dirichlet_e)
self.dirichlet
is initialized like this:
self.dirilecht = torch.distributions.dirichlet.Dirichlet(torch.full(
(self.policy_size,),
self.dirichlet_a,
device=self.device,
dtype=torch.float32,
requires_grad=False)
)
Upon reaching a leaf node, we'll need to return to the root node to begin traversing the tree again. The simplest way to do this is to cache the root state. All TurboZero Envs implement save_node()
and load_node(env_ids)
, which we use to save and load the root node as needed during MCTS.
self.save_node
We also keep track of the current player id at the root node for purposes of reward assignment later on:
cur_players = self.env.cur_players.clone()
Now we're ready to jump into the code for a single iteration (remember: this means edge traversal. We can break an edge traversal down into a few main steps:
- Choose an action to take for each environment
- Take those actions
- If the resultant state is a leaf node, give it a new id and tell the parent where to find it
- Evaluate the new state
- Propagate rewards back along the traversal from root to leaf if this node is a leaf node
- Update MCTS state tensors
- Return the to root if we've reached a leaf
The full code for one full iteration looks like this, we will break each piece down:
# STEP 1 (Choose an action to take for each environment)
actions = self.choose_action()
# STEP 2 (Take those actions)
self.env.step(actions)
master_action_indices = self.next_idx[self.env_indices, self.cur_nodes, actions]
# STEP 3 (If the resultant state is a leaf node, give it a new id and tell the parent where to find it)
unvisited = master_action_indices == 0
in_bounds = ~((self.next_empty >= self.total_slots) & unvisited)
master_action_indices += self.next_empty * in_bounds * unvisited
self.next_empty += 1 * in_bounds * unvisited
self.next_idx[self.env_indices, self.cur_nodes, actions] = master_action_indices
self.visits[self.env_indices, self.depths] = master_action_indices
self.actions[self.env_indices, self.depths - 1] = actions
self.parents[self.env_indices, master_action_indices] = self.cur_nodes
self.cur_nodes = master_action_indices
# STEP 4 (Evaluate the new state)
with torch.no_grad():
policy_logits, values = evaluation_fn(self.env)
self.p_vals[self.env_indices, self.cur_nodes] = torch.softmax(policy_logits, dim=1)
# STEP 5 (Propagate rewards back along the traversal from root to leaf if this node is a leaf node)
terminated = self.env.is_terminal()
rewards = (self.env.get_rewards() * terminated) + (values.view(-1) * ~terminated)
rewards = (rewards * (cur_players == self.env.cur_players)) + ((1-rewards) * (cur_players != self.env.cur_players))
rewards.unsqueeze_(1)
is_leaf = unvisited | terminated
valid = torch.roll(self.visits, -1, 1) > 0
valid[:,-1] = 0
leaf_inc = valid * is_leaf.long().view(-1, 1)
self.n_vals[self.env_indices_expnd, self.visits, self.actions] += leaf_inc
self.w_vals[self.env_indices_expnd, self.visits, self.actions] += (rewards * leaf_inc * self.reward_indices) + ((1-rewards) * leaf_inc * ~self.reward_indices)
# STEP 6 (Update MCTS state tensors)
self.depths *= ~is_leaf
self.depths += 1
self.max_depths = torch.max(self.max_depths, self.depths)
self.visits[:, 1:] *= ~is_leaf.view(-1, 1)
self.actions *= ~is_leaf.view(-1, 1)
# STEP 7 (Return the to root if we've reached a leaf)
self.env.load_node(is_leaf, saved)
self.cur_nodes *= ~is_leaf
self.cur_nodes += is_leaf
Recall that AlphaZero uses a variant of the PUCT (Polynomial Upper Confidence Trees) algorithm to choose which action
where
This is fairly trivial to implement in a vectorized fashion:
def choose_action(self) -> torch.Tensor:
visits = self.n_vals[self.env_indices, self.cur_nodes]
zero_visits = (visits == 0)
visits_augmented = visits + zero_visits
q_values = self.w_vals[self.env_indices, self.cur_nodes] / visits_augmented
n_sum = visits.sum(dim=1, keepdim=True)
probs = self.p_vals[self.env_indices, self.cur_nodes]
puct_scores = q_values + (self.puct_coeff * probs * torch.sqrt(1 + n_sum) / (1 + visits))
puct_scores *= self.env.get_legal_actions()
return torch.argmax(puct_scores, dim=1)
First, we extract the
visits = self.n_vals[self.env_indices, self.cur_nodes]
Next,
visits = self.n_vals[self.env_indices, self.cur_nodes]
zero_visits = (visits == 0)
visits_augmented = visits + zero_visits
q_values = self.w_vals[self.env_indices, self.cur_nodes] / visits_augmented
Next, we calculate
and get scores for each action, given by:
n_sum = visits.sum(dim=1, keepdim=True)
probs = self.p_vals[self.env_indices, self.cur_nodes]
puct_scores = q_values + (self.puct_coeff * probs * torch.sqrt(1 + n_sum) / (1 + visits))
Next, we zero out any illegal actions, so that they are not chosen.
puct_scores *= self.env.get_legal_actions()
Finally, we choose the legal action in each environment instance with the highest puct score:
return torch.argmax(puct_scores, dim=1)
Back to evaluation, we return the chosen actions and apply them to the environments. This will update each environment's internal state.
actions = self.choose_action()
self.env.step(actions)
Within a parent node, we use an action id to index into various pieces of information for a direct child node (
We can check if the child node has already been mapped to a master index via:
master_action_indices = self.next_idx[self.env_indices, self.cur_nodes, actions]
if it hasn't yet been mapped and is unvisited, its value in next_idx
will be 0:
unvisited = master_action_indices == 0
Next, we need to make sure that we haven't already run out of space for new nodes. During initialization, self.next_empty
is allocated with shape
self.next_empty = torch.full_like(
self.env_indices,
2, # initialize to 2, 0th index is garbage, 1st index is root
dtype=torch.int64,
device=self.device,
requires_grad=False
)
We go out of bounds if:
- Space for a new node is needed
- next_empty > total # of slots
So we check if we stay in-bounds via:
in_bounds = ~((self.next_empty >= self.total_slots) & unvisited)
Then, if the next node is unvisited, we assign it a new master index, if there's space, and update next_empty
:
master_action_indices += self.next_empty * in_bounds * unvisited
self.next_empty += 1 * in_bounds * unvisited
Note that for instances where the master index has already been assigned (i.e. the node has already been visited), these operations have no effect. Next, we update self.next_idx
to reflect any newly assigned master indices.
self.next_idx[self.env_indices, self.cur_nodes, actions] = master_action_indices
During initialization, self.depths
is also initialized, which simply tracks the current depth from the root across each instance.
self.depths = torch.ones(
(self.parallel_envs,),
dtype=torch.int64,
device=self.device,
requires_grad=False
)
We use this to add the traversed edges to self.visits
and self.actions
at the appropriate index:
self.visits[self.env_indices, self.depths] = master_action_indices
self.actions[self.env_indices, self.depths - 1] = actions
Finally, we update self.parents
to store the mapping from parent back to child, and set self.cur_nodes
to the new index.
self.parents[self.env_indices, master_action_indices] = self.cur_nodes
# cur nodes should now reflect the taken actions
self.cur_nodes = master_action_indices
Now we're ready to evaluate the new state. As we discussed, MCTS supports any kind of evaluation function, but the one we're using for AlphaZero will produce an approximate policy and state evaluation using a trained neural network:
# get (policy distribution, evaluation) from evaluation function
with torch.no_grad():
policy_logits, values = evaluation_fn(self.env)
The returned policy logits will not be normalized, so we'll need to pass them through a softmax
prior to storing them:
self.p_vals[self.env_indices, self.cur_nodes] = torch.softmax(policy_logits, dim=1)
Now that we have an evaluation for the current state from our evaluation function, it's time to backpropagate values and update visits counts in whichever instances we have encountered a leaf node.
The values propagated should be:
- The final reward, if the environment is now terminated
- The evaluation from the evaluation function, otherwise We'll assign these values to rewards:
terminated = self.env.is_terminal()
rewards = (self.env.get_rewards() * terminated) + (values.view(-1) * ~terminated)
We'd actually like to have the rewards from the perspective of the player who's turn it was at the root node, which we can achieve with:
rewards = (rewards * (cur_players == self.env.cur_players)) \
+ ((1-rewards) * (cur_players != self.env.cur_players))
Next, for all leaf (or terminated) nodes, we'd like to increment the self.actions
and self.visits
already hold these indices, so all that's left to do is properly index into self.n_vals
and self.w_vals
.
is_leaf = unvisited | terminated
valid = torch.roll(self.visits, -1, 1) > 0
valid[:,-1] = 0
leaf_inc = valid * is_leaf.long().view(-1, 1)
self.n_vals[self.env_indices_expnd, self.visits, self.actions] += leaf_inc
To increment self.reward_indices
:
self.reward_indices = self.build_reward_indices(env.num_players)
def build_reward_indices(self, num_players: int) -> torch.Tensor:
num_repeats = math.ceil(self.max_nodes / num_players)
return torch.tensor([1] + [0] * (num_players - 1), dtype=torch.bool, device=self.device).repeat(num_repeats)[:self.max_nodes].view(1, -1)
which holds a boolean tensor whose indices reflect which nodes along the path from root to leaf should have their rewards flipped. This works well for single and two-player environments but may need to be adjusted/modified for environments with more than two players.
We use this mask to flip the rewards along the traversal path accordingly.
self.w_vals[self.env_indices_expnd, self.visits, self.actions] += (rewards * leaf_inc * self.reward_indices) + ((1-rewards) * leaf_inc * ~self.reward_indices)
Keep in mind that we are only back propagating values in envs that have reached a leaf/terminal node! These operations do nothing for instances where the current node is not terminal or a leaf.
Now that we've run backpropagation for any leaf nodes encountered, it's time to update our other state tensors.
First, we'll increment depths to reflect having traveled 1 node further, and reset it to 1 if we've reached a leaf node:
self.depths *= ~is_leaf
self.depths += 1
self.max_depths = torch.max(self.max_depths, self.depths)
(we maintain self.max_depths
for persisting subtrees)
Similarly, we reset self.actions
and self.visits
if a leaf/terminal state has been reached:
self.visits[:, 1:] *= ~is_leaf.view(-1, 1)
self.actions *= ~is_leaf.view(-1, 1)
(leaf states will be multiplied by zero!)
The final step of one iteration is resetting self.cur_nodes
to the root node (index = 1) for any leaf nodes:
self.env.load_node(is_leaf, saved)
self.cur_nodes *= ~is_leaf
self.cur_nodes += is_leaf
We use env.load_node
to load the environment state that we cached earlier!
Once we've completed the specified number of iterations, we need to return the final policy distribution and evaluation.
The policy distribution is simply
The evaluation can be interpreted to be
# return to the root node
self.cur_nodes.fill_(1)
# reload all envs to the root node
self.env.load_node(self.cur_nodes.bool(), saved)
# self.cur_nodes.bool() is a tensor filled with ones, so we reload every environment
# return visited counts at the root node
max_inds = self.n_vals[self.env_indices, self.cur_nodes].argmax(dim=1)
max_q_vals = self.w_vals[self.env_indices, self.cur_nodes, max_inds] / self.n_vals[self.env_indices, self.cur_nodes, max_inds]
return self.n_vals[self.env_indices, self.cur_nodes], max_q_vals
We just made it though every line of an MCTS evaluation! You can see the full source code here: mcts.py
AlphaZero itself is configured to sample from this policy distribution, with a temperature parameter
This is a hyperparameter that can be adjusted. A value of 1.0 has no effect on the distribution. Values smaller than 1.0 will make it more high-probability nodes more likely to be selected, while values larger than 1.0 will soften the probabilities s.t. all probabilities are closer together. Setting
def choose_actions(self, visits: torch.Tensor) -> torch.Tensor:
if self.config.temperature > 0:
return torch.multinomial(
torch.pow(visits, 1/self.config.temperature),
1,
replacement=True
).flatten()
else:
return rand_argmax_2d(visits).flatten()
(from alphazero.py)
One of the main contributions of this project is a vectorized implementation of MCTS that supports persisting subtrees. I am not aware of another implementation that has this feature.
DeepMind's mctx omits this feature, because it has more of a focus on model-free algorithms like MuZero. MuZero can freely discard search trees after MCTS iteration -- as each search tree node is represented by an approximation of the actual state rather than the state itself. As soon as an action is taken the true next state is available, which makes saving subtrees unnessecary.
With AlphaZero, we know the true state throughout the entire search process, so holding onto the relevant subtree after taking an action is very useful, and lets us avoid repeating a lot of work, which improves the performance of the algorithm by allowing it to visit new states rather than re-visit old ones.
Subtrees are persisted using the following two functions, which I'll explain line-by-line:
def propogate_root_subtrees(self):
self.subtrees.zero_()
self.subtrees += self.slots_aranged
self.parents[:, 0] = 0
for _ in range(self.max_depths.max() + 1):
parent_subtrees = self.subtrees[self.env_indices_expnd, self.parents]
self.subtrees = (parent_subtrees * (parent_subtrees > 1)) + (self.subtrees * (parent_subtrees <= 1))
def load_subtree(self, actions: torch.Tensor):
self.propogate_root_subtrees()
subtree_master_indices = self.next_idx[self.env_indices, 1, actions]
is_real = subtree_master_indices > 1
new_nodes = (self.subtrees == subtree_master_indices.view(-1, 1))
translation = new_nodes * is_real.view(-1, 1) * new_nodes.long().cumsum(dim=1)
old_subtree_idxs = self.slots_aranged * new_nodes
self.next_empty = torch.amax(translation, dim=1) + 1
self.next_empty.clamp_(min=2)
erase = self.slots_aranged * (self.slots_aranged >= self.next_empty.view(-1, 1))
self.w_vals[self.env_indices_expnd, translation] = self.w_vals[self.env_indices_expnd, old_subtree_idxs]
self.w_vals[self.env_indices_expnd, erase] = 0
self.n_vals[self.env_indices_expnd, translation] = self.n_vals[self.env_indices_expnd, old_subtree_idxs]
self.n_vals[self.env_indices_expnd, erase] = 0
self.p_vals[self.env_indices_expnd, translation] = self.p_vals[self.env_indices_expnd, old_subtree_idxs]
self.p_vals[self.env_indices_expnd, erase] = 0
self.next_idx[self.env_indices_expnd, translation] = translation[self.env_indices.view(-1, 1, 1), self.next_idx]
self.next_idx[self.env_indices_expnd, erase] = 0
self.parents[self.env_indices_expnd, translation] = translation[self.env_indices_expnd, self.parents]
self.parents[self.env_indices_expnd, erase] = 0
self.max_depths -= 1
self.max_depths.clamp_(min=1)
We can divide persisting subtrees into a few steps:
- Assign each node to a root subtree
- Re-index subtree nodes
- Copy data to new node indices
- Erase any other data that's left
- Adjust MCTS state tensors
First, we need to figure out which root subtree each node in the search tree belongs to. We'd like the new tree to only include nodes contained within the subtree corresponding to a chosen action.
We create self.subtrees
(shape
self.subtrees = torch.zeros(
(self.parallel_envs, self.total_slots),
dtype=torch.long,
device=self.device,
requires_grad=False
) # index by master id
First, we zero out each subtree, and set each node's subtree equal to its own ID.
def propogate_root_subtrees(self):
self.subtrees.zero_()
self.subtrees += self.slots_aranged # node index value = node index
self.parents[:, 0] = 0
The plan is to iteratively set each node's subtree equal to its parent subtree, until all nodes subtrees are assigned one of the root node's children. We will have to do this at most self.max_depths
during MCTS iteration, We know how many times we'll need to iterate at most -- it's the maximum of self.max_depths
, or self.max_depths.max()
!
for _ in range(self.max_depths.max() + 1):
parent_subtrees = self.subtrees[self.env_indices_expnd, self.parents]
self.subtrees = (parent_subtrees * (parent_subtrees > 1)) + \
(self.subtrees * (parent_subtrees <= 1))
For each of self.subtrees
contains the immediate root subtree that each and every node belongs to.
Next, we need to select the subtree root indices corresponding to each action taken in the environments. We can find these by accessing the action-node_id mapping of the current tree root node.
subtree_master_indices = self.next_idx[self.env_indices, 1, actions]
In edge cases, it's possible the chosen action does not have a node assigned to it yet. We can check via:
is_real = subtree_master_indices > 1
Then, we can filter nodes that are included in the action's subtree:
new_nodes = (self.subtrees == subtree_master_indices.view(-1, 1))
This gives us a boolean tensor with False
at indices of nodes not included in the subtree, and True
at indices of nodes that do belong to the given subtree.
Next, we'd like to compress the subtree nodes, such that they take up the first torch.cumsum
to map a node's current index to its index relative to the other subtree nodes. It's a pretty slick line of code.
translation = new_nodes * is_real.view(-1, 1) * new_nodes.long().cumsum(dim=1)
We can also create a filtered index containing current node ids for nodes in the subtree:
old_subtree_idxs = self.slots_aranged * new_nodes
Now we're ready to re-index all of the subtree node data. We need to re-index self.next_idx
, and self.parents
. We can actually achieve this with a single line of code:
self.w_vals[self.env_indices_expnd, translation] = self.w_vals[self.env_indices_expnd, old_subtree_idxs]
We use the translation
to index into each data tensor (in this case,
We can do the same for each of the other data tensors:
self.n_vals[self.env_indices_expnd, translation] = self.n_vals[self.env_indices_expnd, old_subtree_idxs]
self.p_vals[self.env_indices_expnd, translation] = self.p_vals[self.env_indices_expnd, old_subtree_idxs]
self.parents[self.env_indices_expnd, translation] = translation[self.env_indices_expnd, self.parents]
self.next_idx[self.env_indices_expnd, translation] = translation[self.env_indices.view(-1, 1, 1), self.next_idx]
self.parents
and self.next_idx
require a little more thought, see if you can figure out the ideas!
Now that the subtree data is in the correct location, any indices we haven't yet touched will contain old data, which we need to zero out. We just need to know where the new subtree ends for each MCTS instance, and then zero out the slice of data occurring after that index. We know that the new next empty tensor should be the maximum new index + 1:
self.next_empty = torch.amax(translation, dim=1) + 1
I'm sure there is a way to use this tensor directly as an index to slice our desired data, but I'm not sure how. Instead, we can gather all the indices not in our subtree occurring after this value:
erase = self.slots_aranged * (self.slots_aranged >= self.next_empty.view(-1, 1))
self.next_empty.clamp_(min=2)
Then, just set these to 0!
self.w_vals[self.env_indices_expnd, erase] = 0
self.n_vals[self.env_indices_expnd, erase] = 0
self.p_vals[self.env_indices_expnd, erase] = 0
self.next_idx[self.env_indices_expnd, erase] = 0
self.parents[self.env_indices_expnd, erase] = 0
State should now be in good shape, we've remapped all the subtree nodes and updated self.next_empty
. All that's left to do is to update self.max_depths
, which we can just decrease by 1 have height
self.max_depths -= 1
self.max_depths.clamp_(min=1)
And that's it! This allows us to save lots of work by persisting subtrees from previous MCTS evaluations! You can view the full source code here.
-
num_iters
: number of iterations/edges traversed before final evaluation -
max_nodes
: memory for this many nodes is allocated -- if no memory is free, every iteration that would result in traversing to a new, unexplored node instead resets back to the root node -
puct_coeff
:$c_{puct}$ in$U(s,a) = c_{puct}P(s,a)\frac{\sqrt{\sum_{b}{N(s,b)}}}{1+N(s,a)}$ , higher values bias towards exploration instead of exploitation -
dirichlet_epsilon
:$d_{e}$ in$P_{n}(S) = (P(S) * (1-d_{e})) + (Dirichlet(d_{a}) * d_{e})$ , proportion of root policy composed of Dirichlet noise -
dirichlet_alpha
:$d_{a}$ in$P_{n}(S) = (P(S) * (1-d_{e})) + (Dirichlet(d_{a}) * d_{e})$ , magnitude of Dirichlet noise added to root policy -
temperature
:$\tau$ in$P_{f}(S) = P(S)^{1/\tau}$ , adjusts policy sampling