Skip to content
Snippets Groups Projects
Commit 37cb43af authored by brian.lee's avatar brian.lee
Browse files

fix strange behaviours with best_reward

parent bbde9a45
1 merge request!1Fix best rollout
from __future__ import print_function
import networkx as nx import networkx as nx
from copy import copy from copy import copy
from math import log from math import log
import numpy as np import numpy as np
def _UCT(mu_j, c_p, n_p, n_j): def _UCT(mu_j, c_p, n_p, n_j):
if n_j ==0: if n_j ==0:
return float("Inf") return float("Inf")
...@@ -111,6 +112,7 @@ class Tree: ...@@ -111,6 +112,7 @@ class Tree:
self.graph.add_node(1, self.graph.add_node(1,
mu=0, mu=0,
N=0, N=0,
best_reward = 0,
state=self.state_store(self.data, None, None, self.id) state=self.state_store(self.data, None, None, self.id)
) )
...@@ -212,6 +214,7 @@ class Tree: ...@@ -212,6 +214,7 @@ class Tree:
self.graph.add_node(len(self.graph)+1, self.graph.add_node(len(self.graph)+1,
mu = 0, mu = 0,
best_reward = 0,
N = 0, N = 0,
state=self.state_store(self.data, self.graph.node[start_node]["state"], o, self.id) state=self.state_store(self.data, self.graph.node[start_node]["state"], o, self.id)
) )
...@@ -240,7 +243,6 @@ class Tree: ...@@ -240,7 +243,6 @@ class Tree:
### EXPANSION ### EXPANSION
# check if _expansion changes start_node to the node after jumping # check if _expansion changes start_node to the node after jumping
self._expansion(start_node) self._expansion(start_node)
print(self._childNodes(start_node))
### SIMULATION ### SIMULATION
avg_reward = 0 avg_reward = 0
...@@ -274,15 +276,17 @@ class Tree: ...@@ -274,15 +276,17 @@ class Tree:
state[self.id] = temp_state state[self.id] = temp_state
# calculate the reward at the end of simulation # calculate the reward at the end of simulation
rew = self.reward(self.data, state) \ rew = self.reward(self.data, state)
- self.reward(self.data, self._null_state(state)) avg_reward += rew
# if best reward so far, store the rollout in the new node # if best reward so far, store the rollout in the new node
if rew > best_reward: if rew > best_reward:
best_reward = rew best_reward = rew
best_rollout = copy(temp_state) best_rollout = copy(temp_state)
self.graph.node[start_node]["mu"] = avg_reward
avg_reward = avg_reward / nsims
self.graph.node[start_node]["mu"] = avg_reward
self.graph.node[start_node]["best_reward"] = best_reward
self.graph.node[start_node]["N"] = 1 self.graph.node[start_node]["N"] = 1
self.graph.node[start_node]["best_rollout"] = copy(best_rollout) self.graph.node[start_node]["best_rollout"] = copy(best_rollout)
...@@ -299,6 +303,10 @@ class Tree: ...@@ -299,6 +303,10 @@ class Tree:
self.graph.node[start_node]["N"] = \ self.graph.node[start_node]["N"] = \
gamma * self.graph.node[start_node]["N"] + 1 gamma * self.graph.node[start_node]["N"] + 1
if best_reward > self.graph.node[start_node]["best_reward"]:
self.graph.node[start_node]["best_reward"] = best_reward
self.graph.node[start_node]["best_rollout"] = copy(best_rollout)
self._update_distribution() self._update_distribution()
return avg_reward return avg_reward
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment