|
| 1 | +function _fill_index_compare!(x, r) |
| 2 | + @simd for i in r |
| 3 | + x[i] = i |
| 4 | + end |
| 5 | +end |
| 6 | +function _compare(dsl, dsr, ::Val{T}; onleft, onright, cols_left, cols_right, check = true, mapformats = false, on_mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, method = :sort, threads = true, eq = isequal, obs_id_name = :obs_id, multiple_match = false, multiple_match_name = :multiple, drop_obs_id = true) where T |
| 7 | + names_left = names(dsl)[cols_left] |
| 8 | + names_right = names(dsr)[cols_right] |
| 9 | + if !(mapformats isa AbstractVector) |
| 10 | + mapformats = repeat([mapformats], 2) |
| 11 | + else |
| 12 | + length(mapformats) !== 2 && throw(ArgumentError("`mapformats` must be a Bool or a vector of Bool with size two")) |
| 13 | + end |
| 14 | + |
| 15 | + if onleft == nothing |
| 16 | + n_dsl = nrow(dsl) |
| 17 | + n_dsr = nrow(dsr) |
| 18 | + total_length = max(n_dsl, n_dsr) |
| 19 | + obs_id_left = _missings(T, total_length) |
| 20 | + obs_id_right = _missings(T, total_length) |
| 21 | + _fill_index_compare!(obs_id_left, 1:n_dsl) |
| 22 | + _fill_index_compare!(obs_id_right, 1:n_dsr) |
| 23 | + res = Dataset(x1=obs_id_left, x2=obs_id_right, copycols = false) |
| 24 | + rename!(res, :x1=>Symbol(obs_id_name, "_left"), :x2=>Symbol(obs_id_name, "_right")) |
| 25 | + else |
| 26 | + res = outerjoin(dsl[!, onleft], dsr[!, onright], on = onleft .=> onright, check = check, mapformats = on_mapformats, stable = stable, alg = alg, accelerate = accelerate, method = method, threads = threads, obs_id = true, obs_id_name = obs_id_name, multiple_match = multiple_match, multiple_match_name = multiple_match_name) |
| 27 | + total_length = nrow(res) |
| 28 | + obs_cols = index(res)[[Symbol(obs_id_name, "_left"), Symbol(obs_id_name, "_right")]] |
| 29 | + obs_id_left = _columns(res)[obs_cols[1]] |
| 30 | + obs_id_right = _columns(res)[obs_cols[2]] |
| 31 | + end |
| 32 | + _info_cols = ncol(res) |
| 33 | + for j in 1:length(cols_left) |
| 34 | + fl = identity |
| 35 | + if mapformats[1] |
| 36 | + fl = getformat(dsl, cols_left[j]) |
| 37 | + end |
| 38 | + fr = identity |
| 39 | + if mapformats[2] |
| 40 | + fr = getformat(dsr, cols_right[j]) |
| 41 | + end |
| 42 | + _res = allocatecol(Bool, total_length) |
| 43 | + _compare_barrier_function!(_res, _columns(dsl)[cols_left[j]], _columns(dsr)[cols_right[j]], fl, fr, eq, obs_id_left, obs_id_right, threads) |
| 44 | + |
| 45 | + push!(_columns(res), _res) |
| 46 | + push!(index(res), Symbol(names(dsl)[cols_left[j]]* "=>" * names(dsr)[cols_right[j]])) |
| 47 | + end |
| 48 | + if drop_obs_id |
| 49 | + select!(res, Not([Symbol(obs_id_name, "_left"), Symbol(obs_id_name, "_right")])) |
| 50 | + end |
| 51 | + res |
| 52 | +end |
| 53 | + |
| 54 | + |
| 55 | +function _compare_barrier_function!(_res, xl, xr, fl, fr, eq_fun, obs_id_left, obs_id_right, threads) |
| 56 | + @_threadsfor threads for i in 1:length(_res) |
| 57 | + if ismissing(obs_id_left[i]) || ismissing(obs_id_right[i]) |
| 58 | + _res[i] = missing |
| 59 | + else |
| 60 | + _res[i] = eq_fun(fl(xl[obs_id_left[i]]), fr(xr[obs_id_right[i]])) |
| 61 | + end |
| 62 | + end |
| 63 | + _res |
| 64 | +end |
0 commit comments