Commit 902ffcb7 authored by John Schulman's avatar John Schulman Committed by GitHub
Browse files

Merge pull request #120 from hamzamerzic/tensorflow_global_variable

Deprecated VARIABLES -> GLOBAL_VARIABLES.
parents 4e2a570e a7320b80
...@@ -49,7 +49,7 @@ class CnnPolicy(object): ...@@ -49,7 +49,7 @@ class CnnPolicy(object):
ac1, vpred1 = self._act(stochastic, ob[None]) ac1, vpred1 = self._act(stochastic, ob[None])
return ac1[0], vpred1[0] return ac1[0], vpred1[0]
def get_variables(self): def get_variables(self):
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope) return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self): def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_initial_state(self): def get_initial_state(self):
......
...@@ -51,7 +51,7 @@ class MlpPolicy(object): ...@@ -51,7 +51,7 @@ class MlpPolicy(object):
ac1, vpred1 = self._act(stochastic, ob[None]) ac1, vpred1 = self._act(stochastic, ob[None])
return ac1[0], vpred1[0] return ac1[0], vpred1[0]
def get_variables(self): def get_variables(self):
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope) return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self): def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_initial_state(self): def get_initial_state(self):
......
...@@ -49,7 +49,7 @@ class CnnPolicy(object): ...@@ -49,7 +49,7 @@ class CnnPolicy(object):
ac1, vpred1 = self._act(stochastic, ob[None]) ac1, vpred1 = self._act(stochastic, ob[None])
return ac1[0], vpred1[0] return ac1[0], vpred1[0]
def get_variables(self): def get_variables(self):
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope) return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self): def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_initial_state(self): def get_initial_state(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment