diff --git a/DecMCTS.py b/DecMCTS.py index e5ca2d1f357f41504acc51c23d51d297a9a1eeb8..a9ab84ac7b66931cafa2e7bde8bef50b0c980265 100644 --- a/DecMCTS.py +++ b/DecMCTS.py @@ -162,8 +162,8 @@ class Tree: temp.pop(1, None) top_n_nodes = sorted(temp, key=temp.get, reverse=True)[:self.comm_n] - X = [self.graph.node[n]["state"] for n in top_n_nodes] - q = [self.graph.node[n]["mu"]**2 for n in top_n_nodes] + X = [self.graph.node[n]["best_rollout"] for n in top_n_nodes if self.graph.node[n]["N"]>0] + q = [self.graph.node[n]["mu"]**2 for n in top_n_nodes if self.graph.node[n]["N"]>0] self.my_act_dist = ActionDistribution(X,q) return True