Skip to content

Commit 3d6885b

Browse files
authored
Merge pull request #175 from JuliaGPU/vc/fancy_private
Fancy private
2 parents 4792d00 + ea2607a commit 3d6885b

File tree

3 files changed

+55
-24
lines changed

3 files changed

+55
-24
lines changed

src/KernelAbstractions.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ macro private(T, dims)
154154
end
155155
end
156156

157+
"""
158+
@private mem = 1
159+
160+
Creates a private local of `mem` per item in the workgroup. This can be safely used
161+
across [`@synchronize`](@ref) statements.
162+
"""
163+
macro private(expr)
164+
expr
165+
end
166+
157167
"""
158168
@uniform expr
159169

src/backends/cpu.jl

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -241,23 +241,9 @@ end
241241

242242
###
243243
# CPU implementation of scratch memory
244-
# - private memory for each workitem
245-
# - memory allocated as a MArray with size `Dims + WorkgroupSize`
244+
# - memory allocated as a MArray with size `Dims`
246245
###
247-
struct ScratchArray{N, D}
248-
data::D
249-
ScratchArray{N}(data::D) where {N, D} = new{N, D}(data)
250-
end
251-
Base.eltype(a::ScratchArray) = eltype(a.data)
252246

253247
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
254-
return ScratchArray{length(Dims)}(MArray{__size((Dims..., __groupsize(ctx.metadata)...)), T}(undef))
255-
end
256-
257-
Base.@propagate_inbounds function Base.getindex(A::ScratchArray, I...)
258-
return A.data[I...]
259-
end
260-
261-
Base.@propagate_inbounds function Base.setindex!(A::ScratchArray, val, I...)
262-
A.data[I...] = val
248+
return MArray{__size(Dims), T}(undef)
263249
end

src/macros.jl

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ struct WorkgroupLoop
117117
indicies :: Vector{Any}
118118
stmts :: Vector{Any}
119119
allocations :: Vector{Any}
120-
private :: Vector{Any}
120+
private_allocations :: Vector{Any}
121+
private :: Set{Symbol}
121122
end
122123

123124
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))
@@ -138,21 +139,23 @@ end
138139

139140
# TODO proper handling of LineInfo
140141
function split(stmts,
141-
indicies = Any[], private=Any[])
142+
indicies = Any[], private = Set{Symbol}())
142143
# 1. Split the code into blocks separated by `@synchronize`
143144
# 2. Aggregate `@index` expressions
144145
# 3. Hoist allocations
145146
# 4. Hoist uniforms
146147

147148
current = Any[]
148149
allocations = Any[]
150+
private_allocations = Any[]
149151
new_stmts = Any[]
150152
for stmt in stmts
151153
has_sync = find_sync(stmt)
152154
if has_sync
153-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private))
155+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private))
154156
push!(new_stmts, emit(loop))
155157
allocations = Any[]
158+
private_allocations = Any[]
156159
current = Any[]
157160
is_sync(stmt) && continue
158161

@@ -177,15 +180,26 @@ function split(stmts,
177180
if @capture(stmt, @uniform x_)
178181
push!(allocations, stmt)
179182
continue
183+
elseif @capture(stmt, @private lhs_ = rhs_)
184+
push!(private, lhs)
185+
push!(private_allocations, :($lhs = $rhs))
186+
continue
180187
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
181188
if @capture(rhs, @index(args__))
182189
push!(indicies, stmt)
183190
continue
184191
elseif @capture(rhs, @localmem(args__) | @uniform(args__) )
185192
push!(allocations, stmt)
186193
continue
187-
elseif @capture(rhs, @private(args__))
188-
push!(allocations, stmt)
194+
elseif @capture(rhs, @private(T_, dims_))
195+
# Implement the legacy `mem = @private T dims` as
196+
# @private mem = Scratchpad(T, Val(dims))
197+
198+
if dims isa Integer
199+
dims = (dims,)
200+
end
201+
alloc = :($Scratchpad($T, Val($dims)))
202+
push!(private_allocations, :($lhs = $alloc))
189203
push!(private, lhs)
190204
continue
191205
end
@@ -196,7 +210,7 @@ function split(stmts,
196210

197211
# everything since the last `@synchronize`
198212
if !isempty(current)
199-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private))
213+
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private))
200214
push!(new_stmts, emit(loop))
201215
end
202216
return new_stmts
@@ -212,13 +226,34 @@ function emit(loop)
212226
end
213227
stmts = Any[]
214228
append!(stmts, loop.allocations)
229+
230+
# private_allocations turn into lhs = ntuple(i->rhs, length(__workitems_iterspace()))
231+
N = gensym(:N)
232+
push!(stmts, :($N = length($__workitems_iterspace())))
233+
234+
for stmt in loop.private_allocations
235+
if @capture(stmt, lhs_ = rhs_)
236+
push!(stmts, :($lhs = ntuple(_->$rhs, $N)))
237+
else
238+
error("@private $stmt not an assignment")
239+
end
240+
end
241+
215242
# don't emit empty loops
216243
if !(isempty(loop.stmts) || all(s->s isa LineNumberNode, loop.stmts))
217244
body = Expr(:block, loop.stmts...)
218245
body = postwalk(body) do expr
219-
if @capture(expr, A_[i__])
246+
if @capture(expr, lhs_ = rhs_)
247+
if lhs in loop.private
248+
error("Can't assign to variables marked private")
249+
end
250+
elseif @capture(expr, A_[i__])
220251
if A in loop.private
221-
return :($A[$(i...), $(idx).I...])
252+
return :($A[$__index_Local_Linear($(idx))][$(i...)])
253+
end
254+
elseif expr isa Symbol
255+
if expr in loop.private
256+
return :($expr[$__index_Local_Linear($(idx))])
222257
end
223258
end
224259
return expr

0 commit comments

Comments
 (0)