Commit 37cb43af authored by brian.lee's avatar brian.lee

fix strange behaviours with best_reward

parent bbde9a45
from __future__ import print_function
import networkx as nx
from copy import copy
from math import log
import numpy as np
def _UCT(mu_j, c_p, n_p, n_j):
if n_j ==0:
return float("Inf")
......@@ -111,6 +112,7 @@ class Tree:
self.graph.add_node(1,
mu=0,
N=0,
best_reward = 0,
state=self.state_store(self.data, None, None, self.id)
)
......@@ -212,6 +214,7 @@ class Tree:
self.graph.add_node(len(self.graph)+1,
mu = 0,
best_reward = 0,
N = 0,
state=self.state_store(self.data, self.graph.node[start_node]["state"], o, self.id)
)
......@@ -240,7 +243,6 @@ class Tree:
### EXPANSION
# check if _expansion changes start_node to the node after jumping
self._expansion(start_node)
print(self._childNodes(start_node))
### SIMULATION
avg_reward = 0
......@@ -274,15 +276,17 @@ class Tree:
state[self.id] = temp_state
# calculate the reward at the end of simulation
rew = self.reward(self.data, state) \
- self.reward(self.data, self._null_state(state))
rew = self.reward(self.data, state)
avg_reward += rew
# if best reward so far, store the rollout in the new node
if rew > best_reward:
best_reward = rew
best_rollout = copy(temp_state)
self.graph.node[start_node]["mu"] = avg_reward
avg_reward = avg_reward / nsims
self.graph.node[start_node]["mu"] = avg_reward
self.graph.node[start_node]["best_reward"] = best_reward
self.graph.node[start_node]["N"] = 1
self.graph.node[start_node]["best_rollout"] = copy(best_rollout)
......@@ -299,6 +303,10 @@ class Tree:
self.graph.node[start_node]["N"] = \
gamma * self.graph.node[start_node]["N"] + 1
if best_reward > self.graph.node[start_node]["best_reward"]:
self.graph.node[start_node]["best_reward"] = best_reward
self.graph.node[start_node]["best_rollout"] = copy(best_rollout)
self._update_distribution()
return avg_reward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment