Skip to content

Commit 000544c

Browse files
araffinqgallouedec
andauthored
Add support for pre and post linear modules in create_mlp (#1975)
* Add support for pre and post linear modules in `create_mlp` * Disable mypy for python 3.8 * Reformat toml file * Update docstring Co-authored-by: Quentin Gallouédec <[email protected]> * Add some comments --------- Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 1a69fc8 commit 000544c

File tree

6 files changed

+157
-61
lines changed

6 files changed

+157
-61
lines changed

.github/workflows/ci.yml

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ name: CI
55

66
on:
77
push:
8-
branches: [ master ]
8+
branches: [master]
99
pull_request:
10-
branches: [ master ]
10+
branches: [master]
1111

1212
jobs:
1313
build:
@@ -23,38 +23,40 @@ jobs:
2323
python-version: ["3.8", "3.9", "3.10", "3.11"]
2424

2525
steps:
26-
- uses: actions/checkout@v3
27-
- name: Set up Python ${{ matrix.python-version }}
28-
uses: actions/setup-python@v4
29-
with:
30-
python-version: ${{ matrix.python-version }}
31-
- name: Install dependencies
32-
run: |
33-
python -m pip install --upgrade pip
34-
# cpu version of pytorch
35-
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
26+
- uses: actions/checkout@v3
27+
- name: Set up Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v4
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
- name: Install dependencies
32+
run: |
33+
python -m pip install --upgrade pip
34+
# cpu version of pytorch
35+
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
3636
37-
# Install Atari Roms
38-
pip install autorom
39-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
40-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
41-
AutoROM --accept-license --source-file Roms.tar.gz
37+
# Install Atari Roms
38+
pip install autorom
39+
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
40+
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
41+
AutoROM --accept-license --source-file Roms.tar.gz
4242
43-
pip install .[extra_no_roms,tests,docs]
44-
# Use headless version
45-
pip install opencv-python-headless
46-
- name: Lint with ruff
47-
run: |
48-
make lint
49-
- name: Build the doc
50-
run: |
51-
make doc
52-
- name: Check codestyle
53-
run: |
54-
make check-codestyle
55-
- name: Type check
56-
run: |
57-
make type
58-
- name: Test with pytest
59-
run: |
60-
make pytest
43+
pip install .[extra_no_roms,tests,docs]
44+
# Use headless version
45+
pip install opencv-python-headless
46+
- name: Lint with ruff
47+
run: |
48+
make lint
49+
- name: Build the doc
50+
run: |
51+
make doc
52+
- name: Check codestyle
53+
run: |
54+
make check-codestyle
55+
- name: Type check
56+
run: |
57+
make type
58+
# Do not run for python 3.8 (mypy internal error)
59+
if: matrix.python-version != '3.8'
60+
- name: Test with pytest
61+
run: |
62+
make pytest

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a5 (WIP)
6+
Release 2.4.0a6 (WIP)
77
--------------------------
88

99
Breaking Changes:
1010
^^^^^^^^^^^^^^^^^
1111

1212
New Features:
1313
^^^^^^^^^^^^^
14+
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
1415

1516
Bug Fixes:
1617
^^^^^^^^^^

pyproject.toml

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ ignore = ["B028", "RUF013"]
1313

1414
[tool.ruff.lint.per-file-ignores]
1515
# Default implementation in abstract methods
16-
"./stable_baselines3/common/callbacks.py"= ["B027"]
17-
"./stable_baselines3/common/noise.py"= ["B027"]
16+
"./stable_baselines3/common/callbacks.py" = ["B027"]
17+
"./stable_baselines3/common/noise.py" = ["B027"]
1818
# ClassVar, implicit optional check not needed for tests
19-
"./tests/*.py"= ["RUF012", "RUF013"]
19+
"./tests/*.py" = ["RUF012", "RUF013"]
2020

2121

2222
[tool.ruff.lint.mccabe]
@@ -37,33 +37,35 @@ exclude = """(?x)(
3737

3838
[tool.pytest.ini_options]
3939
# Deterministic ordering for tests; useful for pytest-xdist.
40-
env = [
41-
"PYTHONHASHSEED=0"
42-
]
40+
env = ["PYTHONHASHSEED=0"]
4341

4442
filterwarnings = [
4543
# Tensorboard warnings
4644
"ignore::DeprecationWarning:tensorboard",
4745
# Gymnasium warnings
4846
"ignore::UserWarning:gymnasium",
4947
# tqdm warning about rich being experimental
50-
"ignore:rich is experimental"
48+
"ignore:rich is experimental",
5149
]
5250
markers = [
53-
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
51+
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
5452
]
5553

5654
[tool.coverage.run]
5755
disable_warnings = ["couldnt-parse"]
5856
branch = false
5957
omit = [
60-
"tests/*",
61-
"setup.py",
62-
# Require graphical interface
63-
"stable_baselines3/common/results_plotter.py",
64-
# Require ffmpeg
65-
"stable_baselines3/common/vec_env/vec_video_recorder.py",
58+
"tests/*",
59+
"setup.py",
60+
# Require graphical interface
61+
"stable_baselines3/common/results_plotter.py",
62+
# Require ffmpeg
63+
"stable_baselines3/common/vec_env/vec_video_recorder.py",
6664
]
6765

6866
[tool.coverage.report]
69-
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
67+
exclude_lines = [
68+
"pragma: no cover",
69+
"raise NotImplementedError()",
70+
"if typing.TYPE_CHECKING:",
71+
]

stable_baselines3/common/torch_layers.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple, Type, Union
1+
from typing import Dict, List, Optional, Tuple, Type, Union
22

33
import gymnasium as gym
44
import torch as th
@@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module):
1414
"""
1515
Base class that represents a features extractor.
1616
17-
:param observation_space:
17+
:param observation_space: The observation space of the environment
1818
:param features_dim: Number of features extracted.
1919
"""
2020

@@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
2626

2727
@property
2828
def features_dim(self) -> int:
29+
"""The number of features that the extractor outputs."""
2930
return self._features_dim
3031

3132

@@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
3435
Feature extract that flatten the input.
3536
Used as a placeholder when feature extraction is not needed.
3637
37-
:param observation_space:
38+
:param observation_space: The observation space of the environment
3839
"""
3940

4041
def __init__(self, observation_space: gym.Space) -> None:
@@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor):
5253
"Human-level control through deep reinforcement learning."
5354
Nature 518.7540 (2015): 529-533.
5455
55-
:param observation_space:
56+
:param observation_space: The observation space of the environment
5657
:param features_dim: Number of features extracted.
5758
This corresponds to the number of unit for the last layer.
5859
:param normalized_image: Whether to assume that the image is already normalized
@@ -113,13 +114,15 @@ def create_mlp(
113114
activation_fn: Type[nn.Module] = nn.ReLU,
114115
squash_output: bool = False,
115116
with_bias: bool = True,
117+
pre_linear_modules: Optional[List[Type[nn.Module]]] = None,
118+
post_linear_modules: Optional[List[Type[nn.Module]]] = None,
116119
) -> List[nn.Module]:
117120
"""
118121
Create a multi layer perceptron (MLP), which is
119122
a collection of fully-connected layers each followed by an activation function.
120123
121124
:param input_dim: Dimension of the input vector
122-
:param output_dim:
125+
:param output_dim: Dimension of the output (last layer, for instance, the number of actions)
123126
:param net_arch: Architecture of the neural net
124127
It represents the number of units per layer.
125128
The length of this list is the number of layers.
@@ -128,20 +131,52 @@ def create_mlp(
128131
:param squash_output: Whether to squash the output using a Tanh
129132
activation function
130133
:param with_bias: If set to False, the layers will not learn an additive bias
131-
:return:
134+
:param pre_linear_modules: List of nn.Module to add before the linear layers.
135+
These modules should maintain the input tensor dimension (e.g. BatchNorm).
136+
The number of input features is passed to the module's constructor.
137+
Compared to post_linear_modules, they are used before the output layer (output_dim > 0).
138+
:param post_linear_modules: List of nn.Module to add after the linear layers
139+
(and before the activation function). These modules should maintain the input
140+
tensor dimension (e.g. Dropout, LayerNorm). They are not used after the
141+
output layer (output_dim > 0). The number of input features is passed to
142+
the module's constructor.
143+
:return: The list of layers of the neural network
132144
"""
133145

146+
pre_linear_modules = pre_linear_modules or []
147+
post_linear_modules = post_linear_modules or []
148+
149+
modules = []
134150
if len(net_arch) > 0:
135-
modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()]
136-
else:
137-
modules = []
151+
# BatchNorm maintains input dim
152+
for module in pre_linear_modules:
153+
modules.append(module(input_dim))
154+
155+
modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias))
156+
157+
# LayerNorm, Dropout maintain output dim
158+
for module in post_linear_modules:
159+
modules.append(module(net_arch[0]))
160+
161+
modules.append(activation_fn())
138162

139163
for idx in range(len(net_arch) - 1):
164+
for module in pre_linear_modules:
165+
modules.append(module(net_arch[idx]))
166+
140167
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias))
168+
169+
for module in post_linear_modules:
170+
modules.append(module(net_arch[idx + 1]))
171+
141172
modules.append(activation_fn())
142173

143174
if output_dim > 0:
144175
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
176+
# Only add BatchNorm before output layer
177+
for module in pre_linear_modules:
178+
modules.append(module(last_layer_dim))
179+
145180
modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias))
146181
if squash_output:
147182
modules.append(nn.Tanh())

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a5
1+
2.4.0a6

tests/test_custom_policy.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
22
import torch as th
3+
import torch.nn as nn
34

45
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
56
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
7+
from stable_baselines3.common.torch_layers import create_mlp
68

79

810
@pytest.mark.parametrize(
@@ -62,3 +64,57 @@ def test_tf_like_rmsprop_optimizer():
6264
def test_dqn_custom_policy():
6365
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
6466
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
67+
68+
69+
def test_create_mlp():
70+
net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True)
71+
# We cannot compare the network directly because the modules have different ids
72+
# assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2),
73+
# nn.Tanh()]
74+
assert len(net) == 6
75+
assert isinstance(net[0], nn.Linear)
76+
assert net[0].in_features == 4
77+
assert net[0].out_features == 16
78+
assert isinstance(net[1], nn.ReLU)
79+
assert isinstance(net[2], nn.Linear)
80+
assert isinstance(net[4], nn.Linear)
81+
assert net[4].in_features == 8
82+
assert net[4].out_features == 2
83+
assert isinstance(net[5], nn.Tanh)
84+
85+
# Linear network
86+
net = create_mlp(4, -1, net_arch=[])
87+
assert net == []
88+
89+
# No output layer, with custom activation function
90+
net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh)
91+
# assert net == [nn.Linear(6, 8), nn.Tanh()]
92+
assert len(net) == 2
93+
assert isinstance(net[0], nn.Linear)
94+
assert net[0].in_features == 6
95+
assert net[0].out_features == 8
96+
assert isinstance(net[1], nn.Tanh)
97+
98+
# Using pre-linear and post-linear modules
99+
pre_linear = [nn.BatchNorm1d]
100+
post_linear = [nn.LayerNorm]
101+
net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear)
102+
# assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU()
103+
# nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(),
104+
# nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear
105+
assert len(net) == 10
106+
assert isinstance(net[0], nn.BatchNorm1d)
107+
assert net[0].num_features == 6
108+
assert isinstance(net[1], nn.Linear)
109+
assert isinstance(net[2], nn.LayerNorm)
110+
assert isinstance(net[3], nn.ReLU)
111+
assert isinstance(net[4], nn.BatchNorm1d)
112+
assert isinstance(net[5], nn.Linear)
113+
assert net[5].in_features == 8
114+
assert net[5].out_features == 12
115+
assert isinstance(net[6], nn.LayerNorm)
116+
assert isinstance(net[7], nn.ReLU)
117+
assert isinstance(net[8], nn.BatchNorm1d)
118+
assert isinstance(net[-1], nn.Linear)
119+
assert net[-1].in_features == 12
120+
assert net[-1].out_features == 2

0 commit comments

Comments
 (0)