Skip to content

Commit 12e9917

Browse files
qgallouedecaraffin
andauthored
Fix image-based normalized env loading (#1321)
* Fix * Add test * Update changelog * fix memory error avoidance * Update version * image env test * black * check_shape_equal * check shape equal in vecnormalize * Allow spaces not to be box or dict * rm `test_save_load_vecnormalized_image` in favor of `test_vec_env` * Remove unused imports --------- Co-authored-by: Antonin RAFFIN <[email protected]> Co-authored-by: Antonin Raffin <[email protected]>
1 parent 7a1e429 commit 12e9917

File tree

6 files changed

+55
-7
lines changed

6 files changed

+55
-7
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 1.8.0a4 (WIP)
7+
Release 1.8.0a5 (WIP)
88
--------------------------
99

1010

@@ -29,6 +29,7 @@ Bug Fixes:
2929
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
3030
- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio)
3131
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
32+
- Fixed loading of normalized image-based environments
3233

3334
Deprecations:
3435
^^^^^^^^^^^^^
@@ -212,7 +213,7 @@ Bug Fixes:
212213
- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb)
213214
- Fixed the issue that when updating the target network in DQN, SAC, TD3, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
214215
- Fixed incorrect type annotation of the replay_buffer_class argument in ``common.OffPolicyAlgorithm`` initializer, where an instance instead of a class was required (@Rocamonde)
215-
- Fixed loading saved model with different number of envrionments
216+
- Fixed loading saved model with different number of environments
216217
- Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde)
217218
- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde)
218219
- Fixed an issue where keys with different tags but the same key raised an error in ``common.logger.HumanOutputFormat`` (@Rocamonde and @AdamGleave)

stable_baselines3/common/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,24 @@ def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, actio
230230
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
231231

232232

233+
def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
234+
"""
235+
If the spaces are Box, check that they have the same shape.
236+
237+
If the spaces are Dict, it recursively checks the subspaces.
238+
239+
:param space1: Space
240+
:param space2: Other space
241+
"""
242+
if isinstance(space1, spaces.Dict):
243+
assert isinstance(space2, spaces.Dict), "spaces must be of the same type"
244+
assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys"
245+
for key in space1.spaces.keys():
246+
check_shape_equal(space1.spaces[key], space2.spaces[key])
247+
elif isinstance(space1, spaces.Box):
248+
assert space1.shape == space2.shape, "spaces must have the same shape"
249+
250+
233251
def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
234252
"""
235253
For box observation type, detects and validates the shape,

stable_baselines3/common/vec_env/vec_normalize.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import pickle
23
from copy import deepcopy
34
from typing import Any, Dict, List, Optional, Union
@@ -159,10 +160,12 @@ def set_venv(self, venv: VecEnv) -> None:
159160
"""
160161
if self.venv is not None:
161162
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
162-
VecEnvWrapper.__init__(self, venv)
163+
self.venv = venv
164+
self.num_envs = venv.num_envs
165+
self.class_attributes = dict(inspect.getmembers(self.__class__))
163166

164-
# Check only that the observation_space match
165-
utils.check_for_correct_spaces(venv, self.observation_space, venv.action_space)
167+
# Check that the observation_space shape match
168+
utils.check_shape_equal(self.observation_space, venv.observation_space)
166169
self.returns = np.zeros(self.num_envs)
167170

168171
def step_wait(self) -> VecEnvStepReturn:

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.8.0a4
1+
1.8.0a5

tests/test_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from stable_baselines3.common.monitor import Monitor
1616
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
1717
from stable_baselines3.common.utils import (
18+
check_shape_equal,
1819
get_parameters_by_name,
1920
get_system_info,
2021
is_vectorized_observation,
@@ -509,3 +510,23 @@ def test_is_vectorized_observation():
509510
discrete_obs = np.ones((1, 1), dtype=np.int8)
510511
dict_obs = {"box": box_obs, "discrete": discrete_obs}
511512
is_vectorized_observation(dict_obs, dict_space)
513+
514+
515+
def test_check_shape_equal():
516+
space1 = spaces.Box(low=0, high=1, shape=(2, 2))
517+
space2 = spaces.Box(low=-1, high=1, shape=(2, 2))
518+
check_shape_equal(space1, space2)
519+
520+
space1 = spaces.Box(low=0, high=1, shape=(2, 2))
521+
space2 = spaces.Box(low=-1, high=2, shape=(3, 3))
522+
with pytest.raises(AssertionError):
523+
check_shape_equal(space1, space2)
524+
525+
space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))})
526+
space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(2, 2)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))})
527+
check_shape_equal(space1, space2)
528+
529+
space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))})
530+
space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(3, 3)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))})
531+
with pytest.raises(AssertionError):
532+
check_shape_equal(space1, space2)

tests/test_vec_normalize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from gym import spaces
88

99
from stable_baselines3 import SAC, TD3, HerReplayBuffer
10+
from stable_baselines3.common.envs import FakeImageEnv
1011
from stable_baselines3.common.monitor import Monitor
1112
from stable_baselines3.common.running_mean_std import RunningMeanStd
1213
from stable_baselines3.common.vec_env import (
@@ -118,6 +119,10 @@ def make_dict_env():
118119
return Monitor(DummyDictEnv())
119120

120121

122+
def make_image_env():
123+
return Monitor(FakeImageEnv())
124+
125+
121126
def check_rms_equal(rmsa, rmsb):
122127
if isinstance(rmsa, dict):
123128
for key in rmsa.keys():
@@ -244,7 +249,7 @@ def test_obs_rms_vec_normalize():
244249
assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3)
245250

246251

247-
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
252+
@pytest.mark.parametrize("make_env", [make_env, make_dict_env, make_image_env])
248253
def test_vec_env(tmp_path, make_env):
249254
"""Test VecNormalize Object"""
250255
clip_obs = 0.5

0 commit comments

Comments
 (0)