Skip to content

Commit 5ce43dd

Browse files
committed
add Shard Levels
1 parent 26e419b commit 5ce43dd

18 files changed

+966
-336
lines changed

ext/SparseArraysExt.jl

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ end
8484
ptr
8585
idx
8686
val
87-
qos_fill
88-
qos_stop
87+
qos_used
88+
qos_alloc
8989
prev_pos
9090
end
9191

@@ -127,14 +127,14 @@ function Finch.virtualize(ctx, ex, ::Type{<:SparseMatrixCSC{Tv,Ti}}, tag=:tns) w
127127
$val = $tag.nzval
128128
end,
129129
)
130-
qos_fill = freshen(ctx, tag, :_qos_fill)
131-
qos_stop = freshen(ctx, tag, :_qos_stop)
130+
qos_used = freshen(ctx, tag, :_qos_used)
131+
qos_alloc = freshen(ctx, tag, :_qos_alloc)
132132
prev_pos = freshen(ctx, tag, :_prev_pos)
133133
shape = [
134134
VirtualExtent(literal(1), value(m, Ti)), VirtualExtent(literal(1), value(n, Ti))
135135
]
136136
VirtualSparseMatrixCSC(
137-
tag, Tv, Ti, shape, ptr, idx, val, qos_fill, qos_stop, prev_pos
137+
tag, Tv, Ti, shape, ptr, idx, val, qos_used, qos_alloc, prev_pos
138138
)
139139
end
140140

@@ -149,8 +149,8 @@ function distribute(
149149
distribute_buffer(ctx, arr.ptr, arch, style),
150150
distribute_buffer(ctx, arr.idx, arch, style),
151151
distribute_buffer(ctx, arr.val, arch, style),
152-
arr.qos_fill,
153-
arr.qos_stop,
152+
arr.qos_used,
153+
arr.qos_alloc,
154154
arr.prev_pos,
155155
)
156156
end
@@ -166,8 +166,8 @@ function Finch.declare!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, init
166166
push_preamble!(
167167
ctx,
168168
quote
169-
$(arr.qos_fill) = $(Tp(0))
170-
$(arr.qos_stop) = $(Tp(0))
169+
$(arr.qos_used) = $(Tp(0))
170+
$(arr.qos_alloc) = $(Tp(0))
171171
resize!($(arr.ptr), $pos_stop + 1)
172172
fill_range!($(arr.ptr), $(Tp(0)), 1, $pos_stop + 1)
173173
$(arr.ptr)[1] = $(Tp(1))
@@ -187,17 +187,17 @@ end
187187
function Finch.freeze!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC)
188188
p = freshen(ctx, :p)
189189
pos_stop = ctx(getstop(virtual_size(ctx, arr)[2]))
190-
qos_stop = freshen(ctx, :qos_stop)
190+
qos_alloc = freshen(ctx, :qos_alloc)
191191
push_preamble!(
192192
ctx,
193193
quote
194194
resize!($(arr.ptr), $pos_stop + 1)
195195
for $p in 1:($pos_stop)
196196
$(arr.ptr)[$p + 1] += $(arr.ptr)[$p]
197197
end
198-
$qos_stop = $(arr.ptr)[$pos_stop + 1] - 1
199-
resize!($(arr.idx), $qos_stop)
200-
resize!($(arr.val), $qos_stop)
198+
$qos_alloc = $(arr.ptr)[$pos_stop + 1] - 1
199+
resize!($(arr.idx), $qos_alloc)
200+
resize!($(arr.val), $qos_alloc)
201201
end,
202202
)
203203
return arr
@@ -206,19 +206,19 @@ end
206206
function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC)
207207
p = freshen(ctx, :p)
208208
pos_stop = ctx(getstop(virtual_size(ctx, arr)[2]))
209-
qos_stop = freshen(ctx, :qos_stop)
209+
qos_alloc = freshen(ctx, :qos_alloc)
210210
push_preamble!(
211211
ctx,
212212
quote
213-
$(arr.qos_fill) = $(arr.ptr)[$pos_stop + 1] - 1
214-
$(arr.qos_stop) = $(arr.qos_fill)
215-
$qos_stop = $(arr.qos_fill)
213+
$(arr.qos_used) = $(arr.ptr)[$pos_stop + 1] - 1
214+
$(arr.qos_alloc) = $(arr.qos_used)
215+
$qos_alloc = $(arr.qos_used)
216216
$(
217217
if issafe(get_mode_flag(ctx))
218218
quote
219219
$(arr.prev_pos) =
220220
Finch.scansearch(
221-
$(arr.ptr), $(arr.qos_stop) + 1, 1, $pos_stop
221+
$(arr.ptr), $(arr.qos_alloc) + 1, 1, $pos_stop
222222
) - 1
223223
end
224224
end
@@ -350,12 +350,12 @@ function Finch.unfurl(
350350
j = tns.j
351351
Tp = arr.Ti
352352
qos = freshen(ctx, tag, :_qos)
353-
qos_fill = arr.qos_fill
354-
qos_stop = arr.qos_stop
353+
qos_used = arr.qos_used
354+
qos_alloc = arr.qos_alloc
355355
dirty = freshen(ctx, tag, :dirty)
356356
Thunk(;
357357
preamble = quote
358-
$qos = $qos_fill + 1
358+
$qos = $qos_used + 1
359359
$(if issafe(get_mode_flag(ctx))
360360
quote
361361
$(arr.prev_pos) < $(ctx(j)) || throw(FinchProtocolError("SparseMatrixCSCs cannot be updated multiple times"))
@@ -365,10 +365,10 @@ function Finch.unfurl(
365365
body = (ctx) -> Lookup(;
366366
body=(ctx, idx) -> Thunk(;
367367
preamble = quote
368-
if $qos > $qos_stop
369-
$qos_stop = max($qos_stop << 1, 1)
370-
Finch.resize_if_smaller!($(arr.idx), $qos_stop)
371-
Finch.resize_if_smaller!($(arr.val), $qos_stop)
368+
if $qos > $qos_alloc
369+
$qos_alloc = max($qos_alloc << 1, 1)
370+
Finch.resize_if_smaller!($(arr.idx), $qos_alloc)
371+
Finch.resize_if_smaller!($(arr.val), $qos_alloc)
372372
end
373373
$dirty = false
374374
end,
@@ -387,8 +387,8 @@ function Finch.unfurl(
387387
)
388388
),
389389
epilogue = quote
390-
$(arr.ptr)[$(ctx(j)) + 1] += $qos - $qos_fill - 1
391-
$qos_fill = $qos - 1
390+
$(arr.ptr)[$(ctx(j)) + 1] += $qos - $qos_used - 1
391+
$qos_used = $qos - 1
392392
end,
393393
)
394394
end
@@ -429,8 +429,8 @@ end
429429
shape
430430
idx
431431
val
432-
qos_fill
433-
qos_stop
432+
qos_used
433+
qos_alloc
434434
end
435435

436436
function Finch.virtual_size(ctx::AbstractCompiler, arr::VirtualSparseVector)
@@ -460,9 +460,9 @@ function Finch.virtualize(ctx, ex, ::Type{<:SparseVector{Tv,Ti}}, tag=:tns) wher
460460
$val = $tag.nzval
461461
end,
462462
)
463-
qos_fill = freshen(ctx, tag, :_qos_fill)
464-
qos_stop = freshen(ctx, tag, :_qos_stop)
465-
VirtualSparseVector(tag, Tv, Ti, shape, idx, val, qos_fill, qos_stop)
463+
qos_used = freshen(ctx, tag, :_qos_used)
464+
qos_alloc = freshen(ctx, tag, :_qos_alloc)
465+
VirtualSparseVector(tag, Tv, Ti, shape, idx, val, qos_used, qos_alloc)
466466
end
467467

468468
function distribute(
@@ -475,8 +475,8 @@ function distribute(
475475
arr.shape,
476476
distribute_buffer(ctx, arr.idx, arch, style),
477477
distribute_buffer(ctx, arr.val, arch, style),
478-
arr.qos_fill,
479-
arr.qos_stop,
478+
arr.qos_used,
479+
arr.qos_alloc,
480480
)
481481
end
482482

@@ -490,36 +490,36 @@ function Finch.declare!(ctx::AbstractCompiler, arr::VirtualSparseVector, init)
490490
push_preamble!(
491491
ctx,
492492
quote
493-
$(arr.qos_fill) = $(Tp(0))
494-
$(arr.qos_stop) = $(Tp(0))
493+
$(arr.qos_used) = $(Tp(0))
494+
$(arr.qos_alloc) = $(Tp(0))
495495
end,
496496
)
497497
return arr
498498
end
499499

500500
function Finch.freeze!(ctx::AbstractCompiler, arr::VirtualSparseVector)
501501
p = freshen(ctx, :p)
502-
qos_stop = freshen(ctx, :qos_stop)
502+
qos_alloc = freshen(ctx, :qos_alloc)
503503
push_preamble!(
504504
ctx,
505505
quote
506-
$qos_stop = $(ctx(arr.qos_fill))
507-
resize!($(arr.idx), $qos_stop)
508-
resize!($(arr.val), $qos_stop)
506+
$qos_alloc = $(ctx(arr.qos_used))
507+
resize!($(arr.idx), $qos_alloc)
508+
resize!($(arr.val), $qos_alloc)
509509
end,
510510
)
511511
return arr
512512
end
513513

514514
function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseVector)
515515
p = freshen(ctx, :p)
516-
qos_stop = freshen(ctx, :qos_stop)
516+
qos_alloc = freshen(ctx, :qos_alloc)
517517
push_preamble!(
518518
ctx,
519519
quote
520-
$(arr.qos_fill) = length($(arr.idx))
521-
$(arr.qos_stop) = $(arr.qos_fill)
522-
$qos_stop = $(arr.qos_fill)
520+
$(arr.qos_used) = length($(arr.idx))
521+
$(arr.qos_alloc) = $(arr.qos_used)
522+
$qos_alloc = $(arr.qos_used)
523523
end,
524524
)
525525
return arr
@@ -593,23 +593,23 @@ function Finch.unfurl(
593593
tag = arr.tag
594594
Tp = arr.Ti
595595
qos = freshen(ctx, tag, :_qos)
596-
qos_fill = arr.qos_fill
597-
qos_stop = arr.qos_stop
596+
qos_used = arr.qos_used
597+
qos_alloc = arr.qos_alloc
598598
dirty = freshen(ctx, tag, :dirty)
599599

600600
Unfurled(;
601601
arr=arr,
602602
body=Thunk(;
603603
preamble = quote
604-
$qos = $qos_fill + 1
604+
$qos = $qos_used + 1
605605
end,
606606
body = (ctx) -> Lookup(;
607607
body=(ctx, idx) -> Thunk(;
608608
preamble = quote
609-
if $qos > $qos_stop
610-
$qos_stop = max($qos_stop << 1, 1)
611-
Finch.resize_if_smaller!($(arr.idx), $qos_stop)
612-
Finch.resize_if_smaller!($(arr.val), $qos_stop)
609+
if $qos > $qos_alloc
610+
$qos_alloc = max($qos_alloc << 1, 1)
611+
Finch.resize_if_smaller!($(arr.idx), $qos_alloc)
612+
Finch.resize_if_smaller!($(arr.val), $qos_alloc)
613613
end
614614
$dirty = false
615615
end,
@@ -623,7 +623,7 @@ function Finch.unfurl(
623623
)
624624
),
625625
epilogue = quote
626-
$qos_fill = $qos - 1
626+
$qos_used = $qos - 1
627627
end,
628628
),
629629
)

src/Finch.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ export Dense, DenseLevel
4242
export Element, ElementLevel
4343
export AtomicElement, AtomicElementLevel
4444
export Separate, SeparateLevel
45+
export Shard, ShardLevel
4546
export Mutex, MutexLevel
4647
export Pattern, PatternLevel
4748
export Scalar, SparseScalar, ShortCircuitScalar, SparseShortCircuitScalar
@@ -142,6 +143,7 @@ include("tensors/levels/dense_rle_levels.jl")
142143
include("tensors/levels/element_levels.jl")
143144
include("tensors/levels/atomic_element_levels.jl")
144145
include("tensors/levels/separate_levels.jl")
146+
include("tensors/levels/shard_levels.jl")
145147
include("tensors/levels/mutex_levels.jl")
146148
include("tensors/levels/pattern_levels.jl")
147149
include("tensors/masks.jl")

src/architecture.jl

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ abstract type AbstractVirtualTask end
4141
Return the number of tasks on the device dev.
4242
"""
4343
function get_num_tasks end
44+
4445
"""
4546
get_task_num(task::AbstractTask)
4647
4748
Return the task number of `task`.
4849
"""
4950
function get_task_num end
51+
5052
"""
5153
get_device(task::AbstractTask)
5254
@@ -61,6 +63,25 @@ Return the task which spawned `task`.
6163
"""
6264
function get_parent_task end
6365

66+
get_num_tasks(ctx::AbstractCompiler) = get_num_tasks(get_task(ctx))
67+
get_num_tasks(task::AbstractTask) = get_num_tasks(get_device(task))
68+
get_task_num(ctx::AbstractCompiler) = get_task_num(get_task(ctx))
69+
get_device(ctx::AbstractCompiler) = get_device(get_task(ctx))
70+
get_parent_task(ctx::AbstractCompiler) = get_parent_task(get_task(ctx))
71+
72+
function is_on_device(ctx::AbstractCompiler, dev)
73+
res = false
74+
task = get_task(ctx)
75+
while task != nothing
76+
if get_device(task) == dev
77+
res = true
78+
break
79+
end
80+
task = get_parent_task(task)
81+
end
82+
return res
83+
end
84+
6485
"""
6586
aquire_lock!(dev::AbstractDevice, val)
6687
@@ -92,20 +113,35 @@ function make_lock end
92113
"""
93114
Serial()
94115
95-
A device that represents a serial CPU execution.
116+
A Task that represents a serial CPU execution.
96117
"""
97-
struct Serial <: AbstractTask end
118+
struct Serial <: AbstractDevice end
98119
const serial = Serial()
99-
get_device(::Serial) = CPU(1)
100-
get_parent_task(::Serial) = nothing
101-
get_task_num(::Serial) = 1
120+
get_num_tasks(::Serial) = 1
102121
struct VirtualSerial <: AbstractVirtualTask end
103122
virtualize(ctx, ex, ::Type{Serial}) = VirtualSerial()
104123
lower(ctx::AbstractCompiler, task::VirtualSerial, ::DefaultStyle) = :(Serial())
105124
FinchNotation.finch_leaf(device::VirtualSerial) = virtual(device)
106-
get_device(::VirtualSerial) = VirtualCPU(nothing, 1)
107-
get_parent_task(::VirtualSerial) = nothing
108-
get_task_num(::VirtualSerial) = literal(1)
125+
get_num_tasks(::VirtualSerial) = literal(1)
126+
Base.:(==)(::Serial, ::Serial) = true
127+
Base.:(==)(::VirtualSerial, ::VirtualSerial) = true
128+
129+
"""
130+
SerialTask()
131+
132+
A Task that represents a serial CPU execution.
133+
"""
134+
struct SerialTask <: AbstractDevice end
135+
get_device(::SerialTask) = Serial()
136+
get_parent_task(::SerialTask) = nothing
137+
get_task_num(::SerialTask) = 1
138+
struct VirtualSerialTask <: AbstractVirtualTask end
139+
virtualize(ctx, ex, ::Type{SerialTask}) = VirtualSerialTask()
140+
lower(ctx::AbstractCompiler, task::VirtualSerialTask, ::DefaultStyle) = :(SerialTask())
141+
FinchNotation.finch_leaf(device::VirtualSerialTask) = virtual(device)
142+
get_device(::VirtualSerialTask) = VirtualSerial()
143+
get_parent_task(::VirtualSerialTask) = nothing
144+
get_task_num(::VirtualSerialTask) = literal(1)
109145

110146
struct SerialMemory end
111147
struct VirtualSerialMemory end
@@ -148,6 +184,8 @@ function lower(ctx::AbstractCompiler, device::VirtualCPU, ::DefaultStyle)
148184
something(device.ex, :(CPU($(ctx(device.n)))))
149185
end
150186
get_num_tasks(::VirtualCPU) = literal(1)
187+
Base.:(==)(::CPU, ::CPU) = true
188+
Base.:(==)(::VirtualCPU, ::VirtualCPU) = true #This is not strictly true. A better approach would name devices, and give them parents so that we can be sure to parallelize through the processor hierarchy.
151189

152190
FinchNotation.finch_leaf(device::VirtualCPU) = virtual(device)
153191

@@ -212,7 +250,7 @@ function transfer(device::CPULocalMemory, arr::AbstractArray)
212250
CPULocalArray{A}(mem.device, [copy(arr) for _ in 1:(mem.device.n)])
213251
end
214252
function transfer(task::CPUThread, arr::CPULocalArray)
215-
if get_device(task) === arr.device
253+
if get_device(task) == arr.device
216254
temp = arr.data[task.tid]
217255
return temp
218256
else
@@ -223,6 +261,7 @@ function transfer(dst::AbstractArray, arr::AbstractArray)
223261
return arr
224262
end
225263

264+
226265
"""
227266
transfer(device, arr)
228267
@@ -484,8 +523,8 @@ for T in [
484523
end
485524
end
486525

487-
function virtual_parallel_region(f, ctx, ::Serial)
488-
contain(f, ctx)
526+
function virtual_parallel_region(f, ctx, ::VirtualSerial)
527+
contain(f, ctx; task=VirtualSerialTask())
489528
end
490529

491530
function virtual_parallel_region(f, ctx, device::VirtualCPU)

0 commit comments

Comments
 (0)