1
+ using ReinforcementLearning
2
+ using Flux
3
+ using Flux. Losses
4
+
5
+ using Random
6
+ using Dojo
7
+
8
+ function RL. Experiment (
9
+ :: Val{:JuliaRL} ,
10
+ :: Val{:DDPG} ,
11
+ :: Val{:DojoAnt} ,
12
+ :: Nothing ,
13
+ save_dir = nothing ,
14
+ seed = 42
15
+ )
16
+
17
+ rng = MersenneTwister (seed)
18
+ env = Dojo. DojoRLEnv (" ant" )
19
+ Random. seed! (env, seed)
20
+ A = action_space (env)
21
+ ns, na = length (state (env)), length (action_space (env))
22
+ @show na
23
+
24
+ init = glorot_uniform (rng)
25
+
26
+ create_actor () = Chain (
27
+ Dense (ns, 30 , relu; init = init),
28
+ Dense (30 , 30 , relu; init = init),
29
+ Dense (30 , na, tanh; init = init),
30
+ )
31
+ create_critic () = Chain (
32
+ Dense (ns + na, 30 , relu; init = init),
33
+ Dense (30 , 30 , relu; init = init),
34
+ Dense (30 , 1 ; init = init),
35
+ )
36
+
37
+ agent = Agent (
38
+ policy = DDPGPolicy (
39
+ behavior_actor = NeuralNetworkApproximator (
40
+ model = create_actor (),
41
+ optimizer = ADAM (),
42
+ ),
43
+ behavior_critic = NeuralNetworkApproximator (
44
+ model = create_critic (),
45
+ optimizer = ADAM (),
46
+ ),
47
+ target_actor = NeuralNetworkApproximator (
48
+ model = create_actor (),
49
+ optimizer = ADAM (),
50
+ ),
51
+ target_critic = NeuralNetworkApproximator (
52
+ model = create_critic (),
53
+ optimizer = ADAM (),
54
+ ),
55
+ γ = 0.99f0 ,
56
+ ρ = 0.995f0 ,
57
+ na = na,
58
+ batch_size = 64 ,
59
+ start_steps = 1000 ,
60
+ start_policy = RandomPolicy (A; rng = rng),
61
+ update_after = 1000 ,
62
+ update_freq = 1 ,
63
+ act_limit = 1.0 ,
64
+ act_noise = 0.1 ,
65
+ rng = rng,
66
+ ),
67
+ trajectory = CircularArraySARTTrajectory (
68
+ capacity = 10000 ,
69
+ state = Vector{Float32} => (ns,),
70
+ action = Float32 => (na, ),
71
+ ),
72
+ )
73
+
74
+ stop_condition = StopAfterStep (10_000 , is_show_progress= ! haskey (ENV , " CI" ))
75
+ hook = TotalRewardPerEpisode ()
76
+ Experiment (agent, env, stop_condition, hook, " # Dojo Ant with DDPG" )
77
+ end
78
+
79
+ ex = E ` JuliaRL_DDPG_DojoAnt`
80
+ run (ex)
0 commit comments