Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 98192ab

Browse files
committed
reimplement derivative of r2r in an efficient way and discard @tullio
1 parent f4a8342 commit 98192ab

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed
Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export ChebyshevTransform
22

3-
struct ChebyshevTransform{N, S}<:AbstractTransform
3+
struct ChebyshevTransform{N, S} <: AbstractTransform
44
modes::NTuple{N, S} # N == ndims(x)
55
end
66

@@ -11,7 +11,7 @@ function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N}
1111
end
1212

1313
function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray)
14-
return view(𝐱̂, map(d->1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
14+
return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
1515
end
1616

1717
function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N}
@@ -21,31 +21,30 @@ end
2121

2222
function ChainRulesCore.rrule(::typeof(FFTW.r2r), x::AbstractArray, kind, dims)
2323
y = FFTW.r2r(x, kind, dims)
24-
(M,) = size(x)[dims]
25-
r2r_pullback(Δ) = (NoTangent(), ∇r2r(unthunk(Δ), kind, dims, M), NoTangent(), NoTangent())
24+
r2r_pullback(Δ) = (NoTangent(), ∇r2r(unthunk(Δ), kind, dims), NoTangent(), NoTangent())
2625
return y, r2r_pullback
2726
end
2827

29-
function ∇r2r::AbstractArray, kind, dims, M)
30-
# derivative of r2r turns out to be r2r + a rank 4 correction
28+
function ∇r2r::AbstractArray{T}, kind, dims) where {T}
29+
# derivative of r2r turns out to be r2r
3130
Δx = FFTW.r2r(Δ, kind, dims)
32-
33-
# a1 = fill!(similar(A, M), one(T))
31+
32+
# rank 4 correction: needs @bischtob to elaborate the reason using this.
33+
# (M,) = size(Δ)[dims]
34+
# a1 = fill!(similar(Δ, M), one(T))
3435
# CUDA.@allowscalar a1[1] = a1[end] = zero(T)
3536

36-
# a2 = fill!(similar(A, M), one(T))
37+
# a2 = fill!(similar(Δ, M), one(T))
3738
# a2[1:2:end] .= -one(T)
3839
# CUDA.@allowscalar a2[1] = a2[end] = zero(T)
3940

40-
# e1 = fill!(similar(A, M), zero(T))
41+
# e1 = fill!(similar(Δ, M), zero(T))
4142
# CUDA.@allowscalar e1[1] = one(T)
4243

43-
# eN = fill!(similar(A, M), zero(T))
44+
# eN = fill!(similar(Δ, M), zero(T))
4445
# CUDA.@allowscalar eN[end] = one(T)
4546

46-
# @tullio Δx[s, i, b] +=
47-
# a1[i] * e1[k] * Δ[s, k, b] - a2[i] * eN[k] * Δ[s, k, b]
48-
# @tullio Δx[s, i, b] +=
49-
# eN[i] * a2[k] * Δ[s, k, b] - e1[i] * a1[k] * Δ[s, k, b]
47+
# Δx .+= @. a1' * sum(e1' .* Δ, dims=2) - a2' * sum(eN' .* Δ, dims=2)
48+
# Δx .+= @. eN' * sum(a2' .* Δ, dims=2) - e1' * sum(a1' .* Δ, dims=2)
5049
return Δx
5150
end

0 commit comments

Comments
 (0)