|
409 | 409 | function _gather_groups(ds, cols, ::Val{T}; mapformats = false, stable = true, threads = true) where T
|
410 | 410 | colidx = index(ds)[cols]
|
411 | 411 | _max_level = nrow(ds)
|
| 412 | + |
| 413 | + |
| 414 | + if nrow(ds) > 2^23 && !stable && 5<length(colidx)<16 # the result is stable anyway |
| 415 | + if !mapformats || all(==(identity), getformat.(Ref(ds), colidx)) |
| 416 | + return _gather_groups_hugeds_multicols(ds, cols, Val(T); threads = threads) |
| 417 | + end |
| 418 | + end |
| 419 | + |
| 420 | + |
412 | 421 | prev_max_group = UInt(1)
|
413 | 422 | prev_groups = ones(T, nrow(ds))
|
414 | 423 | groups = T[]
|
@@ -510,6 +519,64 @@ function _find_groups_with_more_than_one_observation_barrier!(res, groups, seen_
|
510 | 519 | nothing
|
511 | 520 | end
|
512 | 521 |
|
| 522 | +### Special path for huge ds and multiple cols - trade off between compilation and performance |
| 523 | +# table columns are passed as a tuple of vectors to ensure type specialization - From DataFrames.jl |
| 524 | +isequal_row(cols::Tuple{AbstractVector}, r1::Int, r2::Int) = |
| 525 | + isequal(cols[1][r1], cols[1][r2]) |
| 526 | +isequal_row(cols::Tuple{Vararg{AbstractVector}}, r1::Int, r2::Int) = |
| 527 | + isequal(cols[1][r1], cols[1][r2]) && isequal_row(Base.tail(cols), r1, r2) |
| 528 | + |
| 529 | +isequal_row(cols1::Tuple{AbstractVector}, r1::Int, cols2::Tuple{AbstractVector}, r2::Int) = |
| 530 | + isequal(cols1[1][r1], cols2[1][r2]) |
| 531 | +isequal_row(cols1::Tuple{Vararg{AbstractVector}}, r1::Int, |
| 532 | + cols2::Tuple{Vararg{AbstractVector}}, r2::Int) = |
| 533 | + isequal(cols1[1][r1], cols2[1][r2]) && |
| 534 | + isequal_row(Base.tail(cols1), r1, Base.tail(cols2), r2) |
| 535 | + |
| 536 | + |
| 537 | +_grabrefs(x) = DataAPI.refpool(x) == nothing ? x : DataAPI.refarray(x) |
| 538 | +function _gather_groups_hugeds_multicols(ds, cols, ::Val{T}; threads = true) where T |
| 539 | + colidx = index(ds)[cols] |
| 540 | + rhashes = byrow(ds, hash, cols, threads = threads) |
| 541 | + colsvals = ntuple(i->_grabrefs(_columns(ds)[colidx[i]]), length(colidx)) |
| 542 | + create_dict_hugeds_multicols(colsvals, rhashes, Val(T)) |
| 543 | +end |
| 544 | + |
| 545 | +function create_dict_hugeds_multicols(colvals, rhashes, ::Val{T}) where T |
| 546 | + sz = max(1 + ((5 * length(rhashes)) >> 2), 16) |
| 547 | + sz = 1 << (8 * sizeof(sz) - leading_zeros(sz - 1)) |
| 548 | + @assert 4 * sz >= 5 * length(rhashes) |
| 549 | + szm1 = sz-1 |
| 550 | + gslots = zeros(T, sz) |
| 551 | + groups = Vector{T}(undef, length(rhashes)) |
| 552 | + ngroups = 0 |
| 553 | + @inbounds for i in eachindex(rhashes) |
| 554 | + # find the slot and group index for a row |
| 555 | + slotix = rhashes[i] & szm1 + 1 |
| 556 | + gix = -1 |
| 557 | + probe = 0 |
| 558 | + while true |
| 559 | + g_row = gslots[slotix] |
| 560 | + if g_row == 0 # unoccupied slot, current row starts a new group |
| 561 | + gslots[slotix] = i |
| 562 | + gix = ngroups += 1 |
| 563 | + break |
| 564 | + elseif rhashes[i] == rhashes[g_row] # occupied slot, check if miss or hit |
| 565 | + if isequal_row(colvals, i, Int(g_row)) # hit |
| 566 | + gix = groups[g_row] |
| 567 | + break |
| 568 | + end |
| 569 | + end |
| 570 | + slotix = slotix & szm1 + 1 # check the next slot |
| 571 | + probe += 1 |
| 572 | + @assert probe < sz |
| 573 | + end |
| 574 | + groups[i] = gix |
| 575 | + end |
| 576 | + return groups, gslots, ngroups |
| 577 | +end |
| 578 | + |
| 579 | + |
513 | 580 | function _gather_groups_old_version(ds, cols, ::Val{T}; mapformats = false) where T
|
514 | 581 | colidx = index(ds)[cols]
|
515 | 582 | _max_level = nrow(ds)
|
|
0 commit comments