Skip to content

Commit fda1ba3

Browse files
authored
Remove value in favor of functor interface (#163)
* Remove value in favor of functor interface * Update docs
1 parent d8c0654 commit fda1ba3

File tree

15 files changed

+104
-243
lines changed

15 files changed

+104
-243
lines changed

docs/src/advanced/developer.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ MarginLoss
6060

6161
Each of the three abstract types listed above serves a purpose
6262
other than dispatch. All losses that belong to the same family
63-
share functionality to some degree. For example all subtypes of
64-
[`SupervisedLoss`](@ref) share the same implementations for the
65-
vectorized versions of [`value`](@ref) and [`deriv`](@ref).
63+
share functionality to some degree.
6664

6765
More interestingly, the abstract types [`DistanceLoss`](@ref) and
6866
[`MarginLoss`](@ref), serve an additional purpose aside from

docs/src/advanced/extend.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ ScaledLoss
4242
julia> lsloss = 1/2 * L2DistLoss()
4343
ScaledLoss{L2DistLoss, 0.5}(L2DistLoss())
4444
45-
julia> value(L2DistLoss(), 4.0, 0.0)
45+
julia> L2DistLoss()(4.0, 0.0)
4646
16.0
4747
48-
julia> value(lsloss, 4.0, 0.0)
48+
julia> lsloss(4.0, 0.0)
4949
8.0
5050
```
5151

@@ -102,16 +102,16 @@ WeightedMarginLoss
102102
julia> myloss = WeightedMarginLoss(HingeLoss(), 0.8)
103103
WeightedMarginLoss{L1HingeLoss, 0.8}(L1HingeLoss())
104104
105-
julia> value(myloss, -4.0, 1.0) # positive class
105+
julia> myloss(-4.0, 1.0) # positive class
106106
4.0
107107
108-
julia> value(HingeLoss(), -4.0, 1.0)
108+
julia> HingeLoss()(-4.0, 1.0)
109109
5.0
110110
111-
julia> value(myloss, 4.0, -1.0) # negative class
111+
julia> myloss(4.0, -1.0) # negative class
112112
0.9999999999999998
113113
114-
julia> value(HingeLoss(), 4.0, -1.0)
114+
julia> HingeLoss()(4.0, -1.0)
115115
5.0
116116
```
117117

docs/src/introduction/gettingstarted.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ From an implementation perspective, we should point out that all
6060
the concrete loss "functions" that this package provides are
6161
actually defined as immutable types, instead of native Julia
6262
functions. We can compute the value of some type of loss using
63-
the function [`value`](@ref). Let us start with an example of how
63+
the functor interface. Let us start with an example of how
6464
to compute the loss of a single observation (i.e. two numbers).
6565

6666
```julia-repl
67-
# loss ŷ y
68-
julia> value(L2DistLoss(), 0.5, 1.0)
67+
# loss ŷ y
68+
julia> L2DistLoss()(0.5, 1.0)
6969
0.25
7070
```
7171

@@ -78,7 +78,7 @@ julia> true_targets = [ 1, 0, -2];
7878
7979
julia> pred_outputs = [0.5, 2, -1];
8080
81-
julia> value.(L2DistLoss(), pred_outputs, true_targets)
81+
julia> L2DistLoss().(pred_outputs, true_targets)
8282
3-element Vector{Float64}:
8383
0.25
8484
4.0

docs/src/user/aggregate.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,21 @@ say "naive", because it will not give us an acceptable
3434
performance.
3535

3636
```jldoctest
37-
julia> value.(L1DistLoss(), [2,5,-2], [1.,2,3])
37+
julia> loss = L1DistLoss()
38+
L1DistLoss()
39+
40+
julia> loss.([2,5,-2], [1.,2,3])
3841
3-element Vector{Float64}:
3942
1.0
4043
3.0
4144
5.0
4245
43-
julia> sum(value.(L1DistLoss(), [2,5,-2], [1.,2,3])) # WARNING: Bad code
46+
julia> sum(loss.([2,5,-2], [1.,2,3])) # WARNING: Bad code
4447
9.0
4548
```
4649

4750
This works as expected, but there is a price for it. Before the
48-
sum can be computed, [`value`](@ref) will allocate a temporary
51+
sum can be computed, the solution will allocate a temporary
4952
array and fill it with the element-wise results. After that,
5053
`sum` will iterate over this temporary array and accumulate the
5154
values accordingly. Bottom line: we allocate temporary memory
@@ -82,7 +85,7 @@ the results, we will see that the loss of the second observation
8285
was effectively counted twice.
8386

8487
```jldoctest
85-
julia> result = value.(L1DistLoss(), [2,5,-2], [1.,2,3]) .* [1,2,1]
88+
julia> result = L1DistLoss().([2,5,-2], [1.,2,3]) .* [1,2,1]
8689
3-element Vector{Float64}:
8790
1.0
8891
6.0

docs/src/user/interface.md

Lines changed: 9 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ than one place.
4949
julia> loss = L2DistLoss()
5050
L2DistLoss()
5151
52-
julia> value(loss, 3, 2)
52+
julia> loss(3, 2)
5353
1
5454
```
5555

@@ -66,9 +66,9 @@ yourself in the code below. As such they are zero-cost
6666
abstractions.
6767

6868
```julia-repl
69-
julia> v1(loss,y,t) = value(loss,y,t)
69+
julia> v1(loss,y,t) = loss(y,t)
7070
71-
julia> v2(y,t) = value(L2DistLoss(),y,t)
71+
julia> v2(y,t) = L2DistLoss()(y,t)
7272
7373
julia> @code_llvm v1(loss, 3, 2)
7474
define i64 @julia_v1_70944(i64, i64) #0 {
@@ -115,46 +115,17 @@ performance overhead, and zero memory allocations on the heap.
115115

116116
The first thing we may want to do is compute the loss for some
117117
observation (singular). In fact, all losses are implemented on
118-
single observations under the hood. The core function to compute
119-
the value of a loss is `value`. We will see throughout the
120-
documentation that this function allows for a lot of different
121-
method signatures to accomplish a variety of tasks.
122-
123-
```@docs
124-
value
125-
```
126-
127-
It may be interesting to note, that this function also supports
128-
broadcasting and all the syntax benefits that come with it. Thus,
129-
it is quite simple to make use of preallocated memory for storing
130-
the element-wise results.
118+
single observations under the hood, and are functors.
131119

132120
```jldoctest bcast1
133-
julia> value.(L1DistLoss(), [2,5,-2], [1,2,3])
121+
julia> loss = L1DistLoss()
122+
L1DistLoss()
123+
124+
julia> loss.([2,5,-2], [1,2,3])
134125
3-element Vector{Int64}:
135126
1
136127
3
137128
5
138-
139-
julia> buffer = zeros(3); # preallocate a buffer
140-
141-
julia> buffer .= value.(L1DistLoss(), [2,5,-2], [1.,2,3])
142-
3-element Vector{Float64}:
143-
1.0
144-
3.0
145-
5.0
146-
```
147-
148-
Furthermore, with the loop fusion changes that were introduced in
149-
Julia 0.6, one can also easily weight the influence of each
150-
observation without allocating a temporary array.
151-
152-
```jldoctest bcast1
153-
julia> buffer .= value.(L1DistLoss(), [2,5,-2], [1.,2,3]) .* [2,1,0.5]
154-
3-element Vector{Float64}:
155-
2.0
156-
3.0
157-
2.5
158129
```
159130

160131
## Computing the 1st Derivatives
@@ -166,8 +137,7 @@ derivatives of the loss in one way or the other during the
166137
training process.
167138

168139
To compute the derivative of some loss we expose the function
169-
[`deriv`](@ref). It supports the same exact method signatures as
170-
[`value`](@ref). It may be interesting to note explicitly, that
140+
[`deriv`](@ref). It may be interesting to note explicitly, that
171141
we always compute the derivative in respect to the predicted
172142
`output`, since we are interested in deducing in which direction
173143
the output should change.
@@ -176,39 +146,6 @@ the output should change.
176146
deriv
177147
```
178148

179-
Similar to [`value`](@ref), this function also supports
180-
broadcasting and all the syntax benefits that come with it. Thus,
181-
one can make use of preallocated memory for storing the
182-
element-wise derivatives.
183-
184-
```jldoctest bcast2
185-
julia> deriv.(L2DistLoss(), [2,5,-2], [1,2,3])
186-
3-element Vector{Int64}:
187-
2
188-
6
189-
-10
190-
191-
julia> buffer = zeros(3); # preallocate a buffer
192-
193-
julia> buffer .= deriv.(L2DistLoss(), [2,5,-2], [1.,2,3])
194-
3-element Vector{Float64}:
195-
2.0
196-
6.0
197-
-10.0
198-
```
199-
200-
Furthermore, with the loop fusion changes that were introduced in
201-
Julia 0.6, one can also easily weight the influence of each
202-
observation without allocating a temporary array.
203-
204-
```jldoctest bcast2
205-
julia> buffer .= deriv.(L2DistLoss(), [2,5,-2], [1.,2,3]) .* [2,1,0.5]
206-
3-element Vector{Float64}:
207-
4.0
208-
6.0
209-
-5.0
210-
```
211-
212149
## Computing the 2nd Derivatives
213150

214151
Additionally to the first derivative, we also provide the
@@ -220,30 +157,6 @@ derivative in respect to the predicted `output`.
220157
deriv2
221158
```
222159

223-
Just like [`deriv`](@ref) and [`value`](@ref), this function also
224-
supports broadcasting and all the syntax benefits that come with
225-
it. Thus, one can make use of preallocated memory for storing the
226-
element-wise derivatives.
227-
228-
```jldoctest
229-
julia> deriv2.(LogitDistLoss(), [0.3, 2.3, -2], [-0.5, 1.2, 3])
230-
3-element Vector{Float64}:
231-
0.42781939304058886
232-
0.3747397590950413
233-
0.013296113341580313
234-
235-
julia> buffer = zeros(3); # preallocate a buffer
236-
237-
julia> buffer .= deriv2.(LogitDistLoss(), [0.3, 2.3, -2], [-0.5, 1.2, 3])
238-
3-element Vector{Float64}:
239-
0.42781939304058886
240-
0.3747397590950413
241-
0.013296113341580313
242-
```
243-
244-
Furthermore [`deriv2`](@ref) supports all the same method
245-
signatures as [`deriv`](@ref) does.
246-
247160
## Properties of a Loss
248161

249162
In some situations it can be quite useful to assert certain

src/LossFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ export
2121
SupervisedLoss,
2222
MarginLoss,
2323
DistanceLoss,
24-
value, deriv, deriv2,
24+
deriv, deriv2,
2525
isdistancebased, ismarginbased,
2626
isminimizable, isdifferentiable,
2727
istwicedifferentiable,

src/losses.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
# type alias to make code more readable
22
Scalar = Union{Number,CategoricalValue}
33

4-
# convenient functor interface
5-
(loss::SupervisedLoss)(output::Scalar, target::Scalar) = value(loss, output, target)
6-
74
# fallback to unary evaluation
8-
value(loss::DistanceLoss, output::Number, target::Number) = value(loss, output - target)
5+
(loss::DistanceLoss)(output::Number, target::Number) = loss(output - target)
96
deriv(loss::DistanceLoss, output::Number, target::Number) = deriv(loss, output - target)
107
deriv2(loss::DistanceLoss, output::Number, target::Number) = deriv2(loss, output - target)
118

12-
value(loss::MarginLoss, output::Number, target::Number) = value(loss, target * output)
13-
deriv(loss::MarginLoss, output::Number, target::Number) = target * deriv(loss, target * output)
14-
deriv2(loss::MarginLoss, output::Number, target::Number) = deriv2(loss, target * output)
9+
(loss::MarginLoss)(output::Number, target::Number) = loss(target * output)
10+
deriv(loss::MarginLoss, output::Number, target::Number) = target * deriv(loss, target * output)
11+
deriv2(loss::MarginLoss, output::Number, target::Number) = deriv2(loss, target * output)
1512

1613
# broadcasting behavior
1714
Broadcast.broadcastable(loss::SupervisedLoss) = Ref(loss)

src/losses/distance.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct LPDistLoss{P} <: DistanceLoss end
1313

1414
LPDistLoss(p::Number) = LPDistLoss{p}()
1515

16-
value(loss::LPDistLoss{P}, difference::Number) where {P} = abs(difference)^P
16+
(loss::LPDistLoss{P})(difference::Number) where {P} = abs(difference)^P
1717
function deriv(loss::LPDistLoss{P}, difference::T)::promote_type(typeof(P),T) where {P,T<:Number}
1818
if difference == 0
1919
zero(difference)
@@ -73,7 +73,7 @@ L(r) = |r|
7373
"""
7474
const L1DistLoss = LPDistLoss{1}
7575

76-
value(loss::L1DistLoss, difference::Number) = abs(difference)
76+
(loss::L1DistLoss)(difference::Number) = abs(difference)
7777
deriv(loss::L1DistLoss, difference::T) where {T<:Number} = convert(T, sign(difference))
7878
deriv2(loss::L1DistLoss, difference::T) where {T<:Number} = zero(T)
7979

@@ -118,7 +118,7 @@ L(r) = |r|^2
118118
"""
119119
const L2DistLoss = LPDistLoss{2}
120120

121-
value(loss::L2DistLoss, difference::Number) = abs2(difference)
121+
(loss::L2DistLoss)(difference::Number) = abs2(difference)
122122
deriv(loss::L2DistLoss, difference::T) where {T<:Number} = convert(T,2) * difference
123123
deriv2(loss::L2DistLoss, difference::T) where {T<:Number} = convert(T,2)
124124

@@ -152,7 +152,7 @@ end
152152
PeriodicLoss(circ::T=1.0) where {T<:AbstractFloat} = PeriodicLoss{T}(circ)
153153
PeriodicLoss(circ) = PeriodicLoss{Float64}(Float64(circ))
154154

155-
value(loss::PeriodicLoss, difference::T) where {T<:Number} = 1 - cos(difference*loss.k)
155+
(loss::PeriodicLoss)(difference::T) where {T<:Number} = 1 - cos(difference*loss.k)
156156
deriv(loss::PeriodicLoss, difference::T) where {T<:Number} = loss.k * sin(difference*loss.k)
157157
deriv2(loss::PeriodicLoss, difference::T) where {T<:Number} = abs2(loss.k) * cos(difference*loss.k)
158158

@@ -207,7 +207,7 @@ end
207207
HuberLoss(d::T=1.0) where {T<:AbstractFloat} = HuberLoss{T}(d)
208208
HuberLoss(d) = HuberLoss{Float64}(Float64(d))
209209

210-
function value(loss::HuberLoss{T1}, difference::T2) where {T1,T2<:Number}
210+
function (loss::HuberLoss{T1})(difference::T2) where {T1,T2<:Number}
211211
T = promote_type(T1,T2)
212212
abs_diff = abs(difference)
213213
if abs_diff <= loss.d
@@ -282,7 +282,7 @@ const EpsilonInsLoss = L1EpsilonInsLoss
282282
@inline L1EpsilonInsLoss::T) where {T<:AbstractFloat} = L1EpsilonInsLoss{T}(ε)
283283
@inline L1EpsilonInsLoss::Number) = L1EpsilonInsLoss{Float64}(Float64(ε))
284284

285-
function value(loss::L1EpsilonInsLoss{T1}, difference::T2) where {T1,T2<:Number}
285+
function (loss::L1EpsilonInsLoss{T1})(difference::T2) where {T1,T2<:Number}
286286
T = promote_type(T1,T2)
287287
max(zero(T), abs(difference) - loss.ε)
288288
end
@@ -344,7 +344,7 @@ end
344344
L2EpsilonInsLoss::T) where {T<:AbstractFloat} = L2EpsilonInsLoss{T}(ε)
345345
L2EpsilonInsLoss(ε) = L2EpsilonInsLoss{Float64}(Float64(ε))
346346

347-
function value(loss::L2EpsilonInsLoss{T1}, difference::T2) where {T1,T2<:Number}
347+
function (loss::L2EpsilonInsLoss{T1})(difference::T2) where {T1,T2<:Number}
348348
T = promote_type(T1,T2)
349349
abs2(max(zero(T), abs(difference) - loss.ε))
350350
end
@@ -399,7 +399,7 @@ L(r) = - \ln \frac{4 e^r}{(1 + e^r)^2}
399399
"""
400400
struct LogitDistLoss <: DistanceLoss end
401401

402-
function value(loss::LogitDistLoss, difference::Number)
402+
function (loss::LogitDistLoss)(difference::Number)
403403
er = exp(difference)
404404
T = typeof(er)
405405
-log(convert(T,4)) - difference + 2log(one(T) + er)
@@ -458,7 +458,7 @@ struct QuantileLoss{T <: AbstractFloat} <: DistanceLoss
458458
τ::T
459459
end
460460

461-
function value(loss::QuantileLoss{T1}, diff::T2) where {T1, T2 <: Number}
461+
function (loss::QuantileLoss{T1})(diff::T2) where {T1, T2 <: Number}
462462
T = promote_type(T1, T2)
463463
diff * (convert(T,diff > 0) - loss.τ)
464464
end
@@ -512,7 +512,7 @@ struct LogCoshLoss <: DistanceLoss end
512512
_softplus(x::T) where T<:Number = x > zero(T) ? x + log1p(exp(-x)) : log1p(exp(x))
513513
_log_cosh(x::T) where T<:Number = x + _softplus(-2x) - log(convert(T, 2))
514514

515-
function value(loss::LogCoshLoss, diff::T) where {T <: Number}
515+
function (loss::LogCoshLoss)(diff::T) where {T <: Number}
516516
_log_cosh(diff)
517517
end
518518

0 commit comments

Comments
 (0)