Skip to content

Commit e3f7e13

Browse files
authored
Merge pull request #161 from omlins/hide
Do minor hide communication improvement
2 parents 26eacdc + 1b22a7f commit e3f7e13

File tree

4 files changed

+91
-31
lines changed

4 files changed

+91
-31
lines changed

src/ParallelKernel/hide_communication.jl

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,35 @@ function hide_communication_gpu(ranges_outer::Union{Symbol,Expr}, ranges_inner::
147147
push!(compcalls_outer, :(@parallel_async $ranges_outer[i] stream=ParallelStencil.ParallelKernel.@get_priority_stream(i) $(kwargs...) $compkernelcall)) #NOTE: it cannot directly go to ParallelStencil.ParallelKernel.@parallel_async as else it cannot honour ParallelStencil args as memopt (fixing it to ParallelStencil is also not possible as it assumes, else the ParalellKernel hide_communication unit tests fail).
148148
push!(compcalls_inner, :(@parallel_async $ranges_inner[i] stream=ParallelStencil.ParallelKernel.@get_stream(i) $(kwargs...) $compkernelcall)) #NOTE: ...
149149
end
150-
bc_and_commcalls = process_bc_and_commcalls(bc_and_commcalls)
151-
quote
152-
for i in 1:length($ranges_outer)
153-
$(compcalls_outer...)
150+
bc_and_commcalls = flatten(process_bc_and_commcalls(bc_and_commcalls))
151+
if comm_is_splitted(bc_and_commcalls)
152+
bc_and_commcalls_z, bc_and_commcalls_xy = split_bc_and_commcalls(bc_and_commcalls)
153+
quote
154+
for i in 1:length($ranges_outer)
155+
$(compcalls_outer...)
156+
end
157+
for i in 2:3 ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end # NOTE: synchronize the streams of the z-boundary computations (assumed to be stream 2 and 3 - to be in agreement with get_ranges_outer)
158+
$bc_and_commcalls_z
159+
for i in 1:length($ranges_inner)
160+
$(compcalls_inner...)
161+
end
162+
for i in 1:1 ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end
163+
for i in 4:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end
164+
$bc_and_commcalls_xy
165+
for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end
154166
end
155-
for i in 1:length($ranges_inner)
156-
$(compcalls_inner...)
167+
else
168+
quote
169+
for i in 1:length($ranges_outer)
170+
$(compcalls_outer...)
171+
end
172+
for i in 1:length($ranges_inner)
173+
$(compcalls_inner...)
174+
end
175+
for i in 1:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end
176+
$bc_and_commcalls
177+
for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end
157178
end
158-
for i in 1:length($ranges_outer) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_priority_stream(i)); end
159-
$bc_and_commcalls
160-
for i in 1:length($ranges_inner) ParallelStencil.ParallelKernel.@synchronize(ParallelStencil.ParallelKernel.@get_stream(i)); end
161179
end
162180
end
163181

@@ -169,6 +187,10 @@ end
169187
function hide_communication_gpu(boundary_width::Union{Integer,Symbol,Expr}, block::Expr; computation_calls::Integer=1)
170188
if (computation_calls < 1) @KeywordArgumentError("Invalid keyword argument in @hide_communication: computation_calls must be >= 1.") end
171189
compcalls, bc_and_commcalls = extract_calls(block, computation_calls)
190+
191+
USE_EXPERIMENTAL = false
192+
if (USE_EXPERIMENTAL) bc_and_commcalls = flatten(split_commcalls(bc_and_commcalls)) end
193+
172194
compranges = []
173195
for i in 1:length(compcalls)
174196
parallel_args = extract_args(compcalls[i], Symbol("@parallel"))
@@ -226,6 +248,41 @@ function process_bc_and_commcalls(block::Expr)
226248
end
227249
end
228250

251+
function split_commcalls(block::Expr)
252+
return postwalk(block) do x
253+
if !(!@capture(x, f_(args__; kwargs__)) && @capture(x, f_(args__)) && f == :update_halo!) return x; end
254+
return :(update_halo!($(args...); dims=(3,)); update_halo!($(args...); dims=(1,2)))
255+
end
256+
end
257+
258+
function comm_is_splitted(block::Expr)
259+
if !is_block(block) return false; end
260+
statements = block.args
261+
has_comm_z = false
262+
has_comm_xy = false
263+
for statement in statements
264+
if @capture(statement, f_(args__; kwarg_))
265+
if !has_comm_z && @capture(kwarg, dims=(3,)) has_comm_z = true
266+
elseif has_comm_z && @capture(kwarg, dims=(1,2)) has_comm_xy = true
267+
end
268+
end
269+
end
270+
return has_comm_z && has_comm_xy
271+
end
272+
273+
function split_bc_and_commcalls(block::Expr)
274+
if !is_block(block) @ModuleInternalError("expression is not a block; a block with at least two statements for communication is expected (obtained: $block)") end
275+
statements = block.args
276+
comm_z_pos = -1
277+
for i in length(statements):-1:1
278+
if (@capture(statements[i], f_(args__; kwarg_)) && @capture(kwarg, dims=(3,))) comm_z_pos = i; break; end
279+
end
280+
if (comm_z_pos < 1) @ModuleInternalError("no communication statement with dims=(3,) found in the block.") end
281+
bc_and_commcalls_z = quote $(statements[1:comm_z_pos]...) end
282+
bc_and_commcalls_xy = quote $(statements[comm_z_pos+1:end]...) end
283+
return bc_and_commcalls_z, bc_and_commcalls_xy
284+
end
285+
229286

230287
## FUNCTIONS TO GET INNER AND OUTER RANGES AND TO PROMOTE BOUNDARY_WIDTH TO 3D
231288

@@ -247,22 +304,25 @@ function get_ranges_outer(boundary_width, ranges::RANGES_TYPE...)
247304
ms = length.(ranges)
248305
bw = boundary_width
249306
if ms[3] > 1 # 3D
250-
ranges_outer = ((1:ms[1], 1:ms[2], 1:bw[3]),
251-
(1:ms[1], 1:ms[2], ms[3]-bw[3]+1:ms[3]),
252-
(1:ms[1], 1:bw[2], bw[3]+1:ms[3]-bw[3]),
253-
(1:ms[1], ms[2]-bw[2]+1:ms[2], bw[3]+1:ms[3]-bw[3]),
254-
(1:bw[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]),
255-
(ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]),
307+
ranges_outer = (
308+
(1:bw[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]), # 5
309+
(1:ms[1], 1:ms[2], 1:bw[3]), # 1
310+
(1:ms[1], 1:ms[2], ms[3]-bw[3]+1:ms[3]), # 2
311+
(ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], bw[3]+1:ms[3]-bw[3]), # 6
312+
(1:ms[1], 1:bw[2], bw[3]+1:ms[3]-bw[3]), # 3
313+
(1:ms[1], ms[2]-bw[2]+1:ms[2], bw[3]+1:ms[3]-bw[3]), # 4
256314
)
257315
elseif ms[2] > 1 # 2D
258-
ranges_outer = ((1:ms[1], 1:bw[2], 1:1),
259-
(1:ms[1], ms[2]-bw[2]+1:ms[2], 1:1),
260-
(1:bw[1], bw[2]+1:ms[2]-bw[2], 1:1),
261-
(ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], 1:1),
316+
ranges_outer = (
317+
(ms[1]-bw[1]+1:ms[1], bw[2]+1:ms[2]-bw[2], 1:1), # 4
318+
(1:ms[1], 1:bw[2], 1:1), # 1
319+
(1:ms[1], ms[2]-bw[2]+1:ms[2], 1:1), # 2
320+
(1:bw[1], bw[2]+1:ms[2]-bw[2], 1:1), # 3
262321
)
263322
elseif ms[1] > 1 # 1D
264-
ranges_outer = ((1:bw[1], 1:1, 1:1),
265-
(ms[1]-bw[1]+1:ms[1], 1:1, 1:1),
323+
ranges_outer = (
324+
(ms[1]-bw[1]+1:ms[1], 1:1, 1:1), # 2
325+
(1:bw[1], 1:1, 1:1), # 1
266326
)
267327
else
268328
@ModuleInternalError("invalid argument 'ranges'.")

src/ParallelKernel/parallel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ end
302302

303303
## @SYNCHRONIZE FUNCTIONS
304304

305-
synchronize_cuda(args::Union{Symbol,Expr}...) = :(CUDA.synchronize($(args...)))
306-
synchronize_amdgpu(args::Union{Symbol,Expr}...) = :(AMDGPU.synchronize($(args...)))
305+
synchronize_cuda(args::Union{Symbol,Expr}...) = :(CUDA.synchronize($(args...); blocking=true))
306+
synchronize_amdgpu(args::Union{Symbol,Expr}...) = :(AMDGPU.synchronize($(args...); blocking=true))
307307
synchronize_threads(args::Union{Symbol,Expr}...) = :(begin end)
308308
synchronize_polyester(args::Union{Symbol,Expr}...) = :(begin end)
309309

test/ParallelKernel/test_parallel.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import Enzyme
3434
@static if $package == $PKG_CUDA
3535
call = @prettystring(1, @parallel f(A))
3636
@test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
37-
@test occursin("CUDA.synchronize(CUDA.stream())", call)
37+
@test occursin("CUDA.synchronize(CUDA.stream(); blocking = true)", call)
3838
call = @prettystring(1, @parallel ranges f(A))
3939
@test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
4040
call = @prettystring(1, @parallel nblocks nthreads f(A))
@@ -46,7 +46,7 @@ import Enzyme
4646
elseif $package == $PKG_AMDGPU
4747
call = @prettystring(1, @parallel f(A))
4848
@test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
49-
@test occursin("AMDGPU.synchronize(AMDGPU.stream())", call)
49+
@test occursin("AMDGPU.synchronize(AMDGPU.stream(); blocking = true)", call)
5050
call = @prettystring(1, @parallel ranges f(A))
5151
@test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
5252
call = @prettystring(1, @parallel nblocks nthreads f(A))
@@ -401,11 +401,11 @@ import Enzyme
401401
end;
402402
@testset "@synchronize" begin
403403
@static if $package == $PKG_CUDA
404-
@test @prettystring(1, @synchronize()) == "CUDA.synchronize()"
405-
@test @prettystring(1, @synchronize(mystream)) == "CUDA.synchronize(mystream)"
404+
@test @prettystring(1, @synchronize()) == "CUDA.synchronize(; blocking = true)"
405+
@test @prettystring(1, @synchronize(mystream)) == "CUDA.synchronize(mystream; blocking = true)"
406406
elseif $package == $PKG_AMDGPU
407-
@test @prettystring(1, @synchronize()) == "AMDGPU.synchronize()"
408-
@test @prettystring(1, @synchronize(mystream)) == "AMDGPU.synchronize(mystream)"
407+
@test @prettystring(1, @synchronize()) == "AMDGPU.synchronize(; blocking = true)"
408+
@test @prettystring(1, @synchronize(mystream)) == "AMDGPU.synchronize(mystream; blocking = true)"
409409
end;
410410
end;
411411
@reset_parallel_kernel()

test/test_parallel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import ParallelStencil.@gorgeousexpand
2929
@static if $package == $PKG_CUDA
3030
call = @prettystring(1, @parallel f(A))
3131
@test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
32-
@test occursin("CUDA.synchronize(CUDA.stream())", call)
32+
@test occursin("CUDA.synchronize(CUDA.stream(); blocking = true)", call)
3333
call = @prettystring(1, @parallel ranges f(A))
3434
@test occursin("CUDA.@cuda blocks = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32) stream = CUDA.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
3535
call = @prettystring(1, @parallel nblocks nthreads f(A))
@@ -45,7 +45,7 @@ import ParallelStencil.@gorgeousexpand
4545
elseif $package == $PKG_AMDGPU
4646
call = @prettystring(1, @parallel f(A))
4747
@test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
48-
@test occursin("AMDGPU.synchronize(AMDGPU.stream())", call)
48+
@test occursin("AMDGPU.synchronize(AMDGPU.stream(); blocking = true)", call)
4949
call = @prettystring(1, @parallel ranges f(A))
5050
@test occursin("AMDGPU.@roc gridsize = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64)) groupsize = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 64) stream = AMDGPU.stream() f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
5151
call = @prettystring(1, @parallel nblocks nthreads f(A))

0 commit comments

Comments
 (0)