Skip to content

Commit 0b3e74f

Browse files
authored
Merge pull request #173 from MilesCranmer/fix-sum-speeds
Improve aggregation speeds by using `eachindex` instead of `iterate`
2 parents 7318c58 + a8f7f46 commit 0b3e74f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/losses.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ include("losses/weighted.jl")
3535
Return sum of `loss` values over the iterables `outputs` and `targets`.
3636
"""
3737
function sum(loss::SupervisedLoss, outputs, targets)
38-
sum(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets))
38+
sum(i -> loss(outputs[i], targets[i]), eachindex(outputs, targets))
3939
end
4040

4141
"""
@@ -46,7 +46,7 @@ The `weights` determine the importance of each observation. The option
4646
`normalize` divides the result by the sum of the weights.
4747
"""
4848
function sum(loss::SupervisedLoss, outputs, targets, weights; normalize=true)
49-
s = sum(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights))
49+
s = sum(i -> weights[i] * loss(outputs[i], targets[i]), eachindex(outputs, targets, weights))
5050
n = normalize ? sum(weights) : one(first(weights))
5151
s / n
5252
end
@@ -57,7 +57,7 @@ end
5757
Return mean of `loss` values over the iterables `outputs` and `targets`.
5858
"""
5959
function mean(loss::SupervisedLoss, outputs, targets)
60-
mean(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets))
60+
mean(i -> loss(outputs[i], targets[i]), eachindex(outputs, targets))
6161
end
6262

6363
"""
@@ -68,7 +68,7 @@ The `weights` determine the importance of each observation. The option
6868
`normalize` divides the result by the sum of the weights.
6969
"""
7070
function mean(loss::SupervisedLoss, outputs, targets, weights; normalize=true)
71-
m = mean(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights))
71+
m = mean(i -> weights[i] * loss(outputs[i], targets[i]), eachindex(outputs, targets, weights))
7272
n = normalize ? sum(weights) : one(first(weights))
7373
m / n
7474
end

0 commit comments

Comments
 (0)