Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
10
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
bigprint
pyDecMCTS
Commits
37cb43af
Commit
37cb43af
authored
Jan 13, 2020
by
brian.lee
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix strange behaviours with best_reward
parent
bbde9a45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
6 deletions
+14
-6
DecMCTS.py
DecMCTS.py
+14
-6
No files found.
DecMCTS.py
View file @
37cb43af
from
__future__
import
print_function
import
networkx
as
nx
from
copy
import
copy
from
math
import
log
import
numpy
as
np
def
_UCT
(
mu_j
,
c_p
,
n_p
,
n_j
):
if
n_j
==
0
:
return
float
(
"Inf"
)
...
...
@@ -111,6 +112,7 @@ class Tree:
self
.
graph
.
add_node
(
1
,
mu
=
0
,
N
=
0
,
best_reward
=
0
,
state
=
self
.
state_store
(
self
.
data
,
None
,
None
,
self
.
id
)
)
...
...
@@ -212,6 +214,7 @@ class Tree:
self
.
graph
.
add_node
(
len
(
self
.
graph
)
+
1
,
mu
=
0
,
best_reward
=
0
,
N
=
0
,
state
=
self
.
state_store
(
self
.
data
,
self
.
graph
.
node
[
start_node
][
"state"
],
o
,
self
.
id
)
)
...
...
@@ -240,7 +243,6 @@ class Tree:
### EXPANSION
# check if _expansion changes start_node to the node after jumping
self
.
_expansion
(
start_node
)
print
(
self
.
_childNodes
(
start_node
))
### SIMULATION
avg_reward
=
0
...
...
@@ -274,15 +276,17 @@ class Tree:
state
[
self
.
id
]
=
temp_state
# calculate the reward at the end of simulation
rew
=
self
.
reward
(
self
.
data
,
state
)
\
-
self
.
reward
(
self
.
data
,
self
.
_null_state
(
state
))
rew
=
self
.
reward
(
self
.
data
,
state
)
avg_reward
+=
rew
# if best reward so far, store the rollout in the new node
if
rew
>
best_reward
:
best_reward
=
rew
best_rollout
=
copy
(
temp_state
)
self
.
graph
.
node
[
start_node
][
"mu"
]
=
avg_reward
avg_reward
=
avg_reward
/
nsims
self
.
graph
.
node
[
start_node
][
"mu"
]
=
avg_reward
self
.
graph
.
node
[
start_node
][
"best_reward"
]
=
best_reward
self
.
graph
.
node
[
start_node
][
"N"
]
=
1
self
.
graph
.
node
[
start_node
][
"best_rollout"
]
=
copy
(
best_rollout
)
...
...
@@ -299,6 +303,10 @@ class Tree:
self
.
graph
.
node
[
start_node
][
"N"
]
=
\
gamma
*
self
.
graph
.
node
[
start_node
][
"N"
]
+
1
if
best_reward
>
self
.
graph
.
node
[
start_node
][
"best_reward"
]:
self
.
graph
.
node
[
start_node
][
"best_reward"
]
=
best_reward
self
.
graph
.
node
[
start_node
][
"best_rollout"
]
=
copy
(
best_rollout
)
self
.
_update_distribution
()
return
avg_reward
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment