From 88fdd6972367693e9d35aa7c11736add4192e81a Mon Sep 17 00:00:00 2001 From: Marco Bonici <58727599+marcobonici@users.noreply.github.com> Date: Wed, 10 May 2023 10:43:17 +0200 Subject: [PATCH 1/5] First commit to add weighted mean square --- src/loss.jl | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/loss.jl b/src/loss.jl index ad85a81..930a944 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -107,6 +107,62 @@ function (sl::SquaredLoss{<:AbstractArray{<:Number}})( T(0.5) * s, p, pu end +""" +WeightedSquaredLoss(target) + +Calculates half of mean weighted squared loss of the target. +""" +struct WeightedSquaredLoss{Y, W<:AbstractVector{Y}} <: AbstractLoss{Y} + y::Y + weights::W +end +(::WeightedSquaredLoss)(y, w) = WeightedSquaredLoss(y, w) +WeightedSquaredLoss() = WeightedSquaredLoss(nothing) +target(wsl::WeightedSquaredLoss) = getfield(wsl, :y)#maybe need to return both :y and :weights? + +Base.getindex(wsl::WeightedSquaredLoss, r) = WeightedSquaredLoss(view_slice_last(target(wsl), r)) + +weighted_squared_loss(chn::SimpleChain, y, w) = add_loss(chn, WeightedSquaredLoss(y, w)) + +Base.show(io::IO, ::WeightedSquaredLoss) = print(io, "WeightedSquaredLoss") + +@inline loss_multiplier(::AbstractLoss, N, ::Type{T}) where {T} = inv(T(N)) +@inline loss_multiplier(::WeightedSquaredLoss, N, ::Type{T}) where {T} = T(2) / T(N) + +function chain_valgrad!( + _, + arg::AbstractArray{T,D}, + layers::Tuple{WeightedSquaredLoss}, + p::Ptr, + pu::Ptr{UInt8} + ) where {T,D} + y = getfield(getfield(layers, 1), :y) + w = getfield(getfield(layers, 1), :weights) + # invN = T(inv(static_size(arg, D))) + s = zero(T) + @turbo for i ∈ eachindex(arg) + δ = arg[i] - y[i] + arg[i] = δ + s += δ * δ * w[i] + end + T(0.5) * s, arg, pu + end + function (sl::WeightedSquaredLoss{<:AbstractArray{<:Number}})( + arg::AbstractArray{T,N}, + p, + pu + ) where {T,N} + y = getfield(sl, :y) + w = getfield(sl, :weights) + s = zero(T) + @turbo for i ∈ eachindex(arg) + δ = arg[i] - y[i] + s += δ * δ * w[i] + end + # NOTE: we're not dividing by static_size(arg,N) + T(0.5) * s, p, pu + end + """ AbsoluteLoss From 924c995a489fbe684b953026af82f45904b3dc1d Mon Sep 17 00:00:00 2001 From: Marco Bonici <58727599+marcobonici@users.noreply.github.com> Date: Wed, 10 May 2023 17:30:46 +0200 Subject: [PATCH 2/5] Update src/loss.jl according to ChrisElrod suggestion Co-authored-by: Chris Elrod --- src/loss.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/loss.jl b/src/loss.jl index 930a944..3a4daa9 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -142,8 +142,9 @@ function chain_valgrad!( s = zero(T) @turbo for i ∈ eachindex(arg) δ = arg[i] - y[i] - arg[i] = δ - s += δ * δ * w[i] + δw = δ*w[i] + arg[i] = δw + s += δ * δw end T(0.5) * s, arg, pu end From bc84718452e955a1c92859dfc1708f535039b54a Mon Sep 17 00:00:00 2001 From: Marco Bonici <58727599+marcobonici@users.noreply.github.com> Date: Thu, 11 May 2023 17:02:20 +0200 Subject: [PATCH 3/5] Modified target method; updated tuple constructor WSL --- src/loss.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/loss.jl b/src/loss.jl index 3a4daa9..ca12405 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -118,7 +118,8 @@ struct WeightedSquaredLoss{Y, W<:AbstractVector{Y}} <: AbstractLoss{Y} end (::WeightedSquaredLoss)(y, w) = WeightedSquaredLoss(y, w) WeightedSquaredLoss() = WeightedSquaredLoss(nothing) -target(wsl::WeightedSquaredLoss) = getfield(wsl, :y)#maybe need to return both :y and :weights? +WeightedSquaredLoss(x::Tuple) = WeightedSquaredLoss(x...) +target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :w) Base.getindex(wsl::WeightedSquaredLoss, r) = WeightedSquaredLoss(view_slice_last(target(wsl), r)) From dfd1a25a227b7931b046ce8e5cc58e9d7ad2ddc4 Mon Sep 17 00:00:00 2001 From: Marco Bonici <58727599+marcobonici@users.noreply.github.com> Date: Thu, 11 May 2023 17:59:57 +0200 Subject: [PATCH 4/5] Adding view_slice_last method --- src/loss.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/loss.jl b/src/loss.jl index ca12405..1d7b743 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -120,6 +120,9 @@ end WeightedSquaredLoss() = WeightedSquaredLoss(nothing) WeightedSquaredLoss(x::Tuple) = WeightedSquaredLoss(x...) target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :w) +function view_slice_last(target(wsl::WeightedSquaredLoss), r) + return Tuple(view_slice_last(f, r) for f in target(wsl)) +end Base.getindex(wsl::WeightedSquaredLoss, r) = WeightedSquaredLoss(view_slice_last(target(wsl), r)) From 02ebd4fd35cb418f8f4bf19dfb84339d90572c08 Mon Sep 17 00:00:00 2001 From: Marco Bonici Date: Fri, 12 May 2023 17:50:42 +0200 Subject: [PATCH 5/5] Modified constructors wsl; loss eval ok, not train --- src/SimpleChains.jl | 1 + src/loss.jl | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/SimpleChains.jl b/src/SimpleChains.jl index 6c4f3a9..9d1a2c0 100644 --- a/src/SimpleChains.jl +++ b/src/SimpleChains.jl @@ -62,6 +62,7 @@ export SimpleChain, Flatten, AbsoluteLoss, SquaredLoss, + WeightedSquaredLoss, LogitCrossEntropyLoss, relu, static, diff --git a/src/loss.jl b/src/loss.jl index 1d7b743..bc0501f 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -112,16 +112,16 @@ WeightedSquaredLoss(target) Calculates half of mean weighted squared loss of the target. """ -struct WeightedSquaredLoss{Y, W<:AbstractVector{Y}} <: AbstractLoss{Y} +struct WeightedSquaredLoss{Y, W} <: AbstractLoss{Y} y::Y weights::W end (::WeightedSquaredLoss)(y, w) = WeightedSquaredLoss(y, w) WeightedSquaredLoss() = WeightedSquaredLoss(nothing) WeightedSquaredLoss(x::Tuple) = WeightedSquaredLoss(x...) -target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :w) -function view_slice_last(target(wsl::WeightedSquaredLoss), r) - return Tuple(view_slice_last(f, r) for f in target(wsl)) +target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :weights) +function view_slice_last(x::Tuple, r) + return Tuple(view_slice_last(f, r) for f in x::Tuple) end Base.getindex(wsl::WeightedSquaredLoss, r) = WeightedSquaredLoss(view_slice_last(target(wsl), r)) @@ -130,7 +130,6 @@ weighted_squared_loss(chn::SimpleChain, y, w) = add_loss(chn, WeightedSquaredLos Base.show(io::IO, ::WeightedSquaredLoss) = print(io, "WeightedSquaredLoss") -@inline loss_multiplier(::AbstractLoss, N, ::Type{T}) where {T} = inv(T(N)) @inline loss_multiplier(::WeightedSquaredLoss, N, ::Type{T}) where {T} = T(2) / T(N) function chain_valgrad!(