Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Files generated by invoking Julia with --code-coverage
*.jl.cov
*.jl.*.cov
*.jld2*

# Files generated by invoking Julia with --track-allocation
*.jl.mem
Expand Down
6 changes: 6 additions & 0 deletions emulate_sample/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
150 changes: 150 additions & 0 deletions emulate_sample/emulate_sample_catke.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
using JLD2, Enzyme, ProgressBars, Random, Statistics, AdvancedHMC, GLMakie

Random.seed!(1234)
tic = time()
include("simple_networks.jl")
include("hmc_interface.jl")
include("optimization_utils.jl")

data_directory = ""
data_file = "catke_parameters.jld2"

jlfile = jldopen(data_directory * data_file, "r")
θ = jlfile["parameters"]
y = jlfile["objectives"]
close(jlfile)

θ̄ = mean(θ)
θ̃ = std(θ)
ymax = maximum(y)
ymin = minimum(y)
yshift = ymin # ymin / 2 #
Δy = ymax - ymin # 2 * std(y) #
θr = (reshape(θ, (size(θ)[1] * size(θ)[2], size(θ)[3])) .- θ̄ ) ./ (2θ̃)
yr = (reshape(y, (size(y)[1] * size(y)[2])) .- yshift ) ./ Δy
M = size(θr)[1]
Mᴾ = size(θr)[2]

# Define Network
Nθ = size(θr, 2)
Nθᴴ = Nθ ÷ 2
W1 = randn(Nθᴴ, Nθ)
b1 = randn(Nθᴴ)
W2 = randn(1, Nθᴴ)
b2 = randn(1)
W3 = randn(1, Nθ)
b3 = randn(1)

network = OneLayerNetworkWithLinearByPass(W1, b1, W2, b2, W3, b3)
dnetwork = deepcopy(network)
smoothed_network = deepcopy(network)

## Emulate
adam = Adam(network)
batchsize = 100
loss_list = Float64[]
test_loss_list = Float64[]
epochs = 100
network_parameters = copy(parameters(network))
for i in ProgressBar(1:epochs)
shuffled_list = chunk_list(shuffle(1:2:M), batchsize)
shuffled_test_list = chunk_list(shuffle(2:2:M), batchsize)
loss_value = 0.0
N = length(shuffled_list)
# Batched Gradient Descent and Loss Evaluation
for permuted_list in ProgressBar(shuffled_list)
θbatch = [θr[x, :] for x in permuted_list]
ybatch = yr[permuted_list]
zero!(dnetwork)
autodiff(Enzyme.Reverse, loss, Active, DuplicatedNoNeed(network, dnetwork), Const(θbatch), Const(ybatch))
update!(adam, network, dnetwork)
loss_value += loss(network, θbatch, ybatch) / N
end
push!(loss_list, loss_value)
# Test Loss
loss_value = 0.0
N = length(shuffled_test_list)
for permuted_list in shuffled_test_list
θbatch = [θr[x, :] for x in permuted_list]
ybatch = yr[permuted_list]
loss_value += loss(network, θbatch, ybatch) / N
end
push!(test_loss_list, loss_value)
# Weighted Averaging of Network
m = 0.9
network_parameters .= m * network_parameters + (1-m) * parameters(network)
set_parameters!(smoothed_network, network_parameters)
end

loss_fig = Figure()
ax = Axis(loss_fig[1, 1]; title = "Log10 Loss", xlabel = "Epoch", ylabel = "Loss")
scatter!(ax, log10.(loss_list); color = :blue, label = "Training Loss")
scatter!(ax, log10.(test_loss_list); color = :red, label = "Test Loss")
axislegend(ax, position = :rt)
display(loss_fig)

## Sample
# Define logp and ∇logp and regularizer

initial_θ = copy(θr[argmin(yr), :])
mintheta = minimum(θr, dims = 1)[:]
maxtheta = maximum(θr, dims = 1)[:]
reg = Regularizer([mintheta, maxtheta, initial_θ])

function (regularizer::Regularizer)(x)
if any(x .≤ regularizer.parameters[1])
return -Inf
elseif any(x .> regularizer.parameters[2])
return -Inf
else
return -sum(abs.(x - regularizer.parameters[3]) ./ (regularizer.parameters[2] - regularizer.parameters[1]))
end
return 0.0
end

scale = 10 * Δy # 1/minimum(yr)
regularization_scale = 0.001/2 * scale

U = LogDensity(network, reg, scale, regularization_scale)
∇U = GradientLogDensity(U)

# HMC
D = size(θr, 2)
n_samples = 10000
n_adapts = 1000

metric = DiagEuclideanMetric(D)
hamiltonian = Hamiltonian(metric, GaussianKinetic(), U, ∇U)

initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)

kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

samples, stats = sample(hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=true)

toc = time()
println("Elapsed time: $((toc - tic)/60) minutes")

# Plot
namelist = ["CᵂwΔ", "Cᵂu★", "Cʰⁱc", "Cʰⁱu", "Cʰⁱe", "CʰⁱD", "Cˢ", "Cˡᵒc", "Cˡᵒu", "Cˡᵒe", "CˡᵒD", "CRi⁰", "CRiᵟ", "Cᵘⁿc", "Cᵘⁿu", "Cᵘⁿe", "CᵘⁿD", "Cᶜc", "Cᶜu", "Cᶜe", "CᶜD", "Cᵉc", "Cˢᵖ"]
fig = Figure()
Mp = 5
for i in 1:23
ii = (i-1)÷Mp + 1
jj = (i-1)%Mp + 1
ax = Axis(fig[ii, jj]; title = namelist[i])
v1 = ([sample[i] for sample in samples] .* 2θ̃) .+ θ̄
hist!(ax, v1; bins = 50, strokewidth = 0, color = :blue, normalization = :pdf)
xlims!(ax, -0.1, (reg.parameters[2][i]* 2θ̃ + θ̄) * 1.1)
density!(ax, v1; color = (:red, 0.1), strokewidth = 3, strokecolor = :red)
end
display(fig)

imin = argmax([stat.log_density for stat in stats])
imax = argmin([stat.log_density for stat in stats])
network(samples[imin])
θ₀ = (initial_θ .* 2θ̃) .+ θ̄
((mean(samples) .* 2θ̃) .+ θ̄) - θ₀
((samples[imin] .* 2θ̃) .+ θ̄) - θ₀
42 changes: 42 additions & 0 deletions emulate_sample/hmc_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
struct LogDensity{N, S, M}
logp::N
regularization::S
scale::M
scale_regularization::M
end

# Negative sign if the network represents the potential function
# Note: regularization should be negative semi-definite
function (logp::LogDensity{T})(θ) where T <: SimpleNetwork
return -logp.logp(θ)[1] * logp.scale + logp.regularization(θ) * logp.scale_regularization
end

function LogDensity(network::SimpleNetwork)
regularization(x) = 0.0
return LogDensity(network, regularization, 1.0, 1.0)
end

function LogDensity(network::SimpleNetwork, scale)
regularization(x) = 0.0
return LogDensity(network, regularization, scale, 1.0)
end

struct GradientLogDensity{N}
logp::N
dθ::Vector{Float64}
end

function GradientLogDensity(logp::LogDensity{S}) where S <: SimpleNetwork
dθ = zeros(size(logp.logp.W1, 2))
return GradientLogDensity(logp, dθ)
end

function (∇logp::GradientLogDensity)(θ)
∇logp.dθ .= 0.0
autodiff(Enzyme.Reverse, Const(∇logp.logp), Active, DuplicatedNoNeed(θ, ∇logp.dθ))
return (∇logp.logp(θ), copy(∇logp.dθ))
end

struct Regularizer{F}
parameters::F
end
85 changes: 85 additions & 0 deletions emulate_sample/optimization_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
function loss(network::SimpleNetwork, x, y)
ŷ = similar(y)
for i in eachindex(ŷ)
ŷ[i] = predict(network, x[i])[1]
end
return mean((y .- ŷ) .^ 2)
end

function chunk_list(list, n)
return [list[i:min(i+n-1, length(list))] for i in 1:n:length(list)]
end

struct Adam{S, T, I}
struct_copies::S
parameters::T
t::I
end

function parameters(network::SimpleNetwork)
network_parameters = []
for names in propertynames(network)
push!(network_parameters, getproperty(network, names)[:])
end
param_lengths = [length(params) for params in network_parameters]
parameter_list = zeros(sum(param_lengths))
start = 1
for i in 1:length(param_lengths)
parameter_list[start:start+param_lengths[i]-1] .= network_parameters[i]
start += param_lengths[i]
end
return parameter_list
end

function set_parameters!(network::SimpleNetwork, parameters_list)
param_lengths = Int64[]
for names in propertynames(network)
push!(param_lengths, length(getproperty(network, names)[:]))
end
start = 1
for (i, names) in enumerate(propertynames(network))
getproperty(network, names)[:] .= parameters_list[start:start+param_lengths[i]-1]
start = start + param_lengths[i]
end
return nothing
end

function Adam(network::SimpleNetwork; α=0.001, β₁=0.9, β₂=0.999, ϵ=1e-8)
parameters_list = (; α, β₁, β₂, ϵ)
network_parameters = parameters(network)
t = [1.0]
θ = deepcopy(network_parameters) .* 0.0
gₜ = deepcopy(network_parameters) .* 0.0
m₀ = deepcopy(network_parameters) .* 0.0
mₜ = deepcopy(network_parameters) .* 0.0
v₀ = deepcopy(network_parameters) .* 0.0
vₜ = deepcopy(network_parameters) .* 0.0
v̂ₜ = deepcopy(network_parameters) .* 0.0
m̂ₜ = deepcopy(network_parameters) .* 0.0
struct_copies = (; θ, gₜ, m₀, mₜ, v₀, vₜ, v̂ₜ, m̂ₜ)
return Adam(struct_copies, parameters_list, t)
end


function update!(adam::Adam, network::SimpleNetwork, dnetwork::SimpleNetwork)
# unpack
(; α, β₁, β₂, ϵ) = adam.parameters
t = adam.t[1]
(; θ, gₜ, m₀, mₜ, v₀, vₜ, v̂ₜ, m̂ₜ) = adam.struct_copies
t = adam.t[1]
# get gradient
θ .= parameters(network)
gₜ .= parameters(dnetwork)
# update
@. mₜ = β₁ * m₀ + (1 - β₁) * gₜ
@. vₜ = β₂ * v₀ + (1 - β₂) * (gₜ .^2)
@. m̂ₜ = mₜ / (1 - β₁^t)
@. v̂ₜ = vₜ / (1 - β₂^t)
@. θ = θ - α * m̂ₜ / (sqrt(v̂ₜ) + ϵ)
# update parameters
m₀ .= mₜ
v₀ .= vₜ
adam.t[1] += 1
set_parameters!(network, θ)
return nothing
end
67 changes: 67 additions & 0 deletions emulate_sample/simple_networks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using Enzyme

abstract type SimpleNetwork end

struct OneLayerNetwork{M, V} <: SimpleNetwork
W1::M
b1::V
W2::M
b2::V
end

struct OneLayerNetworkWithLinearByPass{M,V} <: SimpleNetwork
W1::M
b1::V
W2::M
b2::V
W3::M
b3::V
end

function zero!(dnetwork::OneLayerNetworkWithLinearByPass)
dnetwork.W1 .= 0.0
dnetwork.b1 .= 0.0
dnetwork.W2 .= 0.0
dnetwork.b2 .= 0.0
dnetwork.W3 .= 0.0
dnetwork.b3 .= 0.0
return nothing
end

function zero!(dnetwork::OneLayerNetwork)
dnetwork.W1 .= 0.0
dnetwork.b1 .= 0.0
dnetwork.W2 .= 0.0
dnetwork.b2 .= 0.0
return nothing
end

function update!(network::SimpleNetwork, dnetwork::SimpleNetwork, η)
network.W1 .-= η .* dnetwork.W1
network.b1 .-= η .* dnetwork.b1
network.W2 .-= η .* dnetwork.W2
network.b2 .-= η .* dnetwork.b2
return nothing
end

swish(x) = x / (1 + exp(-x))
activation_function(x) = swish(x) # tanh(x) #

function predict(network::OneLayerNetwork, x)
return abs.(network.W2 * activation_function.(network.W1 * x .+ network.b1) .+ network.b2)
end

function predict(network::OneLayerNetworkWithLinearByPass, x)
y1 = network.W1 * x .+ network.b1
y2 = network.W2 * activation_function.(y1) .+ network.b2
y3 = network.W3 * x .+ network.b3
return abs.(y3) .+ abs.(y2)
end

function predict(network::OneLayerNetwork, x, activation::Function)
return abs.(network.W2 * activation.(network.W1 * x .+ network.b1) .+ network.b2)
end

function (network::SimpleNetwork)(x)
return predict(network, x)
end