Skip to content

Commit 820bcbe

Browse files
committed
improve performance
1 parent 19e13b9 commit 820bcbe

File tree

5 files changed

+109
-6
lines changed

5 files changed

+109
-6
lines changed

src/byrow/byrow.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,15 @@ byrow(ds::AbstractDataset, ::typeof(stdze), cols::MultiColumnIndex = names(ds, U
140140

141141
byrow(ds::AbstractDataset, ::typeof(stdze!), cols::MultiColumnIndex = names(ds, Union{Missing, Number}); threads = true) = row_stdze!(ds, cols, threads = threads)
142142

143-
byrow(ds::AbstractDataset, ::typeof(hash), cols::MultiColumnIndex = :; by = identity, threads = nrow(ds) > __NCORES*10) = row_hash(ds, by, cols, threads = threads)
144-
byrow(ds::AbstractDataset, ::typeof(hash), col::ColumnIndex; by = identity, threads = nrow(ds) > __NCORES*10) = byrow(ds, hash, [col]; by = by, threads = threads)
143+
function byrow(ds::AbstractDataset, ::typeof(hash), cols::MultiColumnIndex = :; by = identity, mapformats = false, threads = nrow(ds) > __NCORES*10)
144+
colsidx = multiple_getindex(index(ds), cols)
145+
if mapformats
146+
by = map(y->expand_Base_Fix(by, getformat(ds, y)), colsidx)
147+
end
148+
row_hash(ds, by, cols, threads = threads)
149+
end
150+
151+
byrow(ds::AbstractDataset, ::typeof(hash), col::ColumnIndex; by = identity, mapformats = false, threads = nrow(ds) > __NCORES*10) = byrow(ds, hash, [col]; by = by, mapformats = mapformats, threads = threads)
145152

146153
byrow(ds::AbstractDataset, ::typeof(join), col::MultiColumnIndex; threads = nrow(ds) > __NCORES*10, delim = "", last = "") = row_join(ds, col, threads = threads, delim = delim, last = last)
147154

src/byrow/row_functions.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,19 +1096,33 @@ Base.@propagate_inbounds function _op_for_hash!(x, y, f, lo, hi)
10961096
x
10971097
end
10981098

1099-
function row_hash(ds::AbstractDataset, f::Function, cols = :; threads = true)
1099+
function row_hash(ds::AbstractDataset, f::Union{AbstractVector{<:Function}, Function}, cols = :; threads = true)
11001100
colsidx = multiple_getindex(index(ds), cols)
11011101
init0 = zeros(UInt64, nrow(ds))
11021102

1103+
multi_f = false
1104+
if f isa AbstractVector
1105+
@assert length(f) == length(colsidx) "number of provided functions must match the number of selected columns"
1106+
multi_f = true
1107+
end
1108+
11031109
if threads
11041110
cz = div(length(init0), __NCORES)
11051111
Threads.@threads for i in 1:__NCORES
11061112
lo = (i-1)*cz+1
11071113
i == __NCORES ? hi = length(init0) : hi = i*cz
1108-
mapreduce(identity, (x,y) -> _op_for_hash!(x, y, f, lo, hi), view(_columns(ds),colsidx), init = init0)
1114+
if multi_f
1115+
mapreduce_index(f, (x, y, func) -> _op_for_hash!(x, y, func, lo, hi), view(_columns(ds),colsidx), init0)
1116+
else
1117+
mapreduce(identity, (x,y) -> _op_for_hash!(x, y, f, lo, hi), view(_columns(ds),colsidx), init = init0)
1118+
end
11091119
end
11101120
else
1111-
mapreduce(identity, (x,y) -> _op_for_hash!(x, y, f, 1, length(x)), view(_columns(ds),colsidx), init = init0)
1121+
if multi_f
1122+
mapreduce_index(f, (x, y, func) -> _op_for_hash!(x, y, func, 1, length(x)), view(_columns(ds),colsidx), init0)
1123+
else
1124+
mapreduce(identity, (x,y) -> _op_for_hash!(x, y, f, 1, length(x)), view(_columns(ds),colsidx), init = init0)
1125+
end
11121126
end
11131127
init0
11141128
end

src/other/utils.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,15 @@ end
409409
function _gather_groups(ds, cols, ::Val{T}; mapformats = false, stable = true, threads = true) where T
410410
colidx = index(ds)[cols]
411411
_max_level = nrow(ds)
412+
413+
414+
if nrow(ds) > 2^23 && !stable && 5<length(colidx)<16 # the result is stable anyway
415+
if !mapformats || all(==(identity), getformat.(Ref(ds), colidx))
416+
return _gather_groups_hugeds_multicols(ds, cols, Val(T); threads = threads)
417+
end
418+
end
419+
420+
412421
prev_max_group = UInt(1)
413422
prev_groups = ones(T, nrow(ds))
414423
groups = T[]
@@ -510,6 +519,64 @@ function _find_groups_with_more_than_one_observation_barrier!(res, groups, seen_
510519
nothing
511520
end
512521

522+
### Special path for huge ds and multiple cols - trade off between compilation and performance
523+
# table columns are passed as a tuple of vectors to ensure type specialization - From DataFrames.jl
524+
isequal_row(cols::Tuple{AbstractVector}, r1::Int, r2::Int) =
525+
isequal(cols[1][r1], cols[1][r2])
526+
isequal_row(cols::Tuple{Vararg{AbstractVector}}, r1::Int, r2::Int) =
527+
isequal(cols[1][r1], cols[1][r2]) && isequal_row(Base.tail(cols), r1, r2)
528+
529+
isequal_row(cols1::Tuple{AbstractVector}, r1::Int, cols2::Tuple{AbstractVector}, r2::Int) =
530+
isequal(cols1[1][r1], cols2[1][r2])
531+
isequal_row(cols1::Tuple{Vararg{AbstractVector}}, r1::Int,
532+
cols2::Tuple{Vararg{AbstractVector}}, r2::Int) =
533+
isequal(cols1[1][r1], cols2[1][r2]) &&
534+
isequal_row(Base.tail(cols1), r1, Base.tail(cols2), r2)
535+
536+
537+
_grabrefs(x) = DataAPI.refpool(x) == nothing ? x : DataAPI.refarray(x)
538+
function _gather_groups_hugeds_multicols(ds, cols, ::Val{T}; threads = true) where T
539+
colidx = index(ds)[cols]
540+
rhashes = byrow(ds, hash, cols, threads = threads)
541+
colsvals = ntuple(i->_grabrefs(_columns(ds)[colidx[i]]), length(colidx))
542+
create_dict_hugeds_multicols(colsvals, rhashes, Val(T))
543+
end
544+
545+
function create_dict_hugeds_multicols(colvals, rhashes, ::Val{T}) where T
546+
sz = max(1 + ((5 * length(rhashes)) >> 2), 16)
547+
sz = 1 << (8 * sizeof(sz) - leading_zeros(sz - 1))
548+
@assert 4 * sz >= 5 * length(rhashes)
549+
szm1 = sz-1
550+
gslots = zeros(T, sz)
551+
groups = Vector{T}(undef, length(rhashes))
552+
ngroups = 0
553+
@inbounds for i in eachindex(rhashes)
554+
# find the slot and group index for a row
555+
slotix = rhashes[i] & szm1 + 1
556+
gix = -1
557+
probe = 0
558+
while true
559+
g_row = gslots[slotix]
560+
if g_row == 0 # unoccupied slot, current row starts a new group
561+
gslots[slotix] = i
562+
gix = ngroups += 1
563+
break
564+
elseif rhashes[i] == rhashes[g_row] # occupied slot, check if miss or hit
565+
if isequal_row(colvals, i, Int(g_row)) # hit
566+
gix = groups[g_row]
567+
break
568+
end
569+
end
570+
slotix = slotix & szm1 + 1 # check the next slot
571+
probe += 1
572+
@assert probe < sz
573+
end
574+
groups[i] = gix
575+
end
576+
return groups, gslots, ngroups
577+
end
578+
579+
513580
function _gather_groups_old_version(ds, cols, ::Val{T}; mapformats = false) where T
514581
colidx = index(ds)[cols]
515582
_max_level = nrow(ds)

src/precompile/warmup.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ function warmup()
3434
combine(gatherby(ds,1), Ref([1,2,3,7,8]) .=> [median, sort])
3535
combine(gatherby(ds,1), Ref([1,2,3,7,8]) .=> [sum, mean, length, maximum, minimum, var, std])
3636
combine(gatherby(ds,1), r"x1$" .=> [sum, mean, length, maximum, minimum, var, std])
37+
IMD._gather_groups_hugeds_multicols(ds, 1:6, Val(Int32), threads = true)
3738

3839
ds2 = ds[1:2, [1,3,7]]
3940
innerjoin(ds, ds2, on = [:x1, :x3, :x7])

src/sort/gatherby.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,27 @@ function gatherby(ds::AbstractDataset, cols::MultiColumnIndex; mapformats::Bool
104104
return GatherBy(ds, colsidx, 1:nrow(ds), nrow(ds), mapformats, b[1], 1:nrow(ds))
105105
else
106106
a = _gather_groups(ds, colsidx, Val(T), mapformats = mapformats, stable = stable, threads = threads)
107-
return GatherBy(ds, colsidx, a[1], a[3], mapformats, nothing, nothing)
107+
return GatherBy(ds, colsidx, a[1], a[3], mapformats, nothing, nothing)
108108
end
109109
end
110110
end
111111
gatherby(ds::AbstractDataset, col::ColumnIndex; mapformats = true, stable = true, isgathered = false, eachrow = false, threads = true) = gatherby(ds, [col], mapformats = mapformats, stable = stable, isgathered = isgathered, eachrow = eachrow, threads = threads)
112112

113113

114+
__SPFRMT(x) = x & 1023
115+
__SPFRMT(::Missing) = missing # not needed
116+
117+
# currently not been used in gatherby
118+
# use sort and format trick for fast gatherby - hm stands for high memory footprint
119+
function hm_gatherby(ds::AbstractDataset, cols::MultiColumnIndex; mapformats = false, threads = true)
120+
modify!(ds, cols=>byrow(hash; threads = threads, mapformats = mapformats)=>:___tmp___cols8934, :___tmp___cols8934=>identity=>:___tmp___cols8934_2)
121+
setformat!(ds, :___tmp___cols8934_2=>__SPFRMT)
122+
gds = groupby(ds, [:___tmp___cols8934_2, :___tmp___cols8934], stable = false, threads = threads)
123+
grpcols, ranges, last_valid_index = _find_starts_of_groups(view(ds, gds.perm, cols), cols, nrow(ds) < typemax(Int32) ? Val(Int32) : Val(Int64); mapformats = mapformats, threads = threads)
124+
select!(ds, Not([:___tmp___cols8934, :___tmp___cols8934_2]))
125+
GatherBy(ds, grpcols, nothing, last_valid_index, mapformats, gds.perm, ranges)
126+
end
127+
114128
function _fill_mapreduce_col!(x, f, op, y, loc)
115129
@inbounds for i in 1:length(y)
116130
x[loc[i]] = op(x[loc[i]], f(y[i]))

0 commit comments

Comments
 (0)