Commit 396d598d authored by Jayant Khatkar's avatar Jayant Khatkar

Interface changed to make state simpler

parent 2d190af8
import networkx as nx
from copy import deepcopy
from math import log
import numpy as np
......@@ -77,7 +78,14 @@ class Tree:
reward/available actions in coordination with others
"""
def __init__(self, data, reward_func, avail_actions_func, comm_n, c_p=1, time_func=None):
def __init__(self,
data,
reward_func,
avail_actions_func,
comm_n,
robot_id,
c_p=1,
time_func=None):
self.data = data
self.graph = nx.DiGraph()
......@@ -85,6 +93,7 @@ class Tree:
self.available_actions = avail_actions_func
self.time_func = time_func
self.c_p = c_p
self.id = robot_id
self.comms = {} # Plan with no robots initially
self.comm_n = comm_n # number of action dists to communicate
......@@ -160,13 +169,16 @@ class Tree:
"""
node_path = self.graph.node[node_id]["action_seq"]
other_paths = {k:self.comms[k].random_action() for k in self.comms}
all_paths = {k:self.comms[k].random_action() for k in self.comms}
all_paths[self.id] = node_path
return (node_path, other_paths)
return all_paths
def _null_state(self, state):
return ([], state[1])
temp = deepcopy(state)
temp[self.id] = []
return temp
def _expansion(self, start_node):
......
......@@ -30,7 +30,7 @@ def avail_actions(data, state):
# State is a dictionary with keys being robot IDs, and values
# are a list of actions taken from the starting position
def reward(dat, state):
each_robot_sum = [sum(state[1][a]) for a in state[1]]
each_robot_sum = [sum(state[a]) for a in state]
return sum(each_robot_sum)
# Number of Action Sequences to communicate
......
......@@ -7,17 +7,15 @@ def avail_actions(data, state):
return [1,2,3,4,5]
def reward(dat, state):
other_robots = [sum(state[1][a]) for a in state[1]]
#if sum(other_robots) + sum(state[0]) >25:
# return 0
return sum(state[0]) + sum(other_robots)
each_robot_sum= [sum(state[a]) for a in state]
return sum(each_robot_sum)
comm_n = 5
tree1 = Tree(data, reward, avail_actions, comm_n)
tree1 = Tree(data, reward, avail_actions, comm_n, 1)
tree2 = Tree(data, reward, avail_actions, comm_n)
tree2 = Tree(data, reward, avail_actions, comm_n, 2)
for i in range(350):
tree1.grow()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment