Skip to content

Commit 0cccd08

Browse files
committed
all joins now support method = :sort or :hash
1 parent 20990b0 commit 0cccd08

File tree

8 files changed

+1222
-139
lines changed

8 files changed

+1222
-139
lines changed

src/join/closejoin.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ end
308308

309309

310310
# border = :nearest | :missing | :none
311-
function _join_closejoin(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeunique = false, border = :nearest, mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, direction = :backward, inplace = false, tol = nothing, allow_exact_match = true, op = nothing) where T
311+
function _join_closejoin(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeunique = false, border = :nearest, mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, direction = :backward, inplace = false, tol = nothing, allow_exact_match = true, op = nothing, method = :sort) where T
312312
isempty(dsl) && return copy(dsl)
313313
if !allow_exact_match
314314
#aem is the function to check allow_exact_match
@@ -332,14 +332,24 @@ function _join_closejoin(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, m
332332
if DataAPI.refpool(_columns(dsr)[oncols_right[end]]) !== nothing
333333
nsfpaj = false
334334
end
335-
ranges = Vector{UnitRange{T}}(undef, nrow(dsl))
336-
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate && length(oncols_right) > 1; nsfpaj = nsfpaj)
335+
if length(oncols_left) > 1 && method == :hash
336+
ranges, a, idx, minval, reps, sz, right_cols_2= _find_ranges_for_join_using_hash(dsl, dsr, onleft[1:end-1], onright[1:end-1], mapformats, true, Val(T))
337+
filter!(!=(0), reps)
338+
pushfirst!(reps, 1)
339+
cumsum!(reps, reps)
340+
pop!(reps)
341+
grng = GIVENRANGE(idx, reps, Int[], length(reps))
342+
starts, idx, last_valid_range = _sort_for_join_after_hash(dsr, oncols_right[end], stable, alg, mapformats, nsfpaj, grng)
343+
else
344+
ranges = Vector{UnitRange{T}}(undef, nrow(dsl))
345+
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate && length(oncols_right) > 1; nsfpaj = nsfpaj)
337346

338-
for j in 1:(length(oncols_left) - 1)
339-
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j; nsfpaj = nsfpaj)
347+
for j in 1:(length(oncols_left) - 1)
348+
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j; nsfpaj = nsfpaj)
349+
end
340350
end
341351

342-
# if border = :none , we should :nearest direction
352+
# if border = :none , we should use :nearest direction
343353
_change_refpool_find_range_for_close!(ranges, dsl, dsr, idx, oncols_left, oncols_right, border == :none ? :nearest : direction, mapformats[1], mapformats[2], length(oncols_left); nsfpaj = nsfpaj)
344354
total_length = nrow(dsl)
345355

src/join/join.jl

Lines changed: 115 additions & 68 deletions
Large diffs are not rendered by default.

src/join/join_dict.jl

Lines changed: 167 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,42 @@ function _fill_ranges_for_dict_join!(ranges, dict, maxprob, _fl, _fr, x_l, x_r,
155155
end
156156

157157

158+
function _find_ranges_for_join_using_hash(dsl, dsr, onleft, onright, mapformats, makeunique, ::Val{T}) where T
159+
oncols_left = onleft
160+
oncols_right = onright
161+
right_cols = setdiff(1:length(index(dsr)), oncols_right)
162+
if !makeunique && !isempty(intersect(_names(dsl), _names(dsr)[right_cols]))
163+
throw(ArgumentError("duplicate column names, pass `makeunique = true` to make them unique using a suffix automatically." ))
164+
end
165+
166+
cols = Any[]
167+
for j in 1:length(oncols_left)
168+
if mapformats[1]
169+
fl = getformat(dsl, oncols_left[j])
170+
else
171+
fl = identity
172+
end
173+
if mapformats[2]
174+
fr = getformat(dsr, oncols_right[j])
175+
else
176+
fr = identity
177+
end
178+
push!(cols, Cat2Vec(_columns(dsl)[oncols_left[j]], _columns(dsr)[oncols_right[j]], fl, fr))
179+
end
180+
newds = Dataset(cols, :auto, copycols = false)
181+
a = _gather_groups(newds, :, nrow(newds)< typemax(Int32) ? Val(Int32) : Val(Int64), stable = false, mapformats = false)
182+
183+
reps = _find_counts_for_join(view(a[1], nrow(dsl)+1:length(a[1])), a[3])
184+
gslots, minval, sz = _create_dictionary_for_join_int(identity, view(a[1], nrow(dsl)+1:length(a[1])), reps, 1, a[3], Val(T))
185+
186+
ranges = Vector{UnitRange{T}}(undef, nrow(dsl))
187+
where = Vector{T}(undef, length(reps)+1)
188+
cumsum!(view(where, 2:length(where)), reps)
189+
where[1] = 0
190+
_find_range_for_join!(ranges, view(a[1], 1:nrow(dsl)), gslots, reps, where, 1, sz)
191+
ranges, a, gslots, minval, reps, sz, right_cols
192+
end
193+
158194
function _join_left_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T}; makeunique = makeunique, mapformats = mapformats, check = check ) where T
159195
_fl = _date_valueidentity
160196
_fr = _date_valueidentity
@@ -231,7 +267,9 @@ function _join_left!_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T
231267
end
232268

233269
_fill_ranges_for_dict_join!(ranges, dict, maxprob, _fl, _fr, _columns(dsl)[onleft[1]], _columns(dsr)[onright[1]], sz, type)
234-
270+
if !all(x->length(x) <= 1, ranges)
271+
throw(ArgumentError("`leftjoin!` can only be used when each observation in left data set matches at most one observation from right data set"))
272+
end
235273
new_ends = map(x -> max(1, length(x)), ranges)
236274
cumsum!(new_ends, new_ends)
237275
total_length = new_ends[end]
@@ -241,10 +279,10 @@ function _join_left!_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T
241279
end
242280

243281
for j in 1:length(right_cols)
244-
_res = allocatecol(_columns(dsr)[right_cols[j]], total_length, addmissing = false)
282+
_res = allocatecol(_columns(dsr)[right_cols[j]], total_length)
245283
if DataAPI.refpool(_res) !== nothing
246-
# fill_val = DataAPI.invrefpool(_res)[missing]
247-
_fill_right_cols_table_left!(_res.refs, DataAPI.refarray(_columns(dsr)[right_cols[j]]), ranges, new_ends, total_length, missing)
284+
fill_val = DataAPI.invrefpool(_res)[missing]
285+
_fill_right_cols_table_left!(_res.refs, DataAPI.refarray(_columns(dsr)[right_cols[j]]), ranges, new_ends, total_length, fill_val)
248286
else
249287
_fill_right_cols_table_left!(_res, _columns(dsr)[right_cols[j]], ranges, new_ends, total_length, missing)
250288
end
@@ -376,16 +414,6 @@ function _join_outer_dict(dsl, dsr, ranges, onleft, onright, oncols_left, oncols
376414

377415
end
378416

379-
function _in_use_Set(ldata, rdata, _fl, _fr)
380-
ss = Set(Base.Generator(_fr, rdata));
381-
res = Vector{Bool}(undef, length(ldata))
382-
Threads.@threads for i in 1:length(res)
383-
res[i] = _fl(ldata[i]) in ss
384-
end
385-
res
386-
end
387-
388-
389417
function _update!_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T}; allowmissing = true, mode = :all, mapformats = [true, true], stable = false, alg = HeapSort) where T
390418
_fl = _date_valueidentity
391419
_fr = _date_valueidentity
@@ -424,3 +452,128 @@ function _update!_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T};
424452
_modified(_attributes(dsl))
425453
true, dsl
426454
end
455+
456+
457+
# a new idea for joining without sorting
458+
function _in_hash(dsl::AbstractDataset, dsr::AbstractDataset, ::Val{T}; onleft, onright, mapformats = [true, true]) where T
459+
isempty(dsl) && return Bool[]
460+
oncols_left = onleft
461+
oncols_right = onright
462+
463+
# use Set when there is only one column in `on`
464+
cols = Any[]
465+
for j in 1:length(oncols_left)
466+
if mapformats[1]
467+
fl = getformat(dsl, oncols_left[j])
468+
else
469+
fl = identity
470+
end
471+
if mapformats[2]
472+
fr = getformat(dsr, oncols_right[j])
473+
else
474+
fr = identity
475+
end
476+
push!(cols, Cat2Vec(_columns(dsl)[oncols_left[j]], _columns(dsr)[oncols_right[j]], fl, fr))
477+
end
478+
newds = Dataset(cols, :auto, copycols = false)
479+
a = _gather_groups(newds, :, nrow(newds)< typemax(Int32) ? Val(Int32) : Val(Int64), stable = false, mapformats = false)
480+
res = _in_use_Set_int(view(a[1], 1:nrow(dsl)), view(a[1], nrow(dsl)+1:length(a[1])), 1, a[3])
481+
end
482+
483+
function _create_Set_for_join_int(f, v, minval, rangelen)
484+
flag = false
485+
offset = 1 - minval
486+
n = length(v)
487+
sz = rangelen + 1
488+
gslots = falses(sz)
489+
@inbounds for i in 1:length(v)
490+
_fv = f(v[i])
491+
if ismissing(_fv)
492+
slotix = sz
493+
else
494+
slotix = _fv + offset
495+
end
496+
if !gslots[slotix]
497+
gslots[slotix] = true
498+
end
499+
end
500+
gslots, minval, sz
501+
end
502+
503+
function _query_Set_for_join_int(f, fv, gslots, minval, sz)
504+
offset = 1 - minval
505+
if ismissing(fv)
506+
slotix = sz
507+
else
508+
slotix = fv + offset
509+
!(slotix in 1:sz-1) && return 0
510+
end
511+
512+
if slotix in 1:sz
513+
rowid = gslots[slotix]
514+
return rowid
515+
end
516+
false
517+
end
518+
519+
function _in_use_Set_int_barrier!(res, ldata, gslots, minval, sz)
520+
Threads.@threads for i in 1:length(res)
521+
res[i] = _query_Set_for_join_int(identity, ldata[i], gslots, minval, sz)
522+
end
523+
end
524+
525+
function _in_use_Set_int(ldata, rdata, minval, rangelen)
526+
gslots, minval, sz = _create_Set_for_join_int(identity, rdata, minval, rangelen)
527+
res = Vector{Bool}(undef, length(ldata))
528+
_in_use_Set_int_barrier!(res, ldata, gslots, minval, sz)
529+
res
530+
end
531+
532+
# f is a function which should be applied on each element of v
533+
# v is a vector of Int with minimum minval and range length of rangelen
534+
# reps gives how many times a specific integer will appear in v
535+
# no missing in v
536+
function _create_dictionary_for_join_int(f, v, reps, minval, rangelen, ::Val{T}) where T
537+
offset = 1 - minval
538+
n = length(v)
539+
where = Vector{T}(undef, length(reps)+1)
540+
cumsum!(view(where, 2:length(where)), reps)
541+
where[1] = 0
542+
gslots = zeros(T, where[end])
543+
sz = rangelen
544+
@inbounds for i in 1:length(v)
545+
_fv = f(v[i])
546+
slotix = _fv + offset
547+
gslots[where[slotix]+1] = i
548+
where[slotix] += 1
549+
end
550+
gslots, minval, sz
551+
end
552+
# there is no missing in v
553+
# here can be defined as:
554+
# where = Vector{T}(undef, length(reps)+1)
555+
# cumsum!(view(where, 2:length(where)), reps)
556+
# where[1] = 0
557+
function _query_dictionary_for_join_int(f, v, gslots,reps, where, minval, sz)
558+
offset = 1 - minval
559+
slotix = v + offset
560+
!(slotix in 1:sz) && return 1:0
561+
if slotix in 1:sz
562+
rowid = reps[slotix]
563+
rowid == 0 && return 1:0
564+
return where[slotix]+1:where[slotix+1]
565+
end
566+
end
567+
568+
569+
570+
571+
function _in_use_Set(ldata, rdata, _fl, _fr)
572+
573+
ss = Set(Base.Generator(_fr, rdata));
574+
res = Vector{Bool}(undef, length(ldata))
575+
Threads.@threads for i in 1:length(res)
576+
res[i] = _fl(ldata[i]) in ss
577+
end
578+
res
579+
end

0 commit comments

Comments
 (0)