Skip to content
Snippets Groups Projects
Commit c6719297 authored by Jayant Khatkar's avatar Jayant Khatkar
Browse files

add random and weighted random rollout options (fix #65)

parent 36366f1f
Branches
No related merge requests found
......@@ -10,6 +10,7 @@ import numpy as np
import pickle
import copy
import warnings
import random
home = [0, -pi/2, pi/4, -pi/4, -pi/2, -pi/4]
......@@ -39,11 +40,22 @@ def joint_diff(js1,js2):
return sum([(j1-j2)**2 for j1,j2 in zip(js1,js2)])
def sim_select(data, options, sim_state):
j = sim_state[-1].positions[-1]
costs = [joint_diff(j,o.positions[0]) for o in options]
return options[np.argmin(costs)] # greedy
# return np.random.choice(options) # Random (much slower for some reason)
def sim_select(data, options, sim_state, method='greedy'):
"""
How to select option during simulation
"""
if method == 'greedy':
j = sim_state[-1].positions[-1]
costs = [joint_diff(j,o.positions[0]) for o in options]
return options[np.argmin(costs)]
elif method == 'random':
return random.choice(options)
elif method == 'weighted':
j = sim_state[-1].positions[-1]
weights = [1/joint_diff(j,o.positions[0]) for o in options]
return random.choices(options, weights=weights)[0]
else:
print("ERROR: Invalid sim_select method: '{}'".format(method))
class CombinedSchedule:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment