1
1
export ChebyshevTransform
2
2
3
- struct ChebyshevTransform{N, S}<: AbstractTransform
3
+ struct ChebyshevTransform{N, S} <: AbstractTransform
4
4
modes:: NTuple{N, S} # N == ndims(x)
5
5
end
6
6
@@ -11,7 +11,7 @@ function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N}
11
11
end
12
12
13
13
function truncate_modes (t:: ChebyshevTransform , 𝐱̂:: AbstractArray )
14
- return view (𝐱̂, map (d-> 1 : d, t. modes)... , :, :) # [t.modes..., in_chs, batch]
14
+ return view (𝐱̂, map (d -> 1 : d, t. modes)... , :, :) # [t.modes..., in_chs, batch]
15
15
end
16
16
17
17
function inverse (t:: ChebyshevTransform{N} , 𝐱̂:: AbstractArray ) where {N}
21
21
22
22
function ChainRulesCore. rrule (:: typeof (FFTW. r2r), x:: AbstractArray , kind, dims)
23
23
y = FFTW. r2r (x, kind, dims)
24
- (M,) = size (x)[dims]
25
- r2r_pullback (Δ) = (NoTangent (), ∇r2r (unthunk (Δ), kind, dims, M), NoTangent (), NoTangent ())
24
+ r2r_pullback (Δ) = (NoTangent (), ∇r2r (unthunk (Δ), kind, dims), NoTangent (), NoTangent ())
26
25
return y, r2r_pullback
27
26
end
28
27
29
- function ∇r2r (Δ:: AbstractArray , kind, dims, M)
30
- # derivative of r2r turns out to be r2r + a rank 4 correction
28
+ function ∇r2r (Δ:: AbstractArray{T} , kind, dims) where {T}
29
+ # derivative of r2r turns out to be r2r
31
30
Δx = FFTW. r2r (Δ, kind, dims)
32
-
33
- # a1 = fill!(similar(A, M), one(T))
31
+
32
+ # rank 4 correction: needs @bischtob to elaborate the reason using this.
33
+ # (M,) = size(Δ)[dims]
34
+ # a1 = fill!(similar(Δ, M), one(T))
34
35
# CUDA.@allowscalar a1[1] = a1[end] = zero(T)
35
36
36
- # a2 = fill!(similar(A , M), one(T))
37
+ # a2 = fill!(similar(Δ , M), one(T))
37
38
# a2[1:2:end] .= -one(T)
38
39
# CUDA.@allowscalar a2[1] = a2[end] = zero(T)
39
40
40
- # e1 = fill!(similar(A , M), zero(T))
41
+ # e1 = fill!(similar(Δ , M), zero(T))
41
42
# CUDA.@allowscalar e1[1] = one(T)
42
43
43
- # eN = fill!(similar(A , M), zero(T))
44
+ # eN = fill!(similar(Δ , M), zero(T))
44
45
# CUDA.@allowscalar eN[end] = one(T)
45
46
46
- # @tullio Δx[s, i, b] +=
47
- # a1[i] * e1[k] * Δ[s, k, b] - a2[i] * eN[k] * Δ[s, k, b]
48
- # @tullio Δx[s, i, b] +=
49
- # eN[i] * a2[k] * Δ[s, k, b] - e1[i] * a1[k] * Δ[s, k, b]
47
+ # Δx .+= @. a1' * sum(e1' .* Δ, dims=2) - a2' * sum(eN' .* Δ, dims=2)
48
+ # Δx .+= @. eN' * sum(a2' .* Δ, dims=2) - e1' * sum(a1' .* Δ, dims=2)
50
49
return Δx
51
50
end
0 commit comments