Skip to content

Commit dd64ace

Browse files
committed
[breaking] compare is more powerful + extra features for joins + fix #53
* compare supports comparing two data sets with different number of rows. * compare supports comparing two data sets using key columns to match rows * all joins have `obs_id`/`obs_id_name` option * left/inner/outer have `multiple_match`/`multiple_match_name` option * `multiple_match` indicates rows in left data set which are repeated due to multiple match in the right data set.
1 parent 15ba771 commit dd64ace

File tree

12 files changed

+693
-145
lines changed

12 files changed

+693
-145
lines changed

docs/src/man/gallery.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@ This gallery contains some random questions about data manipulation that we foun
77
* [Tally across columns with variable condition](https://stackoverflow.com/questions/70501316/tally-across-columns-with-variable-condition-in-r) : I am trying to tally across columns of a data frame with values that exceed a corresponding limit variable.
88

99
```julia
10-
julia> ds
10+
julia> ds = Dataset([[1.66077, -1.05298, -0.499206, 2.47123, 2.45914, 1.14014],
11+
[0.75, 0.75, 0.75, 0.75, 0.75, 0.75],
12+
[0.709184, -2.53609, 0.0130659, -0.587867, 0.55786, 1.60398],
13+
[0.333, 0.333, 0.333, 0.333, 0.333, 0.333],
14+
[1.47438, 2.01485, 2.49006, 1.80345, 0.569928, 1.58403],
15+
[1, 1, 1, 1, 1, 1],
16+
[2.02678, 1.51587, 1.70535, 2.51628, 1.909, 0.794765],
17+
[1.25, 1.25, 1.25, 1.25, 1.25, 1.25]],
18+
["a", "a_lim", "b", "b_lim", "c", "c_lim", "d", "d_lim"])
1119
6×8 Dataset
1220
Row │ a a_lim b b_lim c c_lim d d_lim
1321
│ identity identity identity identity identity identity identity identity
@@ -22,7 +30,7 @@ julia> ds
2230

2331
julia> using Chain
2432
julia> @chain ds begin
25-
compare(_[!, r"lim"], _[!, Not(r"lim")], on = 1:4 .=> 1:4), eq = isless)
33+
compare(_[!, r"lim"], _[!, Not(r"lim")], cols = 1:4 .=> 1:4, eq = isless)
2634
byrow(count)
2735
end
2836
6-element Vector{Int32}:

docs/src/man/joins.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,3 +531,70 @@ julia> update(main, transaction, on = [:group, :id],
531531
6 │ G2 1 2.1 missing
532532
7 │ G2 2 0.0 2
533533
```
534+
535+
## `compare`
536+
537+
The `compare` function compares two data sets. When the columns which needed to be compared are specified via the `cols` keyword argument, `compare` compares the corresponding values in each row by calling `eq` on the actual or formatted values. By default, `compare` compares two values via the `isequal` function, however, users may pass any function (that returns `true`/`false`) via the `eq` keyword arguments. When the number of rows of two data sets are not matched, `compare` fills the output data set with `missing`. Users can pass key columns to perform comparing matched pairs of observations. The key columns can be passed via the `on` keyword argument. The `compare` function uses `outerjoin` to find the corresponding matches, this also means, the `compare` function can accept the arguments of `outerjoin`.
538+
539+
> To pass the `mapformats` keyword argument to `outerjoin` in `compare`, use the `on_mapformats` keyword argument, since the `mapformats` keyword argument in `compare` refers to how observations should be compared; based on actual values or formatted values.
540+
541+
By default, the output data set contains observations id when users pass the `on` keyword argument. When an observation exists in only one of the passed data sets, the observation id will be missing for the other one.
542+
543+
### Examples
544+
545+
```jldoctest
546+
julia> old = Dataset(Insurance_Id=[1,2,3,5],Business_Id=[10,20,30,50],
547+
Amount=[100,200,300,missing],
548+
Account_Id=["x1","x10","x5","x5"])
549+
4×4 Dataset
550+
Row │ Insurance_Id Business_Id Amount Account_Id
551+
│ identity identity identity identity
552+
│ Int64? Int64? Int64? String?
553+
─────┼─────────────────────────────────────────────────
554+
1 │ 1 10 100 x1
555+
2 │ 2 20 200 x10
556+
3 │ 3 30 300 x5
557+
4 │ 5 50 missing x5
558+
559+
julia> new = Dataset(Ins_Id=[1,3,2,4,3,2],
560+
B_Id=[10,40,30,40,30,20],
561+
AMT=[100,200,missing,-500,350,700],
562+
Ac_Id=["x1","x1","x10","x10","x7","x5"])
563+
6×4 Dataset
564+
Row │ Ins_Id B_Id AMT Ac_Id
565+
│ identity identity identity identity
566+
│ Int64? Int64? Int64? String?
567+
─────┼────────────────────────────────────────
568+
1 │ 1 10 100 x1
569+
2 │ 3 40 200 x1
570+
3 │ 2 30 missing x10
571+
4 │ 4 40 -500 x10
572+
5 │ 3 30 350 x7
573+
6 │ 2 20 700 x5
574+
575+
julia> eq_fun(x::Number, y::Number) = abs(x - y) <= 50
576+
eq_fun (generic function with 3 methods)
577+
578+
julia> eq_fun(x::AbstractString, y::AbstractString) = isequal(x,y)
579+
eq_fun (generic function with 2 methods)
580+
581+
julia> eq_fun(x,y) = missing
582+
eq_fun (generic function with 3 methods)
583+
584+
julia> compare(old, new,
585+
on = [1=>1,2=>2],
586+
cols = [:Amount=>:AMT, :Account_Id=>:Ac_Id],
587+
eq = eq_fun)
588+
7×6 Dataset
589+
Row │ Insurance_Id Business_Id obs_id_left obs_id_right Amount=>AMT Account_Id=>Ac_Id
590+
│ identity identity identity identity identity identity
591+
│ Int64? Int64? Int32? Int32? Bool? Bool?
592+
─────┼──────────────────────────────────────────────────────────────────────────────────────
593+
1 │ 1 10 1 1 true true
594+
2 │ 2 20 2 6 false false
595+
3 │ 3 30 3 5 true false
596+
4 │ 5 50 4 missing missing missing
597+
5 │ 2 30 missing 3 missing missing
598+
6 │ 3 40 missing 2 missing missing
599+
7 │ 4 40 missing 4 missing missing
600+
```

src/InMemoryDatasets.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ include("join/join.jl")
174174
include("join/join_dict.jl")
175175
include("join/closejoin.jl")
176176
include("join/update.jl")
177+
include("join/compare.jl")
177178
include("join/main.jl")
178179

179180
include("abstractdataset/iteration.jl")

src/dataset/other.jl

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -889,107 +889,6 @@ function dropmissing!(ds::Dataset,
889889
ds
890890
end
891891

892-
function _compare_barrier_function_threaded!(_res, xl, xr, fl, fr, eq_fun)
893-
Threads.@threads for i in 1:length(xl)
894-
_res[i] = eq_fun(fl(xl[i]), fr(xr[i]))
895-
end
896-
_res
897-
end
898-
function _compare_barrier_function!(_res, xl, xr, fl, fr, eq_fun)
899-
for i in 1:length(xl)
900-
_res[i] = eq_fun(fl(xl[i]), fr(xr[i]))
901-
end
902-
_res
903-
end
904-
905-
906-
"""
907-
compare(ds1::AbstractDataset, ds2::AbstractDataset; [on = nothing, eq = isequal, mapformats = false, threads = true])
908-
909-
Compare values of two data sets column by column. It returns a boolean data set which is the result of calling `eq` on each value of
910-
corresponding columns. The `on` keyword can be used to specifiy the pair of columns which is needed to be compared. The `mapformats` keyword
911-
controls whether the actual values or the formatted values should be compared.
912-
913-
```julia
914-
julia> ds1 = Dataset(x = 1:9, y = 9:-1:1);
915-
julia> ds2 = Dataset(x = 1:9, y2 = 9:-1:1, y3 = 1:9);
916-
julia> compare(ds1, ds2, on = [:x=>:x, :y=>:y2])
917-
9×2 Dataset
918-
Row │ x=>x y=>y2
919-
│ identity identity
920-
│ Bool? Bool?
921-
─────┼────────────────────
922-
1 │ true true
923-
2 │ true true
924-
3 │ true true
925-
4 │ true true
926-
5 │ true true
927-
6 │ true true
928-
7 │ true true
929-
8 │ true true
930-
9 │ true true
931-
932-
julia> compare(ds1, ds2, on = [:x=>:x, :y=>:y3])
933-
9×2 Dataset
934-
Row │ x=>x y=>y3
935-
│ identity identity
936-
│ Bool? Bool?
937-
─────┼────────────────────
938-
1 │ true false
939-
2 │ true false
940-
3 │ true false
941-
4 │ true false
942-
5 │ true true
943-
6 │ true false
944-
7 │ true false
945-
8 │ true false
946-
9 │ true false
947-
948-
```
949-
"""
950-
function compare(ds1::AbstractDataset, ds2::AbstractDataset; on = nothing, eq = isequal, mapformats = false, threads = true)
951-
if !(mapformats isa AbstractVector)
952-
mapformats = repeat([mapformats], 2)
953-
else
954-
length(mapformats) !== 2 && throw(ArgumentError("`mapformats` must be a Bool or a vector of Bool with size two"))
955-
end
956-
if on === nothing
957-
left_col_idx = 1:ncol(ds1)
958-
right_col_idx = index(ds2)[names(ds1)]
959-
elseif typeof(on) <: AbstractVector{<:Union{AbstractString, Symbol}}
960-
left_col_idx = index(ds1)[on]
961-
right_col_idx = index(ds2)[names(ds1)[left_col_idx]]
962-
elseif (typeof(on) <: AbstractVector{<:Pair{<:ColumnIndex, <:ColumnIndex}}) || (typeof(on) <: AbstractVector{<:Pair{<:AbstractString, <:AbstractString}})
963-
left_col_idx = index(ds1)[map(x->x.first, on)]
964-
right_col_idx = index(ds2)[map(x->x.second, on)]
965-
else
966-
throw(ArgumentError("`on` keyword must be a vector of column names or a vector of pairs of column names"))
967-
end
968-
969-
nrow(ds1) != nrow(ds2) && throw(ArgumentError("the number of rows for both data sets should be the same"))
970-
res = Dataset()
971-
for j in 1:length(left_col_idx)
972-
_res = allocatecol(Union{Bool, Missing}, nrow(ds1))
973-
fl = identity
974-
if mapformats[1]
975-
fl = getformat(ds1, left_col_idx[j])
976-
end
977-
fr = identity
978-
if mapformats[2]
979-
fr = getformat(ds2, right_col_idx[j])
980-
end
981-
if threads
982-
_compare_barrier_function_threaded!(_res, _columns(ds1)[left_col_idx[j]], _columns(ds2)[right_col_idx[j]], fl, fr, eq)
983-
else
984-
_compare_barrier_function!(_res, _columns(ds1)[left_col_idx[j]], _columns(ds2)[right_col_idx[j]], fl, fr, eq)
985-
end
986-
push!(_columns(res), _res)
987-
push!(index(res), Symbol(names(ds1)[left_col_idx[j]]* "=>" * names(ds2)[right_col_idx[j]]))
988-
end
989-
res
990-
end
991-
992-
993892

994893
"""
995894
describe(ds::AbstractDataset; cols=:, threads = true)

src/join/closejoin.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ end
308308

309309

310310
# border = :nearest | :missing | :none
311-
function _join_closejoin(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeunique = false, border = :nearest, mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, direction = :backward, inplace = false, tol = nothing, allow_exact_match = true, op = nothing, method = :sort, threads = true) where T
311+
function _join_closejoin(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeunique = false, border = :nearest, mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, direction = :backward, inplace = false, tol = nothing, allow_exact_match = true, op = nothing, method = :sort, threads = true, obs_id = false, obs_id_name = :obs_id) where T
312312
isempty(dsl) && return copy(dsl)
313313
if !allow_exact_match
314314
#aem is the function to check allow_exact_match
@@ -380,7 +380,16 @@ function _join_closejoin(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, m
380380
push!(index(newds), new_var_name)
381381
setformat!(newds, index(newds)[new_var_name], getformat(dsr, _names(dsr)[right_cols[j]]))
382382
end
383-
383+
if obs_id
384+
obs_id_name1 = Symbol(obs_id_name, "_left")
385+
obs_id_name2 = Symbol(obs_id_name, "_right")
386+
obs_id_left = allocatecol(T, total_length)
387+
obs_id_right = allocatecol(T, total_length)
388+
obs_id_left .= 1:nrow(dsl)
389+
_fill_right_cols_table_close!(obs_id_right, idx, ranges, total_length, border, missing, direction; nn = direction == :nearest, rnn = view(_columns(dsr)[oncols_right[end]], idx), lnn = _columns(dsl)[oncols_left[end]], tol = tol, aem = aem, op = op, threads = threads)
390+
insertcols!(newds, ncol(newds)+1, obs_id_name1 => obs_id_left, unsupported_copy_cols = false)
391+
insertcols!(newds, ncol(newds)+1, obs_id_name2 => obs_id_right, unsupported_copy_cols = false)
392+
end
384393
if inplace
385394
_modified(_attributes(newds))
386395
end

src/join/compare.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
function _fill_index_compare!(x, r)
2+
@simd for i in r
3+
x[i] = i
4+
end
5+
end
6+
function _compare(dsl, dsr, ::Val{T}; onleft, onright, cols_left, cols_right, check = true, mapformats = false, on_mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, method = :sort, threads = true, eq = isequal, obs_id_name = :obs_id, multiple_match = false, multiple_match_name = :multiple, drop_obs_id = true) where T
7+
names_left = names(dsl)[cols_left]
8+
names_right = names(dsr)[cols_right]
9+
if !(mapformats isa AbstractVector)
10+
mapformats = repeat([mapformats], 2)
11+
else
12+
length(mapformats) !== 2 && throw(ArgumentError("`mapformats` must be a Bool or a vector of Bool with size two"))
13+
end
14+
15+
if onleft == nothing
16+
n_dsl = nrow(dsl)
17+
n_dsr = nrow(dsr)
18+
total_length = max(n_dsl, n_dsr)
19+
obs_id_left = _missings(T, total_length)
20+
obs_id_right = _missings(T, total_length)
21+
_fill_index_compare!(obs_id_left, 1:n_dsl)
22+
_fill_index_compare!(obs_id_right, 1:n_dsr)
23+
res = Dataset(x1=obs_id_left, x2=obs_id_right, copycols = false)
24+
rename!(res, :x1=>Symbol(obs_id_name, "_left"), :x2=>Symbol(obs_id_name, "_right"))
25+
else
26+
res = outerjoin(dsl[!, onleft], dsr[!, onright], on = onleft .=> onright, check = check, mapformats = on_mapformats, stable = stable, alg = alg, accelerate = accelerate, method = method, threads = threads, obs_id = true, obs_id_name = obs_id_name, multiple_match = multiple_match, multiple_match_name = multiple_match_name)
27+
total_length = nrow(res)
28+
obs_cols = index(res)[[Symbol(obs_id_name, "_left"), Symbol(obs_id_name, "_right")]]
29+
obs_id_left = _columns(res)[obs_cols[1]]
30+
obs_id_right = _columns(res)[obs_cols[2]]
31+
end
32+
_info_cols = ncol(res)
33+
for j in 1:length(cols_left)
34+
fl = identity
35+
if mapformats[1]
36+
fl = getformat(dsl, cols_left[j])
37+
end
38+
fr = identity
39+
if mapformats[2]
40+
fr = getformat(dsr, cols_right[j])
41+
end
42+
_res = allocatecol(Bool, total_length)
43+
_compare_barrier_function!(_res, _columns(dsl)[cols_left[j]], _columns(dsr)[cols_right[j]], fl, fr, eq, obs_id_left, obs_id_right, threads)
44+
45+
push!(_columns(res), _res)
46+
push!(index(res), Symbol(names(dsl)[cols_left[j]]* "=>" * names(dsr)[cols_right[j]]))
47+
end
48+
if drop_obs_id
49+
select!(res, Not([Symbol(obs_id_name, "_left"), Symbol(obs_id_name, "_right")]))
50+
end
51+
res
52+
end
53+
54+
55+
function _compare_barrier_function!(_res, xl, xr, fl, fr, eq_fun, obs_id_left, obs_id_right, threads)
56+
@_threadsfor threads for i in 1:length(_res)
57+
if ismissing(obs_id_left[i]) || ismissing(obs_id_right[i])
58+
_res[i] = missing
59+
else
60+
_res[i] = eq_fun(fl(xl[obs_id_left[i]]), fr(xr[obs_id_right[i]]))
61+
end
62+
end
63+
_res
64+
end

0 commit comments

Comments
 (0)