Skip to content
Merged
31 changes: 20 additions & 11 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,31 +172,40 @@ function Base.show(io::IO, l::Dense)
end

"""
Diagonal(α, β)
Diagonal(size::Integer...)
Diagonal(size::Integer...; bias=true, init=ones32)
Diagonal(scale::AbstractArray, [bias])

Create an element-wise linear layer, which performs

y = α .* x .+ β
y = scale .* x .+ bias

The learnable arrays are initialised `α = ones(Float32, size)` and
`β = zeros(Float32, size)`.
with no activation function.

The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
with `init=ones32` by default. You may specify the function `init`,
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.

Used by [`LayerNorm`](@ref).
"""
struct Diagonal{T}
α::T
β::T
struct Diagonal{A<:AbstractArray, B}
scale::A
bias::B
function Diagonal(W::M, bias = true) where M<:AbstractArray
b = create_bias(W, bias, size(W)...)
new{M, typeof(b)}(W, b)
end
end

Diagonal(sz::Integer...) = Diagonal(ones32(sz...), zeros32(sz...))
Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias)

@functor Diagonal

(a::Diagonal)(x) = a.α .* x .+ a.β
(a::Diagonal)(x) = a.scale .* x .+ a.bias

function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", join(size(l.α), ", "), ")")
print(io, "Diagonal(", join(size(l.scale), ", "))
l.bias == false && print(io, "; bias=false")
print(io, ")")
end

"""
Expand Down
7 changes: 4 additions & 3 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,17 @@ import Flux: activations
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))

@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does these tests need bias=false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don't really, I kept those just to check that bias=false doesn't trip anything


@test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2,3,4)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2,3)(rand(2,1,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4)
end

@testset "Maxout" begin
Expand Down