|
15 | 15 | from stable_baselines3.common.monitor import Monitor |
16 | 16 | from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise |
17 | 17 | from stable_baselines3.common.utils import ( |
| 18 | + check_shape_equal, |
18 | 19 | get_parameters_by_name, |
19 | 20 | get_system_info, |
20 | 21 | is_vectorized_observation, |
@@ -509,3 +510,23 @@ def test_is_vectorized_observation(): |
509 | 510 | discrete_obs = np.ones((1, 1), dtype=np.int8) |
510 | 511 | dict_obs = {"box": box_obs, "discrete": discrete_obs} |
511 | 512 | is_vectorized_observation(dict_obs, dict_space) |
| 513 | + |
| 514 | + |
| 515 | +def test_check_shape_equal(): |
| 516 | + space1 = spaces.Box(low=0, high=1, shape=(2, 2)) |
| 517 | + space2 = spaces.Box(low=-1, high=1, shape=(2, 2)) |
| 518 | + check_shape_equal(space1, space2) |
| 519 | + |
| 520 | + space1 = spaces.Box(low=0, high=1, shape=(2, 2)) |
| 521 | + space2 = spaces.Box(low=-1, high=2, shape=(3, 3)) |
| 522 | + with pytest.raises(AssertionError): |
| 523 | + check_shape_equal(space1, space2) |
| 524 | + |
| 525 | + space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))}) |
| 526 | + space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(2, 2)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))}) |
| 527 | + check_shape_equal(space1, space2) |
| 528 | + |
| 529 | + space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))}) |
| 530 | + space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(3, 3)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))}) |
| 531 | + with pytest.raises(AssertionError): |
| 532 | + check_shape_equal(space1, space2) |
0 commit comments