Skip to content

Commit 99a7d32

Browse files
committed
In VNV, use typejoin rather than promote_type
1 parent bb83d93 commit 99a7d32

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

src/varnamedvector.jl

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,17 @@ julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x));
6565
julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y));
6666
6767
julia> vnv.vals
68-
10-element Vector{Float64}:
68+
10-element Vector{Real}:
6969
0.0
7070
0.0
7171
0.0
7272
0.0
73-
1.0
74-
2.0
75-
3.0
76-
4.0
77-
5.0
78-
6.0
73+
1
74+
2
75+
3
76+
4
77+
5
78+
6
7979
```
8080
8181
The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to
@@ -94,17 +94,17 @@ value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked a
9494
julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x));
9595
9696
julia> vnv.vals
97-
10-element Vector{Float64}:
97+
10-element Vector{Real}:
9898
46.0
9999
48.0
100100
0.0
101101
0.0
102-
1.0
103-
2.0
104-
3.0
105-
4.0
106-
5.0
107-
6.0
102+
1
103+
2
104+
3
105+
4
106+
5
107+
6
108108
109109
julia> println(vnv.num_inactive);
110110
Dict(1 => 2)
@@ -116,20 +116,20 @@ like `setindex!` and `getindex!` rather than directly accessing `vnv.vals`.
116116
117117
```jldoctest varnamedvector-struct
118118
julia> vnv[@varname(x)]
119-
2-element Vector{Float64}:
119+
2-element Vector{Real}:
120120
46.0
121121
48.0
122122
123123
julia> getindex_internal(vnv, :)
124-
8-element Vector{Float64}:
124+
8-element Vector{Real}:
125125
46.0
126126
48.0
127-
1.0
128-
2.0
129-
3.0
130-
4.0
131-
5.0
132-
6.0
127+
1
128+
2
129+
3
130+
4
131+
5
132+
6
133133
```
134134
"""
135135
struct VarNamedVector{
@@ -864,9 +864,15 @@ function loosen_types!!(
864864
if KNew <: K && VNew <: V && TNew <: T
865865
return vnv
866866
else
867-
vn_type = promote_type(K, KNew)
868-
val_type = promote_type(V, VNew)
869-
transform_type = promote_type(T, TNew)
867+
# We could use promote_type here, instead of typejoin. However, that would e.g.
868+
# cause Ints to be converted to Float64s, since
869+
# promote_type(Int, Float64) == Float64, which can cause problems. See
870+
# https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188.
871+
# Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing
872+
# and Missing, rather than falling back on Any. However, it's not exported.
873+
vn_type = typejoin(K, KNew)
874+
val_type = typejoin(V, VNew)
875+
transform_type = typejoin(T, TNew)
870876
# This function would work the same way if the first if statement a few lines above
871877
# was skipped, and we only checked for the below condition. However, the first one
872878
# is constant propagated away at compile time (at least on Julia v1.11.7), whereas
@@ -1171,20 +1177,20 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector)
11711177
# Determine `eltype` of `vals`.
11721178
T_left = eltype(left_vnv.vals)
11731179
T_right = eltype(right_vnv.vals)
1174-
T = promote_type(T_left, T_right)
1180+
T = typejoin(T_left, T_right)
11751181

11761182
# Determine `eltype` of `varnames`.
11771183
V_left = eltype(left_vnv.varnames)
11781184
V_right = eltype(right_vnv.varnames)
1179-
V = promote_type(V_left, V_right)
1185+
V = typejoin(V_left, V_right)
11801186
if !(V <: VarName)
11811187
V = VarName
11821188
end
11831189

11841190
# Determine `eltype` of `transforms`.
11851191
F_left = eltype(left_vnv.transforms)
11861192
F_right = eltype(right_vnv.transforms)
1187-
F = promote_type(F_left, F_right)
1193+
F = typejoin(F_left, F_right)
11881194

11891195
# Allocate.
11901196
varname_to_index = Dict{V,Int}()

0 commit comments

Comments
 (0)