Skip to content

Commit 8123b35

Browse files
author
krishnan
committed
Remove step in SAC._train_step arglist, fixes issue of redundant wrapping of HER replay_buffer
1 parent b7e0f40 commit 8123b35

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ Bug Fixes:
6565
- Fixed a bug in CloudPickleWrapper's (used by VecEnvs) ``__setstate___`` where loading was incorrectly using ``pickle.loads`` (@shwang).
6666
- Fixed a bug in ``SAC`` and ``TD3`` where the log timesteps was not correct(@YangRui2015)
6767
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
68-
- Fixed a bug where ``SAC`` uses wrong step to log to tensorboard after multiple calls to ``SAC.learn(..., reset_num_timesteps=True)``
68+
- Fixed a bug where ``SAC`` uses wrong step to log to tensorboard after multiple calls to ``SAC.learn(..., reset_num_timesteps=True)`` (@krishpop)
69+
- Fixed issue where HER replay buffer wrapper is used multiple times after multiple calls to ``HER.learn`` (@krishpop)
6970

7071
Deprecations:
7172
^^^^^^^^^^^^^

stable_baselines/her/her.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _create_replay_wrapper(self, env):
7070
n_sampled_goal=self.n_sampled_goal,
7171
goal_selection_strategy=self.goal_selection_strategy,
7272
wrapped_env=self.env)
73+
self.wrapped_buffer = False
7374

7475
def set_env(self, env):
7576
assert not isinstance(env, VecEnvWrapper), "HER does not support VecEnvWrapper"
@@ -108,9 +109,11 @@ def setup_model(self):
108109

109110
def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="HER",
110111
reset_num_timesteps=True):
112+
replay_wrapper = self.replay_wrapper if not self.wrapped_buffer else None
113+
self.wrapped_buffer = True
111114
return self.model.learn(total_timesteps, callback=callback, log_interval=log_interval,
112115
tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps,
113-
replay_wrapper=self.replay_wrapper)
116+
replay_wrapper=replay_wrapper)
114117

115118
def _check_obs(self, observation):
116119
if isinstance(observation, dict):

stable_baselines/sac/sac.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,7 @@ def setup_model(self):
313313

314314
self.summary = tf.summary.merge_all()
315315

316-
def _train_step(self, step, writer, learning_rate):
317-
del step
316+
def _train_step(self, writer, learning_rate):
318317
# Sample a batch from the replay buffer
319318
batch = self.replay_buffer.sample(self.batch_size, env=self._vec_normalize_env)
320319
batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = batch
@@ -461,7 +460,7 @@ def learn(self, total_timesteps, callback=None,
461460
frac = 1.0 - step / total_timesteps
462461
current_lr = self.learning_rate(frac)
463462
# Update policy and critics (q functions)
464-
mb_infos_vals.append(self._train_step(step, writer, current_lr))
463+
mb_infos_vals.append(self._train_step(writer, current_lr))
465464
# Update target network
466465
if (step + grad_step) % self.target_update_interval == 0:
467466
# Update target network

0 commit comments

Comments
 (0)