Skip to content
Snippets Groups Projects
DecMCTS.py 5.17 KiB
Newer Older
Jayant Khatkar's avatar
Jayant Khatkar committed
import networkx as nx
from math import log
import numpy as np


def _UCT(mu_j, c_p, n_p, n_j):
    if n_j ==0:
        return float("Inf")

    return mu_j + 2*c_p *(2*log(n_p)/n_j)**0.5
Jayant Khatkar's avatar
Jayant Khatkar committed
class ActionDistribution:
    """
    Action Distribution
    Working with action sequences and their respective probability

    To initialise, Inputs:
    - X: list of action sequences
    - q: probability of each action sequence (normalised in intialisation)
Jayant Khatkar's avatar
Jayant Khatkar committed

    """
    
    def __init__(self, X, q):
Jayant Khatkar's avatar
Jayant Khatkar committed
        
        # Action sequence as provided
        assert(len(X)==n)
        self.X = X

        # Normalise 
        self.q = (np.array(q)/sum(q)).tolist()

Jayant Khatkar's avatar
Jayant Khatkar committed

    def best_action(self):
        """
        Most likely action sequence
        """
        return self.X[np.argmax(self.q)]

    
    def random_action(self):
        """
        Weighted random out of possible action sequences
        """
        return np.random.choice(self.X, p=self.q)


class Tree:
    """
    DecMCTS tree
    To Initiate, Inputs:
    - data
        - data required to calculate reward, available options 
    - reward
        - This is a function which has inputs (data, state) and
Jayant Khatkar's avatar
Jayant Khatkar committed
            returns the global reward to be maximised
        - It may be the case that the reward function can only
            calculated once simulation is complete, in which 
            case it should return "None" while simulation is
            incomplete
    - available_actions
        - This is a function which has inputs (data, state) and
            returns the possible actions which can be taken
    - c_p
        - exploration multiplier (number between 0 and 1)

    Usage:
    - grow
        - grow MCTS tree by 1 node
    - send_comms
        - get state of this tree to communicate to others
    - receive_comms
        - Input the state of other trees for use in calculating
            reward/available actions in coordination with others
    """

    def __init__(self, data, reward_func, avail_actions_func, comm_n, c_p=1):
Jayant Khatkar's avatar
Jayant Khatkar committed

        self.data = data
Jayant Khatkar's avatar
Jayant Khatkar committed
        self.graph = nx.DiGraph()
Jayant Khatkar's avatar
Jayant Khatkar committed
        self.reward = reward_func
        self.available_actions = avail_actions_func
        self.c_p = c_p
        self.comms = {} # Plan with no robots initially
        self.comm_n = comm_n # number of action dists to communicate

        # Set Action sequence as nothing for now
        self.my_act_dist = ActionDistribution([[]],[1])
        # Graph add root node of tree
Jayant Khatkar's avatar
Jayant Khatkar committed
        self.graph.add_node(1, 
                mu=0, 
                N=0, 
                action_seq=[], 
                cost=0
                )
    def _parent(self, node_id):
        """
        wrapper for code readability
        """

        return list(self.graph.predecessors(node_id))[0]


Jayant Khatkar's avatar
Jayant Khatkar committed
    def _select(self, children):
        """
        Select Child node which maximises UCT
        """
        
        # N for parent
        n_p = self.graph.nodes[self._parent(children[0])]["N"]
Jayant Khatkar's avatar
Jayant Khatkar committed

        # UCT values for children
        uct = [_UCT(node["mu"], self.c_p, n_p, node["N"]) 
                for node in map(self.graph.nodes.__getitem__, children)]
Jayant Khatkar's avatar
Jayant Khatkar committed

        # Return Child with highest UCT
        return children[np.argmax(uct)]


    def _childNodes(self, node_id):
        """
        wrapper for code readability
        """

        return list(self.graph.successors(node_id))


    def _update_distribution(self):
        """
        Get the top n Action sequences and their "probabilities"
            and store them for communication
        """

        # For now, just using q = mu**2
        temp=nx.get_node_attributes(self.graph, "mu")

        top_n_nodes = sorted(temp, key=temp.get, reverse=True)[:self.comm_n]
        X = [self.graph.nodes[n]["action_seq"] for n in top_n_nodes]
        q = [self.graph.nodes[n]["mu"]**2 for n in top_n_nodes]
        self.my_act_dist = ActionDistribution(X,q)
        return True

Jayant Khatkar's avatar
Jayant Khatkar committed

    def _get_state(self, node_id):
        """
        Randomly select 1 path taken by every other robot & path taken by 
            this robot to get to this node

        Returns tuple where first element is path of current robot,
            and second element is a dictionary of the other paths
Jayant Khatkar's avatar
Jayant Khatkar committed

        node_path = self.graph.nodes[node_id]["action_seq"]
        other_paths = {k:self.comms[k].random_action() for k in self.comms}
Jayant Khatkar's avatar
Jayant Khatkar committed

        return (node_path, other_paths)
Jayant Khatkar's avatar
Jayant Khatkar committed
    def grow(self, nsims=10, ntime=None):
        """
        Grow Tree by one node
        """
        
        ### SELECTION
        start_node = 1
        while len(self._childNodes(start_node))>0:
            start_node = self._select(self._childNodes(start_node))
Jayant Khatkar's avatar
Jayant Khatkar committed

        ### EXPANSION
        state = self._get_state(start_node)
        options = self.available_actions(self.data, state)
Jayant Khatkar's avatar
Jayant Khatkar committed

    def send_comms(self):
        return self.my_act_dist
Jayant Khatkar's avatar
Jayant Khatkar committed

Jayant Khatkar's avatar
Jayant Khatkar committed
    def receive_comms(self, comms_in, robot_id):
        """
        Save data which has been communicated to this tree
        Only receives from one robot at a time, call once 
        for each robot
        
        Inputs:
        - comms_in
            - An Action distribution object
        - robot_id
            - Robot number/id - used as key for comms
        """
        self.comms[robot_id] = comms_in
        return True