Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f4a8342

Browse files
committed
gradient works for ChebyshevTransform but temporally turn down rank 4 corrections
1 parent c988399 commit f4a8342

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

src/Transform/chebyshev_transform.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,45 @@ end
77
Base.ndims(::ChebyshevTransform{N}) where {N} = N
88

99
function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N}
10-
return FFTW.r2r(𝐱, FFTW.REDFT00, 1:N) # [size(x)..., in_chs, batch]
10+
return FFTW.r2r(𝐱, FFTW.REDFT10, 1:N) # [size(x)..., in_chs, batch]
1111
end
1212

1313
function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray)
1414
return view(𝐱̂, map(d->1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
1515
end
1616

1717
function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N}
18-
return FFTW.r2r(
19-
𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))),
20-
FFTW.REDFT00,
21-
1:N,
22-
) # [size(x)..., in_chs, batch]
18+
normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1)))
19+
return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch]
20+
end
21+
22+
function ChainRulesCore.rrule(::typeof(FFTW.r2r), x::AbstractArray, kind, dims)
23+
y = FFTW.r2r(x, kind, dims)
24+
(M,) = size(x)[dims]
25+
r2r_pullback(Δ) = (NoTangent(), ∇r2r(unthunk(Δ), kind, dims, M), NoTangent(), NoTangent())
26+
return y, r2r_pullback
27+
end
28+
29+
function ∇r2r::AbstractArray, kind, dims, M)
30+
# derivative of r2r turns out to be r2r + a rank 4 correction
31+
Δx = FFTW.r2r(Δ, kind, dims)
32+
33+
# a1 = fill!(similar(A, M), one(T))
34+
# CUDA.@allowscalar a1[1] = a1[end] = zero(T)
35+
36+
# a2 = fill!(similar(A, M), one(T))
37+
# a2[1:2:end] .= -one(T)
38+
# CUDA.@allowscalar a2[1] = a2[end] = zero(T)
39+
40+
# e1 = fill!(similar(A, M), zero(T))
41+
# CUDA.@allowscalar e1[1] = one(T)
42+
43+
# eN = fill!(similar(A, M), zero(T))
44+
# CUDA.@allowscalar eN[end] = one(T)
45+
46+
# @tullio Δx[s, i, b] +=
47+
# a1[i] * e1[k] * Δ[s, k, b] - a2[i] * eN[k] * Δ[s, k, b]
48+
# @tullio Δx[s, i, b] +=
49+
# eN[i] * a2[k] * Δ[s, k, b] - e1[i] * a1[k] * Δ[s, k, b]
50+
return Δx
2351
end

test/Transform/chebyshev_transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
1111
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch)
1212

13-
@test_broken g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱)
14-
@test_broken size(g[1]) == (30, 40, 50, ch, batch)
13+
g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱)
14+
@test size(g[1]) == (30, 40, 50, ch, batch)
1515
end

0 commit comments

Comments
 (0)