@@ -117,7 +117,8 @@ struct WorkgroupLoop
117
117
indicies :: Vector{Any}
118
118
stmts :: Vector{Any}
119
119
allocations :: Vector{Any}
120
- private :: Vector{Any}
120
+ private_allocations :: Vector{Any}
121
+ private :: Set{Symbol}
121
122
end
122
123
123
124
is_sync (expr) = @capture (expr, @synchronize () | @synchronize (a_))
@@ -138,21 +139,23 @@ end
138
139
139
140
# TODO proper handling of LineInfo
140
141
function split (stmts,
141
- indicies = Any[], private= Any[] )
142
+ indicies = Any[], private = Set {Symbol} () )
142
143
# 1. Split the code into blocks separated by `@synchronize`
143
144
# 2. Aggregate `@index` expressions
144
145
# 3. Hoist allocations
145
146
# 4. Hoist uniforms
146
147
147
148
current = Any[]
148
149
allocations = Any[]
150
+ private_allocations = Any[]
149
151
new_stmts = Any[]
150
152
for stmt in stmts
151
153
has_sync = find_sync (stmt)
152
154
if has_sync
153
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, deepcopy (private))
155
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
154
156
push! (new_stmts, emit (loop))
155
157
allocations = Any[]
158
+ private_allocations = Any[]
156
159
current = Any[]
157
160
is_sync (stmt) && continue
158
161
@@ -177,15 +180,26 @@ function split(stmts,
177
180
if @capture (stmt, @uniform x_)
178
181
push! (allocations, stmt)
179
182
continue
183
+ elseif @capture (stmt, @private lhs_ = rhs_)
184
+ push! (private, lhs)
185
+ push! (private_allocations, :($ lhs = $ rhs))
186
+ continue
180
187
elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
181
188
if @capture (rhs, @index (args__))
182
189
push! (indicies, stmt)
183
190
continue
184
191
elseif @capture (rhs, @localmem (args__) | @uniform (args__) )
185
192
push! (allocations, stmt)
186
193
continue
187
- elseif @capture (rhs, @private (args__))
188
- push! (allocations, stmt)
194
+ elseif @capture (rhs, @private (T_, dims_))
195
+ # Implement the legacy `mem = @private T dims` as
196
+ # @private mem = Scratchpad(T, Val(dims))
197
+
198
+ if dims isa Integer
199
+ dims = (dims,)
200
+ end
201
+ alloc = :($ Scratchpad ($ T, Val ($ dims)))
202
+ push! (private_allocations, :($ lhs = $ alloc))
189
203
push! (private, lhs)
190
204
continue
191
205
end
@@ -196,7 +210,7 @@ function split(stmts,
196
210
197
211
# everything since the last `@synchronize`
198
212
if ! isempty (current)
199
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, deepcopy (private))
213
+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
200
214
push! (new_stmts, emit (loop))
201
215
end
202
216
return new_stmts
@@ -212,13 +226,34 @@ function emit(loop)
212
226
end
213
227
stmts = Any[]
214
228
append! (stmts, loop. allocations)
229
+
230
+ # private_allocations turn into lhs = ntuple(i->rhs, length(__workitems_iterspace()))
231
+ N = gensym (:N )
232
+ push! (stmts, :($ N = length ($ __workitems_iterspace ())))
233
+
234
+ for stmt in loop. private_allocations
235
+ if @capture (stmt, lhs_ = rhs_)
236
+ push! (stmts, :($ lhs = ntuple (_-> $ rhs, $ N)))
237
+ else
238
+ error (" @private $stmt not an assignment" )
239
+ end
240
+ end
241
+
215
242
# don't emit empty loops
216
243
if ! (isempty (loop. stmts) || all (s-> s isa LineNumberNode, loop. stmts))
217
244
body = Expr (:block , loop. stmts... )
218
245
body = postwalk (body) do expr
219
- if @capture (expr, A_[i__])
246
+ if @capture (expr, lhs_ = rhs_)
247
+ if lhs in loop. private
248
+ error (" Can't assign to variables marked private" )
249
+ end
250
+ elseif @capture (expr, A_[i__])
220
251
if A in loop. private
221
- return :($ A[$ (i... ), $ (idx). I... ])
252
+ return :($ A[$ __index_Local_Linear ($ (idx))][$ (i... )])
253
+ end
254
+ elseif expr isa Symbol
255
+ if expr in loop. private
256
+ return :($ expr[$ __index_Local_Linear ($ (idx))])
222
257
end
223
258
end
224
259
return expr
0 commit comments