Skip to content
Snippets Groups Projects
Commit 2c509b67 authored by Jayant Khatkar's avatar Jayant Khatkar
Browse files

fix python3 nodes bug, and sim_state bug

parent 03852849
Branches
Tags
No related merge requests found
......@@ -5,6 +5,7 @@ from copy import copy
from math import log
import numpy as np
def _UCT(mu_j, c_p, n_p, n_j):
if n_j ==0:
return float("Inf")
......@@ -117,8 +118,7 @@ class Tree:
)
# Set Action sequence as nothing for now
self.my_act_dist = ActionDistribution([self.graph.node[1]["state"]],[1])
self.my_act_dist = ActionDistribution([self.graph.nodes[1]["state"]],[1])
self._expansion(1)
......@@ -127,7 +127,6 @@ class Tree:
"""
wrapper for code readability
"""
return list(self.graph.predecessors(node_id))[0]
......@@ -137,7 +136,7 @@ class Tree:
"""
# N for parent
n_p = self.graph.node[self._parent(children[0])]["N"]
n_p = self.graph.nodes[self._parent(children[0])]["N"]
# UCT values for children
uct = [_UCT(node["mu"], self.c_p, n_p, node["N"])
......@@ -166,8 +165,8 @@ class Tree:
temp.pop(1, None)
top_n_nodes = sorted(temp, key=temp.get, reverse=True)[:self.comm_n]
X = [self.graph.node[n]["best_rollout"] for n in top_n_nodes if self.graph.node[n]["N"]>0]
q = [self.graph.node[n]["mu"]**2 for n in top_n_nodes if self.graph.node[n]["N"]>0]
X = [self.graph.nodes[n]["best_rollout"] for n in top_n_nodes if self.graph.nodes[n]["N"]>0]
q = [self.graph.nodes[n]["mu"]**2 for n in top_n_nodes if self.graph.nodes[n]["N"]>0]
self.my_act_dist = ActionDistribution(X,q)
return True
......@@ -182,14 +181,14 @@ class Tree:
"""
system_state = {k:self.comms[k].random_action() for k in self.comms}
system_state[self.id] = self.graph.node[node_id]["state"]
system_state[self.id] = self.graph.nodes[node_id]["state"]
return system_state
def _null_state(self, state):
temp = copy(state)
temp[self.id] = self.graph.node[1]["state"] # Null state is if robot still at root node
temp[self.id] = self.graph.nodes[1]["state"] # Null state is if robot still at root node
return temp
......@@ -202,7 +201,7 @@ class Tree:
options = self.available_actions(
self.data,
self.graph.node[start_node]["state"],
self.graph.nodes[start_node]["state"],
self.id
)
......@@ -216,7 +215,7 @@ class Tree:
mu = 0,
best_reward = 0,
N = 0,
state=self.state_store(self.data, self.graph.node[start_node]["state"], o, self.id)
state=self.state_store(self.data, self.graph.nodes[start_node]["state"], o, self.id)
)
self.graph.add_edge(start_node, len(self.graph))
......@@ -249,7 +248,7 @@ class Tree:
best_reward = float("-Inf")
best_rollout = None
for i in range(nsims):
temp_state = self.graph.node[start_node]["state"]
temp_state = self.graph.nodes[start_node]["state"]
state[self.id] = temp_state
d = 0 # depth
......@@ -259,7 +258,7 @@ class Tree:
# Get the available actions
options = self.sim_available_actions(
self.data,
state,
state[self.id],
self.id
)
......@@ -285,27 +284,27 @@ class Tree:
avg_reward = avg_reward / nsims
self.graph.node[start_node]["mu"] = avg_reward
self.graph.node[start_node]["best_reward"] = best_reward
self.graph.node[start_node]["N"] = 1
self.graph.node[start_node]["best_rollout"] = copy(best_rollout)
self.graph.nodes[start_node]["mu"] = avg_reward
self.graph.nodes[start_node]["best_reward"] = best_reward
self.graph.nodes[start_node]["N"] = 1
self.graph.nodes[start_node]["best_rollout"] = copy(best_rollout)
### BACKPROPOGATION
while start_node!=1: #while not root node
start_node = self._parent(start_node)
self.graph.node[start_node]["mu"] = \
(gamma * self.graph.node[start_node]["mu"] * \
self.graph.node[start_node]["N"] + avg_reward) \
/(self.graph.node[start_node]["N"] + 1)
self.graph.nodes[start_node]["mu"] = \
(gamma * self.graph.nodes[start_node]["mu"] * \
self.graph.nodes[start_node]["N"] + avg_reward) \
/(self.graph.nodes[start_node]["N"] + 1)
self.graph.node[start_node]["N"] = \
gamma * self.graph.node[start_node]["N"] + 1
self.graph.nodes[start_node]["N"] = \
gamma * self.graph.nodes[start_node]["N"] + 1
if best_reward > self.graph.node[start_node]["best_reward"]:
self.graph.node[start_node]["best_reward"] = best_reward
self.graph.node[start_node]["best_rollout"] = copy(best_rollout)
if best_reward > self.graph.nodes[start_node]["best_reward"]:
self.graph.nodes[start_node]["best_reward"] = best_reward
self.graph.nodes[start_node]["best_rollout"] = copy(best_rollout)
self._update_distribution()
......
......@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name='pydecmcts',
version='0.4',
version='0.5',
packages=find_packages(include=['pydecmcts']),
py_modules=['pydecmcts']
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment