Skip to content

Commit b606aa1

Browse files
committed
cartpole is meaningless as no reward defined
1 parent dd66fa5 commit b606aa1

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

examples/deeprl/ant_ddpg.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using ReinforcementLearning
2+
using Flux
3+
using Flux.Losses
4+
5+
using Random
6+
using Dojo
7+
8+
function RL.Experiment(
9+
::Val{:JuliaRL},
10+
::Val{:DDPG},
11+
::Val{:DojoAnt},
12+
::Nothing,
13+
save_dir = nothing,
14+
seed = 42
15+
)
16+
17+
rng = MersenneTwister(seed)
18+
env = Dojo.DojoRLEnv("ant")
19+
Random.seed!(env, seed)
20+
A = action_space(env)
21+
ns, na = length(state(env)), length(action_space(env))
22+
@show na
23+
24+
init = glorot_uniform(rng)
25+
26+
create_actor() = Chain(
27+
Dense(ns, 30, relu; init = init),
28+
Dense(30, 30, relu; init = init),
29+
Dense(30, na, tanh; init = init),
30+
)
31+
create_critic() = Chain(
32+
Dense(ns + na, 30, relu; init = init),
33+
Dense(30, 30, relu; init = init),
34+
Dense(30, 1; init = init),
35+
)
36+
37+
agent = Agent(
38+
policy = DDPGPolicy(
39+
behavior_actor = NeuralNetworkApproximator(
40+
model = create_actor(),
41+
optimizer = ADAM(),
42+
),
43+
behavior_critic = NeuralNetworkApproximator(
44+
model = create_critic(),
45+
optimizer = ADAM(),
46+
),
47+
target_actor = NeuralNetworkApproximator(
48+
model = create_actor(),
49+
optimizer = ADAM(),
50+
),
51+
target_critic = NeuralNetworkApproximator(
52+
model = create_critic(),
53+
optimizer = ADAM(),
54+
),
55+
γ = 0.99f0,
56+
ρ = 0.995f0,
57+
na = na,
58+
batch_size = 64,
59+
start_steps = 1000,
60+
start_policy = RandomPolicy(A; rng = rng),
61+
update_after = 1000,
62+
update_freq = 1,
63+
act_limit = 1.0,
64+
act_noise = 0.1,
65+
rng = rng,
66+
),
67+
trajectory = CircularArraySARTTrajectory(
68+
capacity = 10000,
69+
state = Vector{Float32} => (ns,),
70+
action = Float32 => (na, ),
71+
),
72+
)
73+
74+
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
75+
hook = TotalRewardPerEpisode()
76+
Experiment(agent, env, stop_condition, hook, "# Dojo Ant with DDPG")
77+
end
78+
79+
ex = E`JuliaRL_DDPG_DojoAnt`
80+
run(ex)

0 commit comments

Comments
 (0)