diff --git a/src/SimpleChains.jl b/src/SimpleChains.jl index dd01b84..a5124e7 100644 --- a/src/SimpleChains.jl +++ b/src/SimpleChains.jl @@ -62,6 +62,7 @@ export SimpleChain, Flatten, AbsoluteLoss, SquaredLoss, + WeightedSquaredLoss, LogitCrossEntropyLoss, relu, static, diff --git a/src/loss.jl b/src/loss.jl index ad85a81..bc0501f 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -107,6 +107,66 @@ function (sl::SquaredLoss{<:AbstractArray{<:Number}})( T(0.5) * s, p, pu end +""" +WeightedSquaredLoss(target) + +Calculates half of mean weighted squared loss of the target. +""" +struct WeightedSquaredLoss{Y, W} <: AbstractLoss{Y} + y::Y + weights::W +end +(::WeightedSquaredLoss)(y, w) = WeightedSquaredLoss(y, w) +WeightedSquaredLoss() = WeightedSquaredLoss(nothing) +WeightedSquaredLoss(x::Tuple) = WeightedSquaredLoss(x...) +target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :weights) +function view_slice_last(x::Tuple, r) + return Tuple(view_slice_last(f, r) for f in x::Tuple) +end + +Base.getindex(wsl::WeightedSquaredLoss, r) = WeightedSquaredLoss(view_slice_last(target(wsl), r)) + +weighted_squared_loss(chn::SimpleChain, y, w) = add_loss(chn, WeightedSquaredLoss(y, w)) + +Base.show(io::IO, ::WeightedSquaredLoss) = print(io, "WeightedSquaredLoss") + +@inline loss_multiplier(::WeightedSquaredLoss, N, ::Type{T}) where {T} = T(2) / T(N) + +function chain_valgrad!( + _, + arg::AbstractArray{T,D}, + layers::Tuple{WeightedSquaredLoss}, + p::Ptr, + pu::Ptr{UInt8} + ) where {T,D} + y = getfield(getfield(layers, 1), :y) + w = getfield(getfield(layers, 1), :weights) + # invN = T(inv(static_size(arg, D))) + s = zero(T) + @turbo for i ∈ eachindex(arg) + δ = arg[i] - y[i] + δw = δ*w[i] + arg[i] = δw + s += δ * δw + end + T(0.5) * s, arg, pu + end + function (sl::WeightedSquaredLoss{<:AbstractArray{<:Number}})( + arg::AbstractArray{T,N}, + p, + pu + ) where {T,N} + y = getfield(sl, :y) + w = getfield(sl, :weights) + s = zero(T) + @turbo for i ∈ eachindex(arg) + δ = arg[i] - y[i] + s += δ * δ * w[i] + end + # NOTE: we're not dividing by static_size(arg,N) + T(0.5) * s, p, pu + end + """ AbsoluteLoss