Commit da997060 authored by John Schulman's avatar John Schulman
Browse files

ppo and trpo

parent 80f94f8e
......@@ -7,12 +7,11 @@
# Setuptools distribution and build folders.
# Virtualenv
# Python egg metadata, regenerated from source files by setuptools.
......@@ -26,4 +25,8 @@ ghostdriver.log
\ No newline at end of file
<img src="data/logo.jpg" width=25% align="right" />
# Baselines
We're releasing OpenAI Baselines, a set of high-quality implementations of reinforcement learning algorithms. To start, we're making available an open source version of Deep Q-Learning and three of its variants.
We're releasing OpenAI Baselines, a set of high-quality implementations of reinforcement learning algorithms.
These algorithms will make it easier for the research community to replicate, refine, and identify new ideas, and will create good baselines to build research on top of. Our DQN implementation and its variants are roughly on par with the scores in published papers. We expect they will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones.
......@@ -12,56 +12,6 @@ You can install it by typing:
pip install baselines
## If you are curious.
##### Train a Cartpole agent and watch it play once it converges!
Here's a list of commands to run to quickly get a working example:
<img src="data/cartpole.gif" width="25%" />
# Train model and save the results to cartpole_model.pkl
python -m baselines.deepq.experiments.train_cartpole
# Load the model saved in cartpole_model.pkl and visualize the learned policy
python -m baselines.deepq.experiments.enjoy_cartpole
Be sure to check out the source code of [both](baselines/deepq/experiments/ [files](baselines/deepq/experiments/!
## If you wish to apply DQN to solve a problem.
Check out our simple agent trained with one stop shop `deepq.learn` function.
- `baselines/deepq/experiments/` - train a Cartpole agent.
- `baselines/deepq/experiments/` - train a Pong agent using convolutional neural networks.
In particular notice that once `deepq.learn` finishes training it returns `act` function which can be used to select actions in the environment. Once trained you can easily save it and load at later time. For both of the files listed above there are complimentary files `` and `` respectively, that load and visualize the learned policy.
## If you wish to experiment with the algorithm
##### Check out the examples
- `baselines/deepq/experiments/` - Cartpole training with more fine grained control over the internals of DQN algorithm.
- `baselines/deepq/experiments/atari/` - more robust setup for training at scale.
##### Download a pretrained Atari agent
For some research projects it is sometimes useful to have an already trained agent handy. There's a variety of models to choose from. You can list them all by running:
python -m baselines.deepq.experiments.atari.download_model
Once you pick a model, you can download it and visualize the learned policy. Be sure to pass `--dueling` flag to visualization script when using dueling models.
python -m baselines.deepq.experiments.atari.download_model --blob model-atari-duel-pong-1 --model-dir /tmp/models
python -m baselines.deepq.experiments.atari.enjoy --model-dir /tmp/models/model-atari-duel-pong-1 --env Pong --dueling
- [DQN](baselines/deepq)
- [PPO](baselines/pposgd)
- [TRPO](baselines/trpo_mpi)
from baselines.bench.benchmarks import *
from baselines.bench.monitor import *
_atari7 = ['BeamRider', 'Breakout', 'Enduro', 'Pong', 'Qbert', 'Seaquest', 'SpaceInvaders']
_atariexpl7 = ['Freeway', 'Gravitar', 'MontezumaRevenge', 'Pitfall', 'PrivateEye', 'Solaris', 'Venture']
def register_benchmark(benchmark):
for b in _BENCHMARKS:
if b['name'] == benchmark['name']:
raise ValueError('Benchmark with name %s already registered!'%b['name'])
def list_benchmarks():
return [b['name'] for b in _BENCHMARKS]
def get_benchmark(benchmark_name):
for b in _BENCHMARKS:
if b['name'] == benchmark_name:
return b
raise ValueError('%s not found! Known benchmarks: %s' % (benchmark_name, list_benchmarks()))
def get_task(benchmark, env_id):
"""Get a task by env_id. Return None if the benchmark doesn't have the env"""
return next(filter(lambda task: task['env_id'] == env_id, benchmark['tasks']), None)
_ATARI_SUFFIX = 'NoFrameskip-v4'
'name' : 'Atari200M',
'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 200M frames',
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(200e6)} for _game in _atari7]
'name' : 'Atari40M',
'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 40M frames',
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(40e6)} for _game in _atari7]
'name' : 'Atari1Hr',
'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 1 hour of walltime',
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_seconds' : 60*60} for _game in _atari7]
'name' : 'AtariExploration40M',
'description' :'7 Atari games emphasizing exploration, with pixel observations, 40M frames',
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(40e6)} for _game in _atariexpl7]
_mujocosmall = [
'InvertedDoublePendulum-v1', 'InvertedPendulum-v1',
'HalfCheetah-v1', 'Hopper-v1', 'Walker2d-v1',
'Reacher-v1', 'Swimmer-v1']
'name' : 'Mujoco1M',
'description' : 'Some small 2D MuJoCo tasks, run for 1M timesteps',
'tasks' : [{'env_id' : _envid, 'trials' : 3, 'num_timesteps' : int(1e6)} for _envid in _mujocosmall]
_roboschool_mujoco = [
'RoboschoolInvertedDoublePendulum-v0', 'RoboschoolInvertedPendulum-v0', # cartpole
'RoboschoolHalfCheetah-v0', 'RoboschoolHopper-v0', 'RoboschoolWalker2d-v0', # forward walkers
'name' : 'RoboschoolMujoco2M',
'description' : 'Same small 2D tasks, still improving up to 2M',
'tasks' : [{'env_id' : _envid, 'trials' : 3, 'num_timesteps' : int(2e6)} for _envid in _roboschool_mujoco]
_atari50 = [ # actually 49
'Alien', 'Amidar', 'Assault', 'Asterix', 'Asteroids',
'Atlantis', 'BankHeist', 'BattleZone', 'BeamRider', 'Bowling',
'Boxing', 'Breakout', 'Centipede', 'ChopperCommand', 'CrazyClimber',
'DemonAttack', 'DoubleDunk', 'Enduro', 'FishingDerby', 'Freeway',
'Frostbite', 'Gopher', 'Gravitar', 'IceHockey', 'Jamesbond',
'Kangaroo', 'Krull', 'KungFuMaster', 'MontezumaRevenge', 'MsPacman',
'NameThisGame', 'Pitfall', 'Pong', 'PrivateEye', 'Qbert',
'Riverraid', 'RoadRunner', 'Robotank', 'Seaquest', 'SpaceInvaders',
'StarGunner', 'Tennis', 'TimePilot', 'Tutankham', 'UpNDown',
'Venture', 'VideoPinball', 'WizardOfWor', 'Zaxxon',
'name' : 'Atari50_40M',
'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 40M frames',
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 3, 'num_timesteps' : int(40e6)} for _game in _atari50]
__all__ = ['Monitor', 'get_monitor_files', 'load_results']
import gym
from gym.core import Wrapper
from os import path
import time
from glob import glob
import ujson as json # Not necessary for monitor writing, but very useful for monitor loading
except ImportError:
import json
class Monitor(Wrapper):
EXT = "monitor.json"
f = None
def __init__(self, env, filename, allow_early_resets=False):
Wrapper.__init__(self, env=env)
self.tstart = time.time()
if filename is None:
self.f = None
self.logger = None
if not filename.endswith(Monitor.EXT):
filename = filename + "." + Monitor.EXT
self.f = open(filename, "wt")
self.logger = JSONLogger(self.f)
self.logger.writekvs({"t_start": self.tstart, "gym_version": gym.__version__,
"env_id": if env.spec else 'Unknown'})
self.allow_early_resets = allow_early_resets
self.rewards = None
self.needs_reset = True
self.episode_rewards = []
self.episode_lengths = []
self.total_steps = 0
self.current_metadata = {} # extra info that gets injected into each log entry
# Useful for metalearning where we're modifying the environment externally
# But want our logs to know about these modifications
def __getstate__(self): # XXX
d = self.__dict__.copy()
if self.f:
del d['f'], d['logger']
d['_filename'] =
d['_num_episodes'] = len(self.episode_rewards)
d['_filename'] = None
return d
def __setstate__(self, d):
filename = d.pop('_filename')
self.__dict__ = d
if filename is not None:
nlines = d.pop('_num_episodes') + 1
self.f = open(filename, "r+t")
for _ in range(nlines):
self.logger = JSONLogger(self.f)
def reset(self):
if not self.allow_early_resets and not self.needs_reset:
raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)")
self.rewards = []
self.needs_reset = False
return self.env.reset()
def step(self, action):
if self.needs_reset:
raise RuntimeError("Tried to step environment that needs reset")
ob, rew, done, info = self.env.step(action)
if done:
self.needs_reset = True
eprew = sum(self.rewards)
eplen = len(self.rewards)
epinfo = {"r": eprew, "l": eplen, "t": round(time.time() - self.tstart, 6)}
if self.logger:
info['episode'] = epinfo
self.total_steps += 1
return (ob, rew, done, info)
def close(self):
if self.f is not None:
def get_total_steps(self):
return self.total_steps
def get_episode_rewards(self):
return self.episode_rewards
def get_episode_lengths(self):
return self.episode_lengths
class JSONLogger(object):
def __init__(self, file):
self.file = file
def writekvs(self, kvs):
for k,v in kvs.items():
if hasattr(v, 'dtype'):
v = v.tolist()
kvs[k] = float(v)
self.file.write(json.dumps(kvs) + '\n')
class LoadMonitorResultsError(Exception):
def get_monitor_files(dir):
return glob(path.join(dir, "*" + Monitor.EXT))
def load_results(dir):
fnames = get_monitor_files(dir)
if not fnames:
raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, dir))
episodes = []
headers = []
for fname in fnames:
with open(fname, 'rt') as fh:
lines = fh.readlines()
header = json.loads(lines[0])
for line in lines[1:]:
episode = json.loads(line)
episode['abstime'] = header['t_start'] + episode['t']
del episode['t']
header0 = headers[0]
for header in headers[1:]:
assert header['env_id'] == header0['env_id'], "mixing data from two envs"
episodes = sorted(episodes, key=lambda e: e['abstime'])
return {
'env_info': {'env_id': header0['env_id'], 'gym_version': header0['gym_version']},
'episode_end_times': [e['abstime'] for e in episodes],
'episode_lengths': [e['l'] for e in episodes],
'episode_rewards': [e['r'] for e in episodes],
'initial_reset_time': min([min(header['t_start'] for header in headers)])
from baselines.common.console_util import *
from baselines.common.dataset import Dataset
from baselines.common.math_util import *
from baselines.common.misc_util import *
import numpy as np
from collections import deque
from PIL import Image
import gym
from gym import spaces
class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def _reset(self):
""" Do no-op action for a number of steps in [1, noop_max]."""
if self.override_num_noops is not None:
noops = self.override_num_noops
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(0)
if done:
obs = self.env.reset()
return obs
class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
obs, _, done, _ = self.env.step(1)
if done:
obs, _, done, _ = self.env.step(2)
if done:
return obs
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
def _step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def _reset(self):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
if self.was_real_done:
obs = self.env.reset()
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = deque(maxlen=2)
self._skip = skip
def _step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
total_reward += reward
if done:
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def _reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
obs = self.env.reset()
return obs
class ClipRewardEnv(gym.RewardWrapper):
def _reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env):
"""Warp frames to 84x84 as done in the Nature paper and later work."""
gym.ObservationWrapper.__init__(self, env)
self.res = 84
self.observation_space = spaces.Box(low=0, high=255, shape=(self.res, self.res, 1))
def _observation(self, obs):
frame ='float32'), np.array([0.299, 0.587, 0.114], 'float32'))
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
resample=Image.BILINEAR), dtype=np.uint8)
return frame.reshape((self.res, self.res, 1))
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
assert shp[2] == 1 # can only stack 1-channel frames
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k))
def _reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k): self.frames.append(ob)
return self._observation()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
return self._observation(), reward, done, info
def _observation(self):
assert len(self.frames) == self.k
return np.concatenate(self.frames, axis=2)
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
"""Configure environment for DeepMind-style Atari.
Note: this does not include frame stacking!"""
assert 'NoFrameskip' in # required for DeepMind-style skip
if episode_life:
env = EpisodicLifeEnv(env)
# env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
return env
import numpy as np
def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
Demmel p 312
p = b.copy()
r = b.copy()
x = np.zeros_like(b)
rdotr =
fmtstr = "%10i %10.3g %10.3g"
titlestr = "%10s %10s %10s"
if verbose: print(titlestr % ("iter", "residual norm", "soln norm"))
for i in range(cg_iters):
if callback is not None:
if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x)))
z = f_Ax(p)
v = rdotr /
x += v*p
r -= v*z
newrdotr =
mu = newrdotr/rdotr
p = r + mu*p
rdotr = newrdotr
if rdotr < residual_tol:
if callback is not None:
if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631
return x
\ No newline at end of file
from __future__ import print_function
from contextlib import contextmanager
import numpy as np
import time
# ================================================================
# Misc
# ================================================================
def fmt_row(width, row, header=False):
out = " | ".join(fmt_item(x, width) for x in row)
if header: out = out + "\n" + "-"*len(out)
return out
def fmt_item(x, l):
if isinstance(x, np.ndarray):
assert x.ndim==0
x = x.item()
if isinstance(x, float): rep = "%g"%x
else: rep = str(x)
return " "*(l - len(rep)) + rep
color2num = dict(
def colorize(string, color, bold=False, highlight=False):
attr = []
num = color2num[color]
if highlight: num += 10
if bold: attr.append('1')
return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)