Skip to content

Commit 1a944d8

Browse files
authored
Merge pull request #153 from omlins/amdparams
Add AMD-specific launch parameters
2 parents e1194ff + fc489a7 commit 1a944d8

File tree

6 files changed

+53
-40
lines changed

6 files changed

+53
-40
lines changed

src/ParallelKernel/parallel.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,9 @@ function parallel_call_gpu(nblocks::Union{Symbol,Expr}, nthreads::Union{Symbol,E
269269
end
270270

271271
function parallel_call_gpu(ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol; stream::Union{Symbol,Expr}=default_stream(package), shmem::Union{Symbol,Expr,Nothing}=nothing, launch::Bool=true, configcall::Expr=kernelcall)
272+
nthreads_x_max = determine_nthreads_x_max(package)
272273
maxsize = :(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)))
273-
nthreads = :( ParallelStencil.ParallelKernel.compute_nthreads($maxsize) )
274+
nthreads = :( ParallelStencil.ParallelKernel.compute_nthreads($maxsize; nthreads_x_max=$nthreads_x_max) )
274275
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
275276
parallel_call_gpu(ranges, nblocks, nthreads, kernelcall, backend_kwargs_expr, async, package; stream=stream, shmem=shmem, launch=launch)
276277
end
@@ -522,9 +523,9 @@ function compute_ranges(maxsize)
522523
return (1:maxsize[1], 1:maxsize[2], 1:maxsize[3])
523524
end
524525

525-
function compute_nthreads(maxsize; nthreads_max=NTHREADS_MAX, flatdim=0) # This is a heuristic, which results in (32,8,1) threads, except if maxsize[1] < 32 or maxsize[2] < 8.
526+
function compute_nthreads(maxsize; nthreads_x_max=NTHREADS_X_MAX, nthreads_max=NTHREADS_MAX, flatdim=0) # This is a heuristic, which results in (32,8,1) threads, except if maxsize[1] < 32 or maxsize[2] < 8.
526527
maxsize = promote_maxsize(maxsize)
527-
nthreads_x = min(32, (flatdim==1) ? 1 : maxsize[1])
528+
nthreads_x = min(nthreads_x_max, (flatdim==1) ? 1 : maxsize[1])
528529
nthreads_y = min(ceil(Int,nthreads_max/nthreads_x), (flatdim==2) ? 1 : maxsize[2])
529530
nthreads_z = min(ceil(Int,nthreads_max/(nthreads_x*nthreads_y)), (flatdim==3) ? 1 : maxsize[3])
530531
return (nthreads_x, nthreads_y , nthreads_z)
@@ -536,6 +537,8 @@ function compute_nblocks(maxsize, nthreads)
536537
return ceil.(Int, maxsize./nthreads)
537538
end
538539

540+
determine_nthreads_x_max(package::Symbol) = (package == PKG_AMDGPU) ? NTHREADS_X_MAX_AMDGPU : NTHREADS_X_MAX
541+
539542

540543
## FUNCTIONS TO CREATE KERNEL LAUNCH AND SYNCHRONIZATION CALLS
541544

src/ParallelKernel/shared.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ const INT_CUDA = Int64 # NOTE: unsigned integers are not yet
2020
const INT_AMDGPU = Int64 # NOTE: ...
2121
const INT_POLYESTER = Int64 # NOTE: ...
2222
const INT_THREADS = Int64 # NOTE: ...
23+
const NTHREADS_X_MAX = 32
24+
const NTHREADS_X_MAX_AMDGPU = 64
2325
const NTHREADS_MAX = 256
2426
const INDICES = (gensym_world("ix", @__MODULE__), gensym_world("iy", @__MODULE__), gensym_world("iz", @__MODULE__))
2527
const RANGES_VARNAME = gensym_world("ranges", @__MODULE__)

src/parallel.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp
162162
if (length(posargs) > 1) @ArgumentError("maximum one positional argument (ranges) is allowed in a @parallel memopt=true call.") end
163163
parallel_call_memopt(caller, posargs..., kernelarg, backend_kwargs_expr, async; kwargs...)
164164
else
165-
ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package)
165+
ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
166166
end
167167
end
168168
end
@@ -321,6 +321,9 @@ end
321321

322322
function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
323323
if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel <kernelcall>: keyword `shmem` is not allowed when memopt=true is set.") end
324+
package = get_package(caller)
325+
nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package)
326+
nthreads_max_memopt = determine_nthreads_max_memopt(package)
324327
configcall_kwarg_expr = :(configcall=$configcall)
325328
metadata_call = create_metadata_call(configcall)
326329
metadata_module = metadata_call
@@ -331,7 +334,7 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel
331334
loopsize = :($(metadata_module).loopsize)
332335
loopsizes = :(($loopdim==3) ? (1, 1, $loopsize) : ($loopdim==2) ? (1, $loopsize, 1) : ($loopsize, 1, 1))
333336
maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes))
334-
nthreads = :( ParallelStencil.compute_nthreads_memopt($maxsize, $loopdim, $stencilranges) )
337+
nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) )
335338
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
336339
numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked
337340
dim1 = :(($loopdim==3) ? 1 : ($loopdim==2) ? 1 : 2) # TODO: to be determined if that is what is desired for loopdim 1 and 2.
@@ -344,11 +347,14 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel
344347
end
345348

346349
function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
347-
metadata_call = create_metadata_call(configcall)
348-
metadata_module = metadata_call
349-
loopdim = :($(metadata_module).loopdim)
350-
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
351-
ranges = :( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...)))
350+
package = get_package(caller)
351+
nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package)
352+
nthreads_max_memopt = determine_nthreads_max_memopt(package)
353+
metadata_call = create_metadata_call(configcall)
354+
metadata_module = metadata_call
355+
loopdim = :($(metadata_module).loopdim)
356+
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
357+
ranges = :( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...)))
352358
parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)
353359
end
354360

@@ -362,15 +368,16 @@ end
362368

363369
## FUNCTIONS TO DETERMINE OPTIMIZATION PARAMETERS
364370

371+
determine_nthreads_max_memopt(package::Symbol) = (package == PKG_AMDGPU) ? NTHREADS_MAX_MEMOPT_AMDGPU : NTHREADS_MAX_MEMOPT_CUDA
365372
determine_loopdim(indices::Union{Symbol,Expr}) = isa(indices,Expr) && (length(indices.args)==3) ? 3 : LOOPDIM_NONE # TODO: currently only loopdim=3 is supported.
366-
compute_loopsize() = LOOPSIZE
373+
compute_loopsize() = LOOPSIZE
367374

368375

369376
## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS
370377

371-
function compute_nthreads_memopt(maxsize, loopdim, stencilranges) # This is a heuristic, which results typcially in (32,4,1) threads for a 3-D case.
378+
function compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, loopdim, stencilranges) # This is a heuristic, which results typcially in (32,4,1) threads for a 3-D case.
372379
maxsize = promote_maxsize(maxsize)
373-
nthreads = ParallelKernel.compute_nthreads(maxsize; nthreads_max=NTHREADS_MAX_LOOPOPT, flatdim=loopdim)
380+
nthreads = ParallelKernel.compute_nthreads(maxsize; nthreads_x_max=nthreads_x_max, nthreads_max=nthreads_max_memopt, flatdim=loopdim)
374381
for stencilranges_A in values(stencilranges)
375382
haloextensions = ((length(stencilranges_A[1])-1)*(loopdim!=1), (length(stencilranges_A[2])-1)*(loopdim!=2), (length(stencilranges_A[3])-1)*(loopdim!=3))
376383
if (2*prod(nthreads) < prod(nthreads .+ haloextensions)) @ArgumentError("@parallel <kernelcall>: the automatic determination of nthreads is not possible for this case. Please specify `nthreads` and `nblocks`.") end # NOTE: this is a simple heuristic to compute compare the number of threads to the total number of cells including halo.
@@ -380,10 +387,10 @@ function compute_nthreads_memopt(maxsize, loopdim, stencilranges) # This is a he
380387
return nthreads
381388
end
382389

383-
function get_ranges_memopt(loopdim, args...)
390+
function get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, loopdim, args...)
384391
ranges = ParallelKernel.get_ranges(args...)
385392
maxsize = length.(ranges)
386-
nthreads = ParallelKernel.compute_nthreads(maxsize; nthreads_max=NTHREADS_MAX_LOOPOPT, flatdim=loopdim)
393+
nthreads = ParallelKernel.compute_nthreads(maxsize; nthreads_x_max=nthreads_x_max, nthreads_max=nthreads_max_memopt, flatdim=loopdim)
387394
# TODO: the following code reduces performance from ~482 GB/s to ~478 GB/s
388395
rests = maxsize .% nthreads
389396
ranges_adjustment = ( (rests[1] != 0) ? (nthreads[1] - rests[1]) : 0,

src/shared.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,23 @@ Return an expression that evaluates to `true` if the indices generated by @paral
1515
This macro is not intended for explicit manual usage. Calls to it are automatically added by @parallel where required.
1616
"""
1717

18-
const SUPPORTED_NDIMS = [1, 2, 3]
19-
const NDIMS_NONE = 0
20-
const ERRMSG_KERNEL_UNSUPPORTED = "unsupported kernel statements in @parallel kernel definition: @parallel is only applicable to kernels that contain exclusively array assignments using macros from FiniteDifferences{1|2|3}D or from another compatible computation submodule. @parallel_indices supports any kind of statements in the kernels."
21-
const ERRMSG_CHECK_NDIMS = "ndims must be evaluatable at parse time (e.g. literal or constant) and has to be one of the following Integers: $(join(SUPPORTED_NDIMS,", "))"
22-
const ERRMSG_CHECK_MEMOPT = "memopt must be evaluatable at parse time (e.g. literal or constant) and has to be of type Bool."
23-
const PSNumber = PKNumber
24-
const LOOPSIZE = 16
25-
const LOOPDIM_NONE = 0
26-
const NTHREADS_MAX_LOOPOPT = 128
27-
const USE_SHMEMHALO_DEFAULT = true
28-
const USE_SHMEMHALO_1D_DEFAULT = true
29-
const USE_FULLRANGE_DEFAULT = (false, false, true)
30-
const FULLRANGE_THRESHOLD = 1
31-
const NOEXPR = :(begin end)
32-
const MOD_METADATA = :__metadata__ # gensym_world("__metadata__", @__MODULE__) # # TODO: name mangling should be used here later, or if there is any sense to leave it like that then at check whether it's available must be done before creating it
33-
const META_FUNCTION_PREFIX = string(gensym_world("META", @__MODULE__))
18+
const SUPPORTED_NDIMS = [1, 2, 3]
19+
const NDIMS_NONE = 0
20+
const ERRMSG_KERNEL_UNSUPPORTED = "unsupported kernel statements in @parallel kernel definition: @parallel is only applicable to kernels that contain exclusively array assignments using macros from FiniteDifferences{1|2|3}D or from another compatible computation submodule. @parallel_indices supports any kind of statements in the kernels."
21+
const ERRMSG_CHECK_NDIMS = "ndims must be evaluatable at parse time (e.g. literal or constant) and has to be one of the following Integers: $(join(SUPPORTED_NDIMS,", "))"
22+
const ERRMSG_CHECK_MEMOPT = "memopt must be evaluatable at parse time (e.g. literal or constant) and has to be of type Bool."
23+
const PSNumber = PKNumber
24+
const LOOPSIZE = 16
25+
const LOOPDIM_NONE = 0
26+
const NTHREADS_MAX_MEMOPT_CUDA = 128
27+
const NTHREADS_MAX_MEMOPT_AMDGPU = 256
28+
const USE_SHMEMHALO_DEFAULT = true
29+
const USE_SHMEMHALO_1D_DEFAULT = true
30+
const USE_FULLRANGE_DEFAULT = (false, false, true)
31+
const FULLRANGE_THRESHOLD = 1
32+
const NOEXPR = :(begin end)
33+
const MOD_METADATA = :__metadata__ # gensym_world("__metadata__", @__MODULE__) # # TODO: name mangling should be used here later, or if there is any sense to leave it like that then at check whether it's available must be done before creating it
34+
const META_FUNCTION_PREFIX = string(gensym_world("META", @__MODULE__))
3435

3536

3637
## FUNCTIONS TO DEAL WITH KERNEL DEFINITIONS

0 commit comments

Comments
 (0)