Skip to content

Commit c821bb6

Browse files
Tabular Approximator fixes (pre v0.11 changes) (#1040)
* Fix tabular approx * Expand tests for approx / learners * Add target network docstring * add more target_network tests * add docstring, fix gpu toggle * Fix naming * Fix approximator tests * Fix tests * target network expanded tests pass * final target network test passes * delete excess file * Drop gpu code from generic forward function * Tab approx forward to env fixes * Add missing test and methods for approximator * Add missing env import * Add RLEnv back to RLCore test dependencies * only run gpu check when gpu is functional * Fix dqn * Try dqn fix * Update src/ReinforcementLearningCore/src/policies/learners/target_network.jl Co-authored-by: Henri Dehaybe <47037088+HenriDeh@users.noreply.github.com> * Fix var naming style --------- Co-authored-by: Henri Dehaybe <47037088+HenriDeh@users.noreply.github.com>
1 parent a173359 commit c821bb6

12 files changed

Lines changed: 310 additions & 60 deletions

File tree

src/ReinforcementLearningCore/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
5555
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
5656
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
5757
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
58+
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
5859
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5960
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
6061
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
6162

6263
[targets]
63-
test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "Test", "UUIDs"]
64+
test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "ReinforcementLearningEnvironments", "Test", "UUIDs"]

src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ using Functors: @functor
55

66
abstract type AbstractLearner end
77

8-
Base.show(io::IO, m::MIME"text/plain", L::AbstractLearner) = show(io, m, convert(AnnotatedStructTree, L))
8+
Base.show(io::IO, m::MIME"text/plain", learner::AbstractLearner) = show(io, m, convert(AnnotatedStructTree, learner))
99

1010
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
11-
function forward(L::Le, env::E) where {Le <: AbstractLearner, E <: AbstractEnv}
12-
env |> state |> Flux.gpu |> (x -> forward(L, x)) |> Flux.cpu
11+
function forward(learner::L, env::E) where {L <: AbstractLearner, E <: AbstractEnv}
12+
env |> state |> (x -> forward(learner, x))
1313
end
1414

1515
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end
Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Flux
2+
13
"""
24
Approximator(model, optimiser)
35
@@ -9,20 +11,37 @@ struct Approximator{M,O} <: AbstractLearner
911
optimiser_state::O
1012
end
1113

12-
function Approximator(; model, optimiser, gpu=false)
14+
15+
"""
16+
Approximator(; model, optimiser, usegpu=false)
17+
18+
Constructs an `Approximator` object for reinforcement learning.
19+
20+
# Arguments
21+
- `model`: The model used for approximation.
22+
- `optimiser`: The optimizer used for updating the model.
23+
- `usegpu`: A boolean indicating whether to use GPU for computation. Default is `false`.
24+
25+
# Returns
26+
An `Approximator` object.
27+
"""
28+
function Approximator(; model, optimiser::Flux.Optimise.AbstractOptimiser, use_gpu=false)
1329
optimiser_state = Flux.setup(optimiser, model)
14-
if gpu # Pass model to GPU (if available) upon creation
30+
if use_gpu # Pass model to GPU (if available) upon creation
1531
return Approximator(gpu(model), gpu(optimiser_state))
1632
else
1733
return Approximator(model, optimiser_state)
1834
end
1935
end
2036

37+
Approximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=false) = Approximator(model=model, optimiser=optimiser, use_gpu=use_gpu)
38+
2139
Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
2240

2341
@functor Approximator (model,)
2442

2543
forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
44+
forward(A::Approximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x))
2645

27-
RLBase.optimise!(A::Approximator, grad) =
28-
Flux.Optimise.update!(A.optimiser_state, A.model, grad)
46+
RLBase.optimise!(A::Approximator, grad::NamedTuple) =
47+
Flux.Optimise.update!(A.optimiser_state, A.model, grad.model)

src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,21 @@ TabularQApproximator(; n_state, n_action, init = 0.0, opt = InvDecay(1.0)) =
2626
TabularApproximator(fill(init, n_action, n_state), opt)
2727

2828
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
29-
function forward(L::Approximator{A, Any}, env::E) where {A <:AbstractArray, E <: AbstractEnv}
30-
env |> state |> (x -> forward(L, x))
31-
end
29+
forward(L::TabularVApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(L, x))
30+
forward(L::TabularQApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(L, x))
3231

3332
RLCore.forward(
34-
app::Approximator{R,O},
33+
app::TabularVApproximator{R,O},
3534
s::I,
36-
) where {R<:AbstractVector,O} = @views app.model[s]
35+
) where {R<:AbstractVector,O,I} = @views app.model[s]
3736

3837
RLCore.forward(
39-
app::Approximator{R,O},
38+
app::TabularQApproximator{R,O},
4039
s::I,
41-
) where {R<:AbstractArray,O} = @views app.model[:, s]
40+
) where {R<:AbstractArray,O,I} = @views app.model[:, s]
4241

4342
RLCore.forward(
44-
app::Approximator{R,O},
43+
app::TabularQApproximator{R,O},
4544
s::I1,
4645
a::I2,
47-
) where {R<:AbstractArray,O} = @views app.model[a, s]
46+
) where {R<:AbstractArray,O,I1,I2} = @views app.model[a, s]

src/ReinforcementLearningCore/src/policies/learners/target_network.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export Approximator, TargetNetwork, target, model
22

3-
using Flux
3+
using Flux: gpu
44

55

66
target(ap::Approximator) = ap.model #see TargetNetwork
@@ -33,11 +33,32 @@ mutable struct TargetNetwork{M}
3333
n_optimise::Int
3434
end
3535

36-
function TargetNetwork(network; sync_freq = 1, ρ = 0f0)
36+
"""
37+
TargetNetwork(network; sync_freq = 1, ρ = 0f0, use_gpu = false)
38+
39+
Constructs a target network for reinforcement learning.
40+
41+
# Arguments
42+
- `network`: The main network used for training.
43+
- `sync_freq`: The frequency (in number of calls to `optimise!`) at which the target network is synchronized with the main network. Default is 1.
44+
- `ρ`: The interpolation factor used for updating the target network. Must be in the range [0, 1]. Default is 0 (the old weights are completely replaced by the new ones).
45+
- `use_gpu`: Specifies whether to use GPU for the target network. Default is `false`.
46+
47+
# Returns
48+
A `TargetNetwork` object.
49+
"""
50+
function TargetNetwork(network::Approximator; sync_freq = 1, ρ = 0f0, use_gpu = false)
3751
@assert 0 <= ρ <= 1 "ρ must in [0,1]"
38-
# NOTE: model is pushed to gpu in Approximator, need to transfer to cpu before deepcopy, then push target model to gpu
39-
target = gpu(deepcopy(cpu(network.model)))
40-
TargetNetwork(network, target, sync_freq, ρ, 0)
52+
ρ = Float32(ρ)
53+
54+
if use_gpu
55+
@assert typeof(gpu(network.model)) == typeof(network.model) "`Approximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `Approximator`."
56+
# NOTE: model is pushed to gpu in Approximator, need to transfer to cpu before deepcopy, then push target model to gpu
57+
target = gpu(deepcopy(cpu(network.model)))
58+
else
59+
target = deepcopy(network.model)
60+
end
61+
return TargetNetwork(network, target, sync_freq, ρ, 0)
4162
end
4263

4364
@functor TargetNetwork (network, target)
@@ -49,9 +70,9 @@ forward(tn::TargetNetwork, args...) = forward(tn.network, args...)
4970
model(tn::TargetNetwork) = model(tn.network)
5071
target(tn::TargetNetwork) = tn.target
5172

52-
function RLBase.optimise!(tn::TargetNetwork, grad)
73+
function RLBase.optimise!(tn::TargetNetwork, grad::NamedTuple)
5374
A = tn.network
54-
optimise!(A, grad)
75+
optimise!(A, grad.network)
5576

5677
tn.n_optimise += 1
5778

@@ -62,4 +83,6 @@ function RLBase.optimise!(tn::TargetNetwork, grad)
6283
end
6384
tn.n_optimise = 0
6485
end
86+
87+
return
6588
end
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using Test
2+
using Flux
3+
4+
@testset "AbstractLearner Tests" begin
5+
@testset "Forward" begin
6+
# Mock environment and learner
7+
struct MockEnv <: AbstractEnv end
8+
struct MockLearner <: AbstractLearner end
9+
10+
function RLCore.forward(::MockLearner, ::AbstractState)
11+
return rand(2)
12+
end
13+
14+
env = MockEnv()
15+
learner = MockLearner()
16+
17+
output = forward(learner, env)
18+
19+
@test typeof(output) == Array{Float64,1}
20+
@test length(output) == 2
21+
end
22+
23+
@testset "Plan" begin
24+
# Mock explorer, environment, and learner
25+
struct MockExplorer <: AbstractExplorer end
26+
struct MockEnv <: AbstractEnv end
27+
struct MockLearner <: AbstractLearner end
28+
29+
function RLBase.plan!(::MockExplorer, ::AbstractState, ::AbstractActionSpace)
30+
return rand(2)
31+
end
32+
33+
env = MockEnv()
34+
learner = MockLearner()
35+
explorer = MockExplorer()
36+
37+
output = RLBase.plan!(explorer, learner, env)
38+
39+
@test typeof(output) == Array{Float64,1}
40+
@test length(output) == 2
41+
end
42+
43+
@testset "Plan with Player" begin
44+
# Mock explorer, environment, and learner
45+
struct MockExplorer <: AbstractExplorer end
46+
struct MockEnv <: AbstractEnv end
47+
struct MockLearner <: AbstractLearner end
48+
49+
function RLBase.plan!(::MockExplorer, ::AbstractState, ::AbstractActionSpace)
50+
return rand(2)
51+
end
52+
53+
env = MockEnv()
54+
learner = MockLearner()
55+
explorer = MockExplorer()
56+
player = :player1
57+
58+
output = RLBase.plan!(explorer, learner, env, player)
59+
60+
@test typeof(output) == Array{Float64,1}
61+
@test length(output) == 2
62+
end
63+
64+
@testset "optimise!" begin
65+
struct MockLearner <: AbstractLearner end
66+
tr = Trajectory(
67+
CircularArraySARTSTraces(; capacity = 1_000),
68+
BatchSampler(1),
69+
InsertSampleRatioController(n_inserted = -1),
70+
)
71+
@test optimise!(MockLearner(), PreActStage(), tr) is nothing
72+
end
73+
end
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Test
2+
using Flux
3+
using ReinforcementLearningEnvironments
4+
5+
@testset "Approximator Tests" begin
6+
@testset "Creation, with use_gpu = true toggle" begin
7+
model = Chain(Dense(10, 5, relu), Dense(5, 2))
8+
optimiser = Adam()
9+
approximator = Approximator(model=model, optimiser=optimiser, use_gpu=true)
10+
11+
@test approximator isa Approximator
12+
@test typeof(approximator.model) == typeof(gpu(model))
13+
@test approximator.optimiser_state isa NamedTuple
14+
end
15+
16+
@testset "Forward" begin
17+
model = Chain(Dense(10, 5, relu), Dense(5, 2))
18+
optimiser = Adam()
19+
approximator = Approximator(model=model, optimiser=optimiser, use_gpu=false)
20+
21+
input = rand(Float32, 10)
22+
output = RLCore.forward(approximator, input)
23+
24+
@test typeof(output) == Array{Float32,1}
25+
@test length(output) == 2
26+
end
27+
28+
@testset "Forward to environment" begin
29+
model = Chain(Dense(4, 5, relu), Dense(5, 2))
30+
optimiser = Adam()
31+
approximator = Approximator(model=model, optimiser=optimiser, use_gpu=false)
32+
33+
env = CartPoleEnv()
34+
output = RLCore.forward(approximator, env)
35+
@test typeof(output) == Array{Float32,1}
36+
@test length(output) == 2
37+
end
38+
39+
@testset "Optimise" begin
40+
model = Chain(Dense(10, 5, relu), Dense(5, 2))
41+
optimiser = Adam()
42+
approximator = Approximator(model=model, optimiser=optimiser)
43+
44+
input = rand(Float32, 10)
45+
46+
47+
grad = Flux.Zygote.gradient(approximator) do model
48+
sum(RLCore.forward(model, input))
49+
end
50+
51+
@test approximator.model.layers[2].bias == [0, 0]
52+
RLCore.optimise!(approximator, grad[1])
53+
54+
@test approximator.model.layers[2].bias != [0, 0]
55+
end
56+
end

src/ReinforcementLearningCore/test/policies/learners/approximators.jl

Lines changed: 0 additions & 31 deletions
This file was deleted.
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
1-
include("approximator.jl")
2-
include("tabular_approximator.jl")
1+
@testset "approximators.jl" begin
2+
include("approximator.jl")
3+
include("tabular_approximator.jl")
4+
include("target_network.jl")
5+
end

src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
using Test
33
using ReinforcementLearningCore
4+
using ReinforcementLearningEnvironments
45
using Flux
56

67
@testset "Constructors" begin
@@ -15,7 +16,11 @@ end
1516
v_approx = TabularVApproximator(n_state = 10)
1617
@test RLCore.forward(v_approx, 1) == 0.0
1718

19+
env = RockPaperScissorsEnv()
20+
@test RLCore.forward(v_approx, env) == 0.0
21+
1822
q_approx = TabularQApproximator(n_state = 5, n_action = 10)
1923
@test RLCore.forward(q_approx, 1) == zeros(Float64, 10)
2024
@test RLCore.forward(q_approx, 1, 5) == 0.0
21-
end
25+
@test RLCore.forward(q_approx, env) == zeros(10)
26+
end

0 commit comments

Comments
 (0)