Skip to content

Commit 1fff2f9

Browse files
authored
Move CategoricalArrays.jl into extension (#168)
* Move CategoricalArrays.jl into extension * Add [extras] section for backwards compatibility
1 parent cfcdc6a commit 1fff2f9

File tree

5 files changed

+39
-9
lines changed

5 files changed

+39
-9
lines changed

Project.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
name = "LossFunctions"
22
uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
3-
version = "0.10.0"
3+
version = "0.10.1"
44

55
[deps]
6-
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
76
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
7+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
88
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
99

10+
[weakdeps]
11+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
12+
13+
[extensions]
14+
LossFunctionsCategoricalArraysExt = "CategoricalArrays"
15+
1016
[compat]
1117
CategoricalArrays = "0.10"
18+
Requires = "1"
1219
julia = "1.6"
20+
21+
[extras]
22+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module LossFunctionsCategoricalArraysExt
2+
3+
if isdefined(Base, :get_extension)
4+
import LossFunctions: MisclassLoss, deriv, deriv2
5+
import CategoricalArrays: CategoricalValue
6+
else
7+
import ..LossFunctions: MisclassLoss, deriv, deriv2
8+
import ..CategoricalArrays: CategoricalValue
9+
end
10+
11+
# type alias to make code more readable
12+
const Scalar = Union{Number,CategoricalValue}
13+
14+
(loss::MisclassLoss)(output::Scalar, target::Scalar) = loss(target == output)
15+
deriv(loss::MisclassLoss, output::Scalar, target::Scalar) = deriv(loss, target == output)
16+
deriv2(loss::MisclassLoss, output::Scalar, target::Scalar) = deriv2(loss, target == output)
17+
18+
end

src/LossFunctions.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module LossFunctions
22

33
using Markdown
4-
using CategoricalArrays: CategoricalValue
54

65
import Base: sum
76
import Statistics: mean
7+
import Requires: @init, @require
88

99
# trait functions
1010
include("traits.jl")
@@ -15,6 +15,11 @@ include("losses.jl")
1515
# IO methods
1616
include("io.jl")
1717

18+
# Extensions
19+
if !isdefined(Base, :get_extension)
20+
@init @require CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" include("../ext/LossFunctionsCategoricalArraysExt.jl")
21+
end
22+
1823
export
1924
# trait functions
2025
Loss,

src/losses.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# type alias to make code more readable
2-
Scalar = Union{Number,CategoricalValue}
3-
41
# fallback to unary evaluation
52
(loss::DistanceLoss)(output::Number, target::Number) = loss(output - target)
63
deriv(loss::DistanceLoss, output::Number, target::Number) = deriv(loss, output - target)

src/losses/other.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ MisclassLoss() = MisclassLoss{Float64}()
1717
deriv(::MisclassLoss{R}, agreement::Bool) where {R} = zero(R)
1818
deriv2(::MisclassLoss{R}, agreement::Bool) where {R} = zero(R)
1919

20-
(loss::MisclassLoss)(output::Scalar, target::Scalar) = loss(target == output)
21-
deriv(loss::MisclassLoss, output::Scalar, target::Scalar) = deriv(loss, target == output)
22-
deriv2(loss::MisclassLoss, output::Scalar, target::Scalar) = deriv2(loss, target == output)
20+
(loss::MisclassLoss)(output::Number, target::Number) = loss(target == output)
21+
deriv(loss::MisclassLoss, output::Number, target::Number) = deriv(loss, target == output)
22+
deriv2(loss::MisclassLoss, output::Number, target::Number) = deriv2(loss, target == output)
2323

2424
isminimizable(::MisclassLoss) = false
2525
isdifferentiable(::MisclassLoss) = false

0 commit comments

Comments
 (0)