@@ -35,7 +35,7 @@ include("losses/weighted.jl")
35
35
Return sum of `loss` values over the iterables `outputs` and `targets`.
36
36
"""
37
37
function sum (loss:: SupervisedLoss , outputs, targets)
38
- sum (loss (ŷ, y) for (ŷ, y) in zip (outputs, targets))
38
+ sum (i -> loss (outputs[i], targets[i]), eachindex (outputs, targets))
39
39
end
40
40
41
41
"""
@@ -46,7 +46,7 @@ The `weights` determine the importance of each observation. The option
46
46
`normalize` divides the result by the sum of the weights.
47
47
"""
48
48
function sum (loss:: SupervisedLoss , outputs, targets, weights; normalize= true )
49
- s = sum (w * loss (ŷ, y) for (ŷ, y, w) in zip (outputs, targets, weights))
49
+ s = sum (i -> weights[i] * loss (outputs[i], targets[i]), eachindex (outputs, targets, weights))
50
50
n = normalize ? sum (weights) : one (first (weights))
51
51
s / n
52
52
end
57
57
Return mean of `loss` values over the iterables `outputs` and `targets`.
58
58
"""
59
59
function mean (loss:: SupervisedLoss , outputs, targets)
60
- mean (loss (ŷ, y) for (ŷ, y) in zip (outputs, targets))
60
+ mean (i -> loss (outputs[i], targets[i]), eachindex (outputs, targets))
61
61
end
62
62
63
63
"""
@@ -68,7 +68,7 @@ The `weights` determine the importance of each observation. The option
68
68
`normalize` divides the result by the sum of the weights.
69
69
"""
70
70
function mean (loss:: SupervisedLoss , outputs, targets, weights; normalize= true )
71
- m = mean (w * loss (ŷ, y) for (ŷ, y, w) in zip (outputs, targets, weights))
71
+ m = mean (i -> weights[i] * loss (outputs[i], targets[i]), eachindex (outputs, targets, weights))
72
72
n = normalize ? sum (weights) : one (first (weights))
73
73
m / n
74
74
end
0 commit comments