Skip to content

Commit dc065e2

Browse files
authored
Merge pull request #1 from findmyway/fix_jkg_rl_envs
fix space related definitions
2 parents b606aa1 + f79e9e3 commit dc065e2

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

environments/environment.jl

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ mutable struct Environment{X,T,M,A,O,I}
3131
dynamics_jacobian_state::Matrix{T}
3232
dynamics_jacobian_input::Matrix{T}
3333
input_previous::Vector{T}
34-
control_map::Matrix{T}
34+
control_map::Matrix{T}
3535
num_states::Int
3636
num_inputs::Int
3737
num_observations::Int
@@ -66,33 +66,33 @@ end
6666
attitude_decompress: flag for pre- and post-concatenating Jacobians with attitude Jacobians
6767
"""
6868
function Base.step(env::Environment, x, u;
69-
gradients=false,
70-
attitude_decompress=false)
69+
gradients = false,
70+
attitude_decompress = false)
7171

7272
mechanism = env.mechanism
73-
timestep= mechanism.timestep
73+
timestep = mechanism.timestep
7474

7575
x0 = x
7676
# u = clip(env.input_space, u) # control limits
7777
env.input_previous .= u # for rendering in Gym
78-
u_scaled = env.control_map * u
78+
u_scaled = env.control_map * u
7979

8080
z0 = env.representation == :minimal ? minimal_to_maximal(mechanism, x0) : x0
81-
z1 = step!(mechanism, z0, u_scaled; opts=env.opts_step)
81+
z1 = step!(mechanism, z0, u_scaled; opts = env.opts_step)
8282
env.state .= env.representation == :minimal ? maximal_to_minimal(mechanism, z1) : z1
8383

8484
# Compute cost
8585
costs = cost(env, x, u)
8686

87-
# Check termination
88-
done = is_done(env, x)
87+
# Check termination
88+
done = is_done(env, x)
8989

9090
# Gradients
9191
if gradients
9292
if env.representation == :minimal
93-
fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad)
93+
fx, fu = get_minimal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad)
9494
elseif env.representation == :maximal
95-
fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts=env.opts_grad)
95+
fx, fu = get_maximal_gradients!(env.mechanism, z0, u_scaled, opts = env.opts_grad)
9696
if attitude_decompress
9797
A0 = attitude_jacobian(z0, length(env.mechanism.bodies))
9898
A1 = attitude_jacobian(z1, length(env.mechanism.bodies))
@@ -109,11 +109,11 @@ function Base.step(env::Environment, x, u;
109109
end
110110

111111
function Base.step(env::Environment, u;
112-
gradients=false,
113-
attitude_decompress=false)
114-
step(env, env.state, u;
115-
gradients=gradients,
116-
attitude_decompress=attitude_decompress)
112+
gradients = false,
113+
attitude_decompress = false)
114+
step(env, env.state, u;
115+
gradients = gradients,
116+
attitude_decompress = attitude_decompress)
117117
end
118118

119119
"""
@@ -156,7 +156,7 @@ is_done(env::Environment, x) = false
156156
x: state
157157
"""
158158
function Base.reset(env::Environment{X};
159-
x=nothing) where X
159+
x = nothing) where {X}
160160

161161
initialize!(env.mechanism, type2symbol(X))
162162
if x != nothing
@@ -172,14 +172,14 @@ function Base.reset(env::Environment{X};
172172
return get_observation(env)
173173
end
174174

175-
function MeshCat.render(env::Environment,
176-
mode="human")
175+
function MeshCat.render(env::Environment,
176+
mode = "human")
177177
z = env.representation == :minimal ? minimal_to_maximal(env.mechanism, env.state) : env.state
178-
set_robot(env.vis, env.mechanism, z, name=:robot)
178+
set_robot(env.vis, env.mechanism, z, name = :robot)
179179
return nothing
180180
end
181181

182-
function seed(env::Environment, s=0)
182+
function seed(env::Environment, s = 0)
183183
env.rng[1] = MersenneTwister(s)
184184
return nothing
185185
end
@@ -214,26 +214,20 @@ mutable struct BoxSpace{T,N} <: Space{T,N}
214214
dtype::DataType # this is always T, it's needed to interface with Stable-Baselines
215215
end
216216

217-
function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where T
217+
function BoxSpace(n::Int; low::AbstractVector{T} = -ones(n), high::AbstractVector{T} = ones(n)) where {T}
218218
return BoxSpace{T,n}(n, low, high, (n,), T)
219219
end
220220

221221
function sample(s::BoxSpace{T,N}) where {T,N}
222-
return rand(T,N) .* (s.high .- s.low) .+ s.low
222+
return rand(T, N) .* (s.high .- s.low) .+ s.low
223223
end
224224

225225
function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N}
226226
all(v .>= s.low) && all(v .<= s.high)
227227
end
228228

229-
# For compat with RLBase
230-
Base.length(s::BoxSpace) = s.n
231-
Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high)
232-
Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low
233-
234229
function clip(s::BoxSpace, u)
235230
clamp.(u, s.low, s.high)
236231
end
237232

238-
239-
233+
Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T, N) .* (s.high .- s.low) .+ s.low

environments/rlenv.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ function DojoRLEnv(name::String; kwargs...)
1717
DojoRLEnv(Dojo.get_environment(name; kwargs...))
1818
end
1919

20-
RLBase.action_space(env::DojoRLEnv) = env.dojoenv.input_space
21-
RLBase.state_space(env::DojoRLEnv) = env.dojoenv.observation_space
20+
function Base.convert(::Type{RLBase.Space}, s::BoxSpace)
21+
RLBase.Space([BoxSpace(1; low = s.low[i:i], high = s.high[i:i]) for i in 1:s.n])
22+
end
23+
24+
RLBase.action_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.input_space)
25+
RLBase.state_space(env::DojoRLEnv) = convert(RLBase.Space, env.dojoenv.observation_space)
2226
RLBase.is_terminated(env::DojoRLEnv) = env.done
2327

2428
RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv)
@@ -39,4 +43,4 @@ function (env::DojoRLEnv)(a)
3943
env.info = i
4044
return nothing
4145
end
42-
(env::DojoRLEnv)(a::AbstractFloat) = env([a])
46+
(env::DojoRLEnv)(a::Number) = env([a])

0 commit comments

Comments
 (0)