Newer
Older
import networkx as nx
from math import log
import numpy as np
def _UCT(mu_j, c_p, n_p, n_j):
if n_j ==0:
return float("Inf")
return mu_j + 2*c_p *(2*log(n_p)/n_j)**0.5
class tree:
def __init__(self, data):
self.data = data
self.available_actions = None
self.c_p
# Graph
self.graph.add_node(1)
def _select(self, children):
"""
Select Child node which maximises UCT
"""
# N for parent
n_p = self.graph.nodes[list(self.graph.predecessors(children[0]))[0]]["N"]
# UCT values for children
uct = [_UCT(node["mu"], self.c_p, n_p, node["N"]) for node in map(self.graph.nodes.__getitem__, children)]
# Return Child with highest UCT
return children[np.argmax(uct)]
def grow(self, nsims=10, ntime=None):
"""
Grow Tree by one node
"""
### SELECTION
start_node = 1
while len(list(self.graph.successors(start_node)))>0:
start_node = self._select(list(self.graph.successors(start_node)))
### EXPANSION