Skip to content
Snippets Groups Projects
DecMCTS.py 1.27 KiB
Newer Older
Jayant Khatkar's avatar
Jayant Khatkar committed
import networkx as nx
from math import log
import numpy as np


def _UCT(mu_j, c_p, n_p, n_j):
    if n_j ==0:
        return float("Inf")

    return mu_j + 2*c_p *(2*log(n_p)/n_j)**0.5
Jayant Khatkar's avatar
Jayant Khatkar committed


class tree:

    def __init__(self, data):

        self.data = data
Jayant Khatkar's avatar
Jayant Khatkar committed
        self.graph = nx.DiGraph()
Jayant Khatkar's avatar
Jayant Khatkar committed
        self.reward_func = None
Jayant Khatkar's avatar
Jayant Khatkar committed
        self.available_actions = None
        self.c_p

        # Graph
        self.graph.add_node(1)


    def _select(self, children):
        """
        Select Child node which maximises UCT
        """
        
        # N for parent
        n_p = self.graph.nodes[list(self.graph.predecessors(children[0]))[0]]["N"]

        # UCT values for children
        uct = [_UCT(node["mu"], self.c_p, n_p, node["N"]) for node in map(self.graph.nodes.__getitem__, children)]

        # Return Child with highest UCT
        return children[np.argmax(uct)]


    def grow(self, nsims=10, ntime=None):
        """
        Grow Tree by one node
        """
        
        ### SELECTION
        start_node = 1
        while len(list(self.graph.successors(start_node)))>0:
            start_node = self._select(list(self.graph.successors(start_node)))

        ### EXPANSION
Jayant Khatkar's avatar
Jayant Khatkar committed


    def send_comms(self):
        print "TODO"

Jayant Khatkar's avatar
Jayant Khatkar committed
    def receive_comms(self):
        print "TODO"