Skip to content

Commit 1043fc3

Browse files
committed
Towards ReinforcementLearning.jl integration
1 parent efab5f2 commit 1043fc3

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
1919
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2020
Polyhedra = "67491407-f73d-577b-9b50-8179a7c68029"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
22+
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
2223
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
2324
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
2425
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

environments/rlenv.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using ReinforcementLearningBase: RLBase
2+
3+
mutable struct DojoRLEnv <: RLBase.AbstractEnv
4+
dojoenv
5+
action_space
6+
observation_space
7+
state
8+
reward
9+
done::Bool
10+
info::Dict
11+
end
12+
13+
function DojoRLEnv(dojoenv::Environment)
14+
action_space = convert(RLBase.Space, dojoenv.input_space)
15+
observation_space = convert(RLBase.Space, dojoenv.observation_space)
16+
state = reset(dojoenv)
17+
return DojoRLEnv(dojoenv, action_space, observation_space, state, 0.0, false, Dict())
18+
end
19+
20+
RLBase.action_space(env::DojoRLEnv) = env.action_space
21+
RLBase.state_space(env::DojoRLEnv) = env.observation_space
22+
RLBase.is_terminated(env::DojoRLEnv) = env.done
23+
24+
RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv)
25+
26+
RLBase.reward(env::DojoRLEnv) = error()
27+
RLBase.state(env::DojoRLEnv) = env.state
28+
29+
Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed)
30+
31+
function (env::DojoRLEnv)(a)
32+
s, r, d, i = step(env.dojoenv, a)
33+
env.state = s
34+
env.reward = r
35+
env.done = d
36+
env.info = i
37+
return nothing
38+
end

src/Dojo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ include(joinpath("..", "environments", "environment.jl"))
146146
include(joinpath("..", "environments", "dynamics.jl"))
147147
include(joinpath("..", "environments", "utilities.jl"))
148148
include(joinpath("..", "environments", "include.jl"))
149+
include(joinpath("..", "environments", "rlenv.jl"))
149150

150151
# Bodies
151152
export

0 commit comments

Comments
 (0)