Skip to content

Commit f8ea299

Browse files
authored
Doc update: custom envs, IsaacLab, Brax and dm_control (#2072)
* Add note about start!=0 for Discrete spaces * Update doc for IsaacLab and dm_control * Fix test due to rounding error
1 parent d055a2e commit f8ea299

File tree

6 files changed

+47
-25
lines changed

6 files changed

+47
-25
lines changed

docs/guide/custom_env.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,24 @@ That is to say, your environment must implement the following methods (and inher
2424
Under the hood, when a channel-last image is passed, SB3 uses a ``VecTransposeImage`` wrapper to re-order the channels.
2525

2626

27+
.. note::
28+
29+
SB3 doesn't support ``Discrete`` and ``MultiDiscrete`` spaces with ``start!=0``. However, you can update your environment or use a wrapper to make your env compatible with SB3:
30+
31+
.. code-block:: python
32+
33+
import gymnasium as gym
34+
35+
class ShiftWrapper(gym.Wrapper):
36+
"""Allow to use Discrete() action spaces with start!=0"""
37+
def __init__(self, env: gym.Env) -> None:
38+
super().__init__(env)
39+
assert isinstance(env.action_space, gym.spaces.Discrete)
40+
self.action_space = gym.spaces.Discrete(env.action_space.n, start=0)
41+
42+
def step(self, action: int):
43+
return self.env.step(action + self.env.action_space.start)
44+
2745
2846
.. code-block:: python
2947

docs/guide/examples.rst

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -735,41 +735,40 @@ A2C policy gradient updates on the model.
735735
print(f"Best fitness: {top_candidates[0][1]:.2f}")
736736
737737
738-
SB3 and ProcgenEnv
739-
------------------
738+
SB3 with Isaac Lab, Brax, Procgen, EnvPool
739+
------------------------------------------
740740

741-
Some environments like `Procgen <https://github.com/openai/procgen>`_ already produce a vectorized
742-
environment (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow
743-
to keep track of the agent progress.
741+
Some massively parallel simulations such as `EnvPool <https://github.com/sail-sg/envpool>`_, `Isaac Lab <https://github.com/isaac-sim/IsaacLab>`_, `Brax <https://github.com/google/brax>`_ or `ProcGen <https://github.com/Farama-Foundation/Procgen2>`_ already produce a vectorized environment to speed up data collection (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_).
744742

745-
.. code-block:: python
743+
To use SB3 with these tools, you need to wrap the env with tool-specific ``VecEnvWrapper`` that pre-processes the data for SB3,
744+
you can find links to some of these wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.
746745

747-
from procgen import ProcgenEnv
746+
- Isaac Lab wrapper: `link <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/sb3.py>`__
747+
- Brax: `link <https://gist.github.com/araffin/a7a576ec1453e74d9bb93120918ef7e7>`__
748+
- EnvPool: `link <https://github.com/sail-sg/envpool/blob/main/examples/sb3_examples/ppo.py>`__
748749

749-
from stable_baselines3 import PPO
750-
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
751750

752-
# ProcgenEnv is already vectorized
753-
venv = ProcgenEnv(num_envs=2, env_name="starpilot")
751+
SB3 with DeepMind Control (dm_control)
752+
--------------------------------------
754753

755-
# To use only part of the observation:
756-
# venv = VecExtractDictObs(venv, "rgb")
754+
If you want to use SB3 with `dm_control <https://github.com/google-deepmind/dm_control>`_, you need to use two wrappers (one from `shimmy <https://github.com/Farama-Foundation/Shimmy>`_, one pre-built one) to convert it to a Gymnasium compatible environment:
757755

758-
# Wrap with a VecMonitor to collect stats and avoid errors
759-
venv = VecMonitor(venv=venv)
756+
.. code-block:: python
760757
761-
model = PPO("MultiInputPolicy", venv, verbose=1)
762-
model.learn(10_000)
758+
import shimmy
759+
import stable_baselines3 as sb3
760+
from dm_control import suite
761+
from gymnasium.wrappers import FlattenObservation
763762
763+
# Available envs:
764+
# suite._DOMAINS and suite.dog.SUITE
764765
765-
SB3 with EnvPool or Isaac Gym
766-
-----------------------------
766+
env = suite.load(domain_name="dog", task_name="run")
767+
gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env))
767768
768-
Just like Procgen (see above), `EnvPool <https://github.com/sail-sg/envpool>`_ and `Isaac Gym <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_ accelerate the environment by
769-
already providing a vectorized implementation.
769+
model = sb3.PPO("MlpPolicy", gym_env, verbose=1)
770+
model.learn(10_000, progress_bar=True)
770771
771-
To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3,
772-
you can find links to those wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.
773772
774773
775774
Record a Video

docs/guide/sbx.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Implemented algorithms:
1818
- Twin Delayed DDPG (TD3)
1919
- Deep Deterministic Policy Gradient (DDPG)
2020
- Batch Normalization in Deep Reinforcement Learning (CrossQ)
21+
- Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)
2122

2223

2324
As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.

docs/misc/changelog.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ Documentation:
4242
- Add FootstepNet Envs to the project page (@cgaspard3333)
4343
- Added FRASA to the project page (@MarcDcls)
4444
- Fixed atari example (@chrisgao99)
45+
- Add a note about ``Discrete`` action spaces with ``start!=0``
46+
- Update doc for massively parallel simulators (Isaac Lab, Brax, ...)
47+
- Add dm_control example
4548

4649
Release 2.4.1 (2024-12-20)
4750
--------------------------

stable_baselines3/common/env_checker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def _check_non_zero_start(space: spaces.Space, space_type: str = "observation",
3737
warnings.warn(
3838
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
3939
"is not supported by Stable-Baselines3. "
40-
f"You can use a wrapper or update your {space_type} space."
40+
"You can use a wrapper (see https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html) "
41+
f"or update your {space_type} space."
4142
)
4243

4344

tests/test_vec_normalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_get_original():
315315
assert not np.array_equal(orig_obs, obs)
316316
assert not np.array_equal(orig_rewards, rewards)
317317
np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs)
318-
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards)
318+
np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards, atol=1e-6)
319319

320320

321321
def test_get_original_dict():

0 commit comments

Comments
 (0)