diff --git a/src/decmcts.py b/src/decmcts.py index 63e97d1c50af019af4c4918a80d1bc87a3870691..d962c73c2f1ff1a960af72c7e7031f1aa6acaafa 100644 --- a/src/decmcts.py +++ b/src/decmcts.py @@ -10,6 +10,7 @@ import numpy as np import pickle import copy import warnings +import random home = [0, -pi/2, pi/4, -pi/4, -pi/2, -pi/4] @@ -39,11 +40,22 @@ def joint_diff(js1,js2): return sum([(j1-j2)**2 for j1,j2 in zip(js1,js2)]) -def sim_select(data, options, sim_state): - j = sim_state[-1].positions[-1] - costs = [joint_diff(j,o.positions[0]) for o in options] - return options[np.argmin(costs)] # greedy -# return np.random.choice(options) # Random (much slower for some reason) +def sim_select(data, options, sim_state, method='greedy'): + """ + How to select option during simulation + """ + if method == 'greedy': + j = sim_state[-1].positions[-1] + costs = [joint_diff(j,o.positions[0]) for o in options] + return options[np.argmin(costs)] + elif method == 'random': + return random.choice(options) + elif method == 'weighted': + j = sim_state[-1].positions[-1] + weights = [1/joint_diff(j,o.positions[0]) for o in options] + return random.choices(options, weights=weights)[0] + else: + print("ERROR: Invalid sim_select method: '{}'".format(method)) class CombinedSchedule: