Skip to content

Commit 25dbbbe

Browse files
allow providing fft plans
1 parent 66cf9d9 commit 25dbbbe

File tree

9 files changed

+199
-48
lines changed

9 files changed

+199
-48
lines changed

.github/workflows/UnitTest.yml

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ on:
77
branches:
88
- master
99
pull_request:
10-
schedule:
11-
- cron: '20 00 1 * *'
1210

1311
jobs:
1412
test:
@@ -20,23 +18,18 @@ jobs:
2018
os: [ubuntu-latest, windows-latest, macOS-latest]
2119

2220
steps:
23-
- uses: actions/checkout@v1.0.0
21+
- uses: actions/checkout@v4
2422
- name: "Set up Julia"
2523
uses: julia-actions/setup-julia@v1
2624
with:
2725
version: ${{ matrix.julia-version }}
2826

29-
- name: Cache artifacts
30-
uses: actions/cache@v1
31-
env:
32-
cache-name: cache-artifacts
33-
with:
34-
path: ~/.julia/artifacts
35-
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
36-
restore-keys: |
37-
${{ runner.os }}-test-${{ env.cache-name }}-
38-
${{ runner.os }}-test-
39-
${{ runner.os }}-
27+
- uses: julia-actions/cache@v1
28+
29+
- run: |
30+
import Pkg
31+
Pkg.add(Pkg.PackageSpec(url="https://github.com/HolyLab/RFFT.jl", rev="ib/add_copy"))
32+
shell: julia --project --color=yes {0}
4033
4134
- name: "Unit Test"
4235
uses: julia-actions/julia-runtest@v1

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ author = ["Tim Holy <[email protected]>", "Jan Weidner <[email protected]>"]
44
version = "0.7.8"
55

66
[deps]
7+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
78
CatIndices = "aafaddc9-749c-510e-ac4f-586e18779b91"
89
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
910
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -15,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1617
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1718
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
19+
RFFT = "3bd9afcd-55df-531a-9b34-dc642dce7b95"
1820
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1921
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2022
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -31,6 +33,7 @@ ImageCore = "0.10"
3133
OffsetArrays = "1.9"
3234
PrecompileTools = "1"
3335
Reexport = "1.1"
36+
RFFT = "0.1.1"
3437
StaticArrays = "0.10, 0.11, 0.12, 1.0"
3538
Statistics = "1"
3639
TiledIteration = "0.2, 0.3, 0.4, 0.5"

demo.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using ImageFiltering, FFTW, LinearAlgebra, Profile, Random
2+
# using ProfileView
3+
using ComputationalResources
4+
5+
FFTW.set_num_threads(parse(Int, get(ENV, "FFTW_NUM_THREADS", "1")))
6+
BLAS.set_num_threads(parse(Int, get(ENV, "BLAS_NUM_THREADS", string(Threads.nthreads() ÷ 2))))
7+
8+
function benchmark(mats)
9+
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
10+
Threads.@threads for mat in mats
11+
frame_filtered = deepcopy(mat[:, :, 1])
12+
r_cached = CPU1(ImageFiltering.planned_fft(frame_filtered, kernel))
13+
for i in axes(mat, 3)
14+
frame = @view mat[:, :, i]
15+
imfilter!(r_cached, frame_filtered, frame, kernel)
16+
end
17+
return
18+
end
19+
end
20+
21+
function test(mats)
22+
kernel = ImageFiltering.factorkernel(Kernel.LoG(1))
23+
for mat in mats
24+
f1 = deepcopy(mat[:, :, 1])
25+
r_cached = CPU1(ImageFiltering.planned_fft(f1, kernel))
26+
f2 = deepcopy(mat[:, :, 1])
27+
r_noncached = CPU1(Algorithm.FFT())
28+
for i in axes(mat, 3)
29+
frame = @view mat[:, :, i]
30+
@info "imfilter! noncached"
31+
imfilter!(r_noncached, f2, frame, kernel)
32+
@info "imfilter! cached"
33+
imfilter!(r_cached, f1, frame, kernel)
34+
@show f1[1:4] f2[1:4]
35+
f1 f2 || error("f1 !≈ f2")
36+
end
37+
return
38+
end
39+
end
40+
41+
function profile()
42+
Random.seed!(1)
43+
nmats = 10
44+
mats = [rand(Float32, rand(80:100), rand(80:100), rand(2000:3000)) for _ in 1:nmats]
45+
GC.gc(true)
46+
47+
# benchmark(mats)
48+
49+
# for _ in 1:3
50+
# @time "warm run of benchmark(mats)" benchmark(mats)
51+
# end
52+
53+
test(mats)
54+
55+
# Profile.clear()
56+
# @profile benchmark(mats)
57+
58+
# Profile.print(IOContext(stdout, :displaysize => (24, 200)); C=true, combine=true, mincount=100)
59+
# # ProfileView.view()
60+
# GC.gc(true)
61+
end
62+
63+
profile()
64+
65+
using ImageFiltering
66+
using ImageFiltering.RFFT
67+
68+
function mwe()
69+
a = rand(Float64, 10, 10)
70+
out1 = rfft(a)
71+
72+
buf = RFFT.RCpair{Float64}(undef, size(a))
73+
rfft_plan = RFFT.plan_rfft!(buf)
74+
copy!(buf, a)
75+
out2 = complex(rfft_plan(buf))
76+
77+
return out1 out2
78+
end
79+
mwe()

src/ImageFiltering.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
module ImageFiltering
22

33
using FFTW
4+
using RFFT
45
using ImageCore, FFTViews, OffsetArrays, StaticArrays, ComputationalResources, TiledIteration
56
# Where possible we avoid a direct dependency to reduce the number of [compat] bounds
67
# using FixedPointNumbers: Normed, N0f8 # reexported by ImageCore
78
using ImageCore.MappedArrays
89
using Statistics, LinearAlgebra
910
using Base: Indices, tail, fill_to_length, @pure, depwarn, @propagate_inbounds
11+
import Base: copy!
1012
using OffsetArrays: IdentityUnitRange # using the one in OffsetArrays makes this work with multiple Julia versions
1113
using SparseArrays # only needed to fix an ambiguity in borderarray
1214
using Reexport
@@ -30,7 +32,8 @@ export Kernel, KernelFactors,
3032
imgradients, padarray, centered, kernelfactors, reflect,
3133
freqkernel, spacekernel,
3234
findlocalminima, findlocalmaxima,
33-
blob_LoG, BlobLoG
35+
blob_LoG, BlobLoG,
36+
planned_fft
3437

3538
FixedColorant{T<:Normed} = Colorant{T}
3639
StaticOffsetArray{T,N,A<:StaticArray} = OffsetArray{T,N,A}
@@ -50,10 +53,16 @@ function Base.transpose(A::StaticOffsetArray{T,2}) where T
5053
end
5154

5255
module Algorithm
56+
import FFTW
5357
# deliberately don't export these, but it's expected that they
5458
# will be used as Algorithm.FFT(), etc.
5559
abstract type Alg end
56-
"Filter using the Fast Fourier Transform" struct FFT <: Alg end
60+
"Filter using the Fast Fourier Transform" struct FFT <: Alg
61+
plan1::Union{Function,Nothing}
62+
plan2::Union{Function,Nothing}
63+
plan3::Union{Function,Nothing}
64+
end
65+
FFT() = FFT(nothing, nothing, nothing)
5766
"Filter using a direct algorithm" struct FIR <: Alg end
5867
"Cache-efficient filtering using tiles" struct FIRTiled{N} <: Alg
5968
tilesize::Dims{N}

src/imfilter.jl

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ function _imfilter_fft!(r::AbstractCPU{FFT},
826826
for I in CartesianIndices(axes(kern))
827827
krn[I] = kern[I]
828828
end
829-
Af = filtfft(A, krn)
829+
Af = filtfft(A, krn, r.settings.plan1, r.settings.plan2, r.settings.plan3)
830830
if map(first, axes(out)) == map(first, axes(Af))
831831
R = CartesianIndices(axes(out))
832832
copyto!(out, R, Af, R)
@@ -837,13 +837,61 @@ function _imfilter_fft!(r::AbstractCPU{FFT},
837837
src = view(FFTView(Af), axes(dest)...)
838838
copyto!(dest, src)
839839
end
840-
out
840+
return out
841+
end
842+
843+
function buffered_planned_rfft(a::AbstractArray{T}) where {T}
844+
buf = RFFT.RCpair{T}(undef, size(a))
845+
plan = RFFT.plan_rfft!(buf; flags=FFTW.MEASURE)
846+
return function (arr::AbstractArray{T}) where {T}
847+
copy!(buf, OffsetArrays.no_offset_view(arr))
848+
return plan(buf)
849+
end
850+
end
851+
function buffered_planned_irfft(a::AbstractArray{T}) where {T}
852+
buf = RFFT.RCpair{T}(undef, size(a))
853+
plan = RFFT.plan_irfft!(buf; flags=FFTW.MEASURE)
854+
return function (arr::AbstractArray{T}) where {T}
855+
copy!(buf, OffsetArrays.no_offset_view(arr))
856+
return plan(buf)
857+
end
841858
end
842859

860+
function planned_fft(A::AbstractArray{T,N},
861+
kernel::ProcessedKernel,
862+
border::BorderSpecAny=Pad(:replicate)
863+
) where {T<:AbstractFloat,N}
864+
bord = border(kernel, A, Algorithm.FFT())
865+
_A = padarray(T, A, bord)
866+
bfp1 = buffered_planned_rfft(_A)
867+
kern = samedims(_A, kernelconv(kernel...))
868+
krn = FFTView(zeros(eltype(kern), map(length, axes(_A))))
869+
bfp2 = buffered_planned_rfft(krn)
870+
bfp3 = buffered_planned_irfft(_A)
871+
return Algorithm.FFT(bfp1, bfp2, bfp3)
872+
end
873+
planned_fft(A::AbstractArray, kernel, border::AbstractString) = planned_fft(A, kernel, borderinstance(border))
874+
planned_fft(A::AbstractArray, kernel::Union{ArrayLike,Laplacian}, border::BorderSpecAny) = planned_fft(A, factorkernel(kernel), border)
875+
876+
function filtfft(A, krn, planned_rfft1::Function, planned_rfft2::Function, planned_irfft::Function)
877+
B = complex(planned_rfft1(A))
878+
B .*= conj!(complex(planned_rfft2(krn)))
879+
return real(planned_irfft(complex(B)))
880+
end
881+
# TODO: this does not work. See TODO below
882+
function filtfft(A::AbstractArray{C}, krn, planned_rfft1::Function, planned_rfft2::Function, planned_irfft::Function) where {C<:Colorant}
883+
Av, dims = channelview_dims(A)
884+
kernrs = kreshape(C, krn)
885+
B = complex(planned_rfft1(Av, dims)) # TODO: dims is not supported by planned_rfft1
886+
B .*= conj!(complex(planned_rfft2(kernrs)))
887+
Avf = real(planned_irfft(complex(B)))
888+
return colorview(base_colorant_type(C){eltype(Avf)}, Avf)
889+
end
890+
filtfft(A, krn, ::Nothing, ::Nothing, ::Nothing) = filtfft(A, krn)
843891
function filtfft(A, krn)
844892
B = rfft(A)
845893
B .*= conj!(rfft(krn))
846-
irfft(B, length(axes(A, 1)))
894+
return irfft(B, length(axes(A, 1)))
847895
end
848896
function filtfft(A::AbstractArray{C}, krn) where {C<:Colorant}
849897
Av, dims = channelview_dims(A)

test/2d.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ using ImageFiltering: borderinstance
3636
end
3737
end
3838

39+
function supported_algs(img, kernel, border)
40+
if eltype(img) isa AbstractFloat
41+
(Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT(), planned_fft(img, kernel, border))
42+
else
43+
# TODO: extend planned_fft to support other types
44+
(Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
45+
end
46+
end
47+
3948
@testset "FIR/FFT" begin
4049
f32type(img) = f32type(eltype(img))
4150
f32type(::Type{C}) where {C<:Colorant} = base_colorant_type(C){Float32}
@@ -50,6 +59,7 @@ end
5059
# Dense inseparable kernel
5160
kern = [0.1 0.2; 0.4 0.5]
5261
kernel = OffsetArray(kern, -1:0, 1:2)
62+
border = Inner()
5363
for img in (imgf, imgi, imgg, imgc)
5464
targetimg = zeros(typeof(img[1]*kern[1]), size(img))
5565
targetimg[3:4,2:3] = rot180(kern) .* img[3,4]
@@ -66,7 +76,7 @@ end
6676
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
6777
fill!(ret, zero(eltype(ret)))
6878
@test @inferred(imfilter!(ret, img, kernel, border)) targetimg
69-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
79+
for alg in supported_algs(img, kernel, border)
7080
@test @inferred(imfilter(img, kernel, border, alg)) targetimg
7181
@test @inferred(imfilter(img, (kernel,), border, alg)) targetimg
7282
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg)
@@ -76,12 +86,12 @@ end
7686
@test_throws MethodError imfilter!(CPU1(Algorithm.FIR()), ret, img, kernel, border, Algorithm.FFT())
7787
end
7888
targetimg_inner = OffsetArray(targetimg[2:end, 1:end-2], 2:5, 1:5)
79-
@test @inferred(imfilter(img, kernel, Inner())) targetimg_inner
80-
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) float32.(targetimg_inner)
81-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
82-
@test @inferred(imfilter(img, kernel, Inner(), alg)) targetimg_inner
83-
@test @inferred(imfilter(f32type(img), img, kernel, Inner(), alg)) float32.(targetimg_inner)
84-
@test @inferred(imfilter(CPU1(alg), img, kernel, Inner())) targetimg_inner
89+
@test @inferred(imfilter(img, kernel, border)) targetimg_inner
90+
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg_inner)
91+
for alg in supported_algs(img, kernel, border)
92+
@test @inferred(imfilter(img, kernel, border, alg)) targetimg_inner
93+
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg_inner)
94+
@test @inferred(imfilter(CPU1(alg), img, kernel, border)) targetimg_inner
8595
end
8696
end
8797
# Factored kernel
@@ -96,7 +106,7 @@ end
96106
for border in ("replicate", "circular", "symmetric", "reflect", Fill(zero(eltype(img))))
97107
@test @inferred(imfilter(img, kernel, border)) targetimg
98108
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
99-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
109+
for alg in supported_algs(img, kernel, border)
100110
@test @inferred(imfilter(img, kernel, border, alg)) targetimg
101111
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg)
102112
end
@@ -106,7 +116,7 @@ end
106116
targetimg_inner = OffsetArray(targetimg[2:end, 1:end-2], 2:5, 1:5)
107117
@test @inferred(imfilter(img, kernel, Inner())) targetimg_inner
108118
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) float32.(targetimg_inner)
109-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
119+
for alg in supported_algs(img, kernel, border)
110120
@test @inferred(imfilter(img, kernel, Inner(), alg)) targetimg_inner
111121
@test @inferred(imfilter(f32type(img), img, kernel, Inner(), alg)) float32.(targetimg_inner)
112122
end
@@ -122,7 +132,7 @@ end
122132
for border in ("replicate", "circular", "symmetric", "reflect", Fill(zero(eltype(img))))
123133
@test @inferred(imfilter(img, kernel, border)) targetimg
124134
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
125-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
135+
for alg in supported_algs(img, kernel, border)
126136
if alg == Algorithm.FFT() && eltype(img) == Int
127137
@test @inferred(imfilter(Float64, img, kernel, border, alg)) targetimg
128138
else
@@ -134,7 +144,7 @@ end
134144
targetimg_inner = OffsetArray(targetimg[2:end-1, 2:end-1], 2:4, 2:6)
135145
@test @inferred(imfilter(img, kernel, Inner())) targetimg_inner
136146
@test @inferred(imfilter(f32type(img), img, kernel, Inner())) float32.(targetimg_inner)
137-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
147+
for alg in supported_algs(img, kernel, border)
138148
if alg == Algorithm.FFT() && eltype(img) == Int
139149
@test @inferred(imfilter(Float64, img, kernel, Inner(), alg)) targetimg_inner
140150
else
@@ -184,7 +194,7 @@ end
184194
targetimg = target1(img, kern, border)
185195
@test @inferred(imfilter(img, kernel, border)) targetimg
186196
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
187-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
197+
for alg in supported_algs(img, kernel, border)
188198
@test @inferred(imfilter(img, kernel, border, alg)) targetimg
189199
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg)
190200
end
@@ -195,7 +205,7 @@ end
195205
targetimg = zerona!(copy(targetimg0))
196206
@test @inferred(zerona!(imfilter(img, kernel, border))) targetimg
197207
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border))) float32.(targetimg)
198-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
208+
for alg in supported_algs(img, kernel, border)
199209
@test @inferred(zerona!(imfilter(img, kernel, border, alg), nanflag)) targetimg
200210
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border, alg), nanflag)) float32.(targetimg)
201211
end
@@ -208,7 +218,7 @@ end
208218
targetimg = target1(img, kern, border)
209219
@test @inferred(imfilter(img, kernel, border)) targetimg
210220
@test @inferred(imfilter(f32type(img), img, kernel, border)) float32.(targetimg)
211-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
221+
for alg in supported_algs(img, kernel, border)
212222
@test @inferred(imfilter(img, kernel, border, alg)) targetimg
213223
@test @inferred(imfilter(f32type(img), img, kernel, border, alg)) float32.(targetimg)
214224
end
@@ -219,7 +229,7 @@ end
219229
targetimg = zerona!(copy(targetimg0))
220230
@test @inferred(zerona!(imfilter(img, kernel, border))) targetimg
221231
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border))) float32.(targetimg)
222-
for alg in (Algorithm.FIR(), Algorithm.FIRTiled(), Algorithm.FFT())
232+
for alg in supported_algs(img, kernel, border)
223233
@test @inferred(zerona!(imfilter(img, kernel, border, alg), nanflag)) targetimg
224234
@test @inferred(zerona!(imfilter(f32type(img), img, kernel, border, alg), nanflag)) float32.(targetimg)
225235
end

0 commit comments

Comments
 (0)