Skip to content

Commit 94caf2c

Browse files
committed
add an example
1 parent 50381ce commit 94caf2c

File tree

4 files changed

+77
-4
lines changed

4 files changed

+77
-4
lines changed

environments/environment.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ function MeshCat.render(env::Environment,
179179
return nothing
180180
end
181181

182-
function seed(env::Environment; s=0)
183-
env.rng[1] = MersenneTwister(seed)
182+
function seed(env::Environment, s=0)
183+
env.rng[1] = MersenneTwister(s)
184184
return nothing
185185
end
186186

@@ -227,6 +227,7 @@ function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N}
227227
end
228228

229229
# For compat with RLBase
230+
Base.length(s::BoxSpace) = s.n
230231
Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high)
231232
Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low
232233

environments/rlenv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RLBase.is_terminated(env::DojoRLEnv) = env.done
2323

2424
RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv)
2525

26-
RLBase.reward(env::DojoRLEnv) = error()
26+
RLBase.reward(env::DojoRLEnv) = env.reward
2727
RLBase.state(env::DojoRLEnv) = env.state
2828

2929
Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed)
@@ -33,7 +33,7 @@ Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed)
3333

3434
function (env::DojoRLEnv)(a)
3535
s, r, d, i = step(env.dojoenv, a)
36-
env.state = s
36+
env.state .= s
3737
env.reward = r
3838
env.done = d
3939
env.info = i

examples/deeprl/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[deps]
2+
Dojo = "ac60b53e-8d92-4c83-b960-e78698fa1916"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5+
ReinforcementLearning = "158674fc-8238-5cab-b5ba-03dfc80d1318"

examples/deeprl/ant_ppo.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using ReinforcementLearning
2+
using Flux
3+
using Flux.Losses
4+
5+
using Random
6+
using Dojo
7+
8+
function RL.Experiment(
9+
::Val{:JuliaRL},
10+
::Val{:PPO},
11+
::Val{:DojoAnt},
12+
::Nothing,
13+
save_dir = nothing,
14+
seed = 42
15+
)
16+
rng = MersenneTwister(seed)
17+
N_ENV = 6
18+
UPDATE_FREQ = 32
19+
env_vec = [Dojo.DojoRLEnv("ant") for i in 1:N_ENV]
20+
for i in 1:N_ENV
21+
Random.seed!(env_vec[i], hash(seed+i))
22+
end
23+
env = MultiThreadEnv(env_vec)
24+
25+
ns, na = length(state(env[1])), length(action_space(env[1]))
26+
RLBase.reset!(env; is_force=true)
27+
28+
agent = Agent(
29+
policy = PPOPolicy(
30+
approximator = ActorCritic(
31+
actor = Chain(
32+
Dense(ns, 256, relu; init = glorot_uniform(rng)),
33+
Dense(256, na; init = glorot_uniform(rng)),
34+
),
35+
critic = Chain(
36+
Dense(ns, 256, relu; init = glorot_uniform(rng)),
37+
Dense(256, 1; init = glorot_uniform(rng)),
38+
),
39+
optimizer = ADAM(1e-3),
40+
),
41+
γ = 0.99f0,
42+
λ = 0.95f0,
43+
clip_range = 0.1f0,
44+
max_grad_norm = 0.5f0,
45+
n_epochs = 4,
46+
n_microbatches = 4,
47+
actor_loss_weight = 1.0f0,
48+
critic_loss_weight = 0.5f0,
49+
entropy_loss_weight = 0.001f0,
50+
update_freq = UPDATE_FREQ,
51+
),
52+
trajectory = PPOTrajectory(;
53+
capacity = UPDATE_FREQ,
54+
state = Matrix{Float32} => (ns, N_ENV),
55+
action = Vector{Int} => (N_ENV,),
56+
action_log_prob = Vector{Float32} => (N_ENV,),
57+
reward = Vector{Float32} => (N_ENV,),
58+
terminal = Vector{Bool} => (N_ENV,),
59+
),
60+
)
61+
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
62+
hook = TotalBatchRewardPerEpisode(N_ENV)
63+
Experiment(agent, env, stop_condition, hook, "# PPO with Dojo Ant")
64+
end
65+
66+
ex = E`JuliaRL_PPO_DojoAnt`
67+
run(ex)

0 commit comments

Comments
 (0)