Skip to content

Commit dd66fa5

Browse files
committed
add simple cartpole example; issue is with RL.jl
1 parent 94caf2c commit dd66fa5

File tree

3 files changed

+157
-1
lines changed

3 files changed

+157
-1
lines changed

environments/rlenv.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@ function (env::DojoRLEnv)(a)
3838
env.done = d
3939
env.info = i
4040
return nothing
41-
end
41+
end
42+
(env::DojoRLEnv)(a::AbstractFloat) = env([a])

examples/deeprl/cartpole_ddpg.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
using ReinforcementLearning
2+
using Flux
3+
using Flux.Losses
4+
5+
using Random
6+
using Dojo
7+
8+
function RL.Experiment(
9+
::Val{:JuliaRL},
10+
::Val{:DDPG},
11+
::Val{:DojoCartpole},
12+
::Nothing,
13+
save_dir = nothing,
14+
seed = 42
15+
)
16+
17+
rng = MersenneTwister(seed)
18+
inner_env = Dojo.DojoRLEnv("cartpole")
19+
Random.seed!(inner_env, seed)
20+
# TODO
21+
low = -5.0
22+
high = 5.0
23+
ns, na = length(state(inner_env)), length(action_space(inner_env))
24+
@show na
25+
A = Dojo.BoxSpace(na)
26+
env = ActionTransformedEnv(
27+
inner_env;
28+
action_mapping = x -> low .+ (x .+ 1) .* 0.5 .* (high .- low),
29+
action_space_mapping = _ -> A
30+
)
31+
32+
init = glorot_uniform(rng)
33+
34+
create_actor() = Chain(
35+
Dense(ns, 30, relu; init = init),
36+
Dense(30, 30, relu; init = init),
37+
Dense(30, na, tanh; init = init),
38+
)
39+
create_critic() = Chain(
40+
Dense(ns + na, 30, relu; init = init),
41+
Dense(30, 30, relu; init = init),
42+
Dense(30, 1; init = init),
43+
)
44+
45+
agent = Agent(
46+
policy = DDPGPolicy(
47+
behavior_actor = NeuralNetworkApproximator(
48+
model = create_actor(),
49+
optimizer = ADAM(),
50+
),
51+
behavior_critic = NeuralNetworkApproximator(
52+
model = create_critic(),
53+
optimizer = ADAM(),
54+
),
55+
target_actor = NeuralNetworkApproximator(
56+
model = create_actor(),
57+
optimizer = ADAM(),
58+
),
59+
target_critic = NeuralNetworkApproximator(
60+
model = create_critic(),
61+
optimizer = ADAM(),
62+
),
63+
γ = 0.99f0,
64+
ρ = 0.995f0,
65+
na = na,
66+
batch_size = 64,
67+
start_steps = 1000,
68+
start_policy = RandomPolicy(A; rng = rng),
69+
update_after = 1000,
70+
update_freq = 1,
71+
act_limit = 1.0,
72+
act_noise = 0.1,
73+
rng = rng,
74+
),
75+
trajectory = CircularArraySARTTrajectory(
76+
capacity = 10000,
77+
state = Vector{Float32} => (ns,),
78+
action = Float32 => (na, ),
79+
),
80+
)
81+
82+
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
83+
hook = TotalRewardPerEpisode()
84+
Experiment(agent, env, stop_condition, hook, "# Dojo Cartpole with DDPG")
85+
end
86+
87+
ex = E`JuliaRL_DDPG_DojoCartpole`
88+
run(ex)

examples/deeprl/cartpole_ppo.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using ReinforcementLearning
2+
using Flux
3+
using Flux.Losses
4+
5+
using Random
6+
using Dojo
7+
8+
function RL.Experiment(
9+
::Val{:JuliaRL},
10+
::Val{:PPO},
11+
::Val{:DojoCartpole},
12+
::Nothing,
13+
save_dir = nothing,
14+
seed = 42
15+
)
16+
rng = MersenneTwister(seed)
17+
N_ENV = 6
18+
UPDATE_FREQ = 32
19+
env_vec = [Dojo.DojoRLEnv("cartpole") for i in 1:N_ENV]
20+
for i in 1:N_ENV
21+
Random.seed!(env_vec[i], hash(seed+i))
22+
end
23+
env = MultiThreadEnv(env_vec)
24+
25+
ns, na = length(state(env[1])), length(action_space(env[1]))
26+
RLBase.reset!(env; is_force=true)
27+
28+
agent = Agent(
29+
policy = PPOPolicy(
30+
approximator = ActorCritic(
31+
actor = Chain(
32+
Dense(ns, 256, relu; init = glorot_uniform(rng)),
33+
Dense(256, na; init = glorot_uniform(rng)),
34+
),
35+
critic = Chain(
36+
Dense(ns, 256, relu; init = glorot_uniform(rng)),
37+
Dense(256, 1; init = glorot_uniform(rng)),
38+
),
39+
optimizer = ADAM(1e-3),
40+
),
41+
γ = 0.99f0,
42+
λ = 0.95f0,
43+
clip_range = 0.1f0,
44+
max_grad_norm = 0.5f0,
45+
n_epochs = 4,
46+
n_microbatches = 4,
47+
actor_loss_weight = 1.0f0,
48+
critic_loss_weight = 0.5f0,
49+
entropy_loss_weight = 0.001f0,
50+
update_freq = UPDATE_FREQ,
51+
),
52+
trajectory = PPOTrajectory(;
53+
capacity = UPDATE_FREQ,
54+
state = Matrix{Float32} => (ns, N_ENV),
55+
action = Vector{Int} => (N_ENV,),
56+
action_log_prob = Vector{Float32} => (N_ENV,),
57+
reward = Vector{Float32} => (N_ENV,),
58+
terminal = Vector{Bool} => (N_ENV,),
59+
),
60+
)
61+
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
62+
hook = TotalBatchRewardPerEpisode(N_ENV)
63+
Experiment(agent, env, stop_condition, hook, "# PPO with Dojo Cartpole")
64+
end
65+
66+
ex = E`JuliaRL_PPO_DojoCartpole`
67+
run(ex)

0 commit comments

Comments
 (0)