Commit 49932862 authored by John Schulman's avatar John Schulman Committed by GitHub
Browse files

Merge pull request #160 from mkarutz/fixFrameStackingA2C

Fixes frame stacking in A2C and ACKTR for multi-channel observations
parents 3eb71a0e cc8818f4
......@@ -98,6 +98,7 @@ class Runner(object):
nenv = env.num_envs
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack)
self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8)
self.nc = nc
obs = env.reset()
self.update_obs(obs)
self.gamma = gamma
......@@ -108,8 +109,8 @@ class Runner(object):
def update_obs(self, obs):
# Do frame-stacking here instead of the FrameStack wrapper to reduce
# IPC overhead
self.obs = np.roll(self.obs, shift=-1, axis=3)
self.obs[:, :, :, -1] = obs[:, :, :, 0]
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
self.obs[:, :, :, -self.nc:] = obs
def run(self):
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
......
......@@ -113,6 +113,7 @@ class Runner(object):
nenv = env.num_envs
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack)
self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8)
self.nc = nc
obs = env.reset()
self.update_obs(obs)
self.gamma = gamma
......@@ -121,8 +122,8 @@ class Runner(object):
self.dones = [False for _ in range(nenv)]
def update_obs(self, obs):
self.obs = np.roll(self.obs, shift=-1, axis=3)
self.obs[:, :, :, -1] = obs[:, :, :, 0]
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
self.obs[:, :, :, -self.nc:] = obs
def run(self):
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment