import networkx as nx from math import log import numpy as np def _UCT(mu_j, c_p, n_p, n_j): if n_j ==0: return float("Inf") return mu_j + 2*c_p *(2*log(n_p)/n_j)**0.5 class tree: def __init__(self, data): self.data = data self.graph = nx.DiGraph() self.reward_func = None self.available_actions = None self.c_p # Graph self.graph.add_node(1) def _select(self, children): """ Select Child node which maximises UCT """ # N for parent n_p = self.graph.nodes[list(self.graph.predecessors(children[0]))[0]]["N"] # UCT values for children uct = [_UCT(node["mu"], self.c_p, n_p, node["N"]) for node in map(self.graph.nodes.__getitem__, children)] # Return Child with highest UCT return children[np.argmax(uct)] def grow(self, nsims=10, ntime=None): """ Grow Tree by one node """ ### SELECTION start_node = 1 while len(list(self.graph.successors(start_node)))>0: start_node = self._select(list(self.graph.successors(start_node))) ### EXPANSION def send_comms(self): print "TODO" def receive_comms(self): print "TODO"