Newer
Older
import networkx as nx
from math import log
import numpy as np
def _UCT(mu_j, c_p, n_p, n_j):
if n_j ==0:
return float("Inf")
return mu_j + 2*c_p *(2*log(n_p)/n_j)**0.5
"""
DecMCTS tree
To Initiate, Inputs:
- data
- data required to calculate reward, available options
- reward
- This is a function which has inputs (data, state) and
returns reward (averaged in "mu") to be maximised
- It may be the case that the reward function can only
calculated once simulation is complete, in which
case it should return "None" while simulation is
incomplete
- available_actions
- This is a function which has inputs (data, state) and
returns the possible actions which can be taken
- c_p
- exploration multiplier (number between 0 and 1)
Usage:
- grow
- grow MCTS tree by 1 node
- send_comms
- get state of this tree to communicate to others
- receive_comms
- Input the state of other trees for use in calculating
reward/available actions in coordination with others
"""
def __init__(self, data, reward_func, avail_actions_func, c_p=1):
self.available_actions = None
self.c_p
# Graph
self.graph.add_node(1)
def _parent(self, node_id):
"""
wrapper for code readability
"""
return list(self.graph.predecessors(node_id))[0]
def _select(self, children):
"""
Select Child node which maximises UCT
"""
# N for parent
n_p = self.graph.nodes[self._parent(children[0])]["N"]
uct = [_UCT(node["mu"], self.c_p, n_p, node["N"])
for node in map(self.graph.nodes.__getitem__, children)]
# Return Child with highest UCT
return children[np.argmax(uct)]
def _childNodes(self, node_id):
"""
wrapper for code readability
"""
return list(self.graph.successors(node_id))
def _get_state(self, node_id):
"""
Randomly select 1 path taken by every other robot
& Calculate path taken by this robot so far
"""
return "TODO"
def grow(self, nsims=10, ntime=None):
"""
Grow Tree by one node
"""
### SELECTION
start_node = 1
while len(self._childNodes(start_node))>0:
start_node = self._select(self._childNodes(start_node))
state = self._get_state(start_node)
options = self.available_actions(self.data, state)