Skip to content

Commit 27ede53

Browse files
committed
fix: refinement type assert cast bug
1 parent 1762588 commit 27ede53

File tree

5 files changed

+51
-6
lines changed

5 files changed

+51
-6
lines changed

crates/erg_common/triple.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ impl<T> Triple<T, T> {
141141
Triple::Ok(a) | Triple::Err(a) => Some(a),
142142
}
143143
}
144+
145+
pub fn merge_or(self, default: T) -> T {
146+
match self {
147+
Triple::None => default,
148+
Triple::Ok(ok) => ok,
149+
Triple::Err(err) => err,
150+
}
151+
}
144152
}
145153

146154
impl<T, E: std::error::Error> Triple<T, E> {

crates/erg_compiler/context/compare.rs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,7 @@ impl Context {
12821282
/// union(Array(Int, 2), Array(Str, 3)) == Array(Int, 2) or Array(Int, 3)
12831283
/// union({ .a = Int }, { .a = Str }) == { .a = Int or Str }
12841284
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int }
1285+
/// union((A and B) or C) == (A or C) and (B or C)
12851286
/// ```
12861287
pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type {
12871288
if lhs == rhs {
@@ -1345,6 +1346,16 @@ impl Context {
13451346
_ => self.simple_union(lhs, rhs),
13461347
},
13471348
(other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other),
1349+
// (A and B) or C ==> (A or C) and (B or C)
1350+
(and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => {
1351+
let ands = and_t.ands();
1352+
let mut t = Type::Obj;
1353+
for branch in ands.iter() {
1354+
let union = self.union(branch, other);
1355+
t = and(t, union);
1356+
}
1357+
t
1358+
}
13481359
(t, Type::Never) | (Type::Never, t) => t.clone(),
13491360
// Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2)
13501361
(
@@ -1497,12 +1508,6 @@ impl Context {
14971508
self.intersection(&fv.crack(), other)
14981509
}
14991510
(Refinement(l), Refinement(r)) => Type::Refinement(self.intersection_refinement(l, r)),
1500-
(other, Refinement(refine)) | (Refinement(refine), other) => {
1501-
let other = other.clone().into_refinement();
1502-
let intersec = self.intersection_refinement(&other, refine);
1503-
self.try_squash_refinement(intersec)
1504-
.unwrap_or_else(Type::Refinement)
1505-
}
15061511
(Structural(l), Structural(r)) => self.intersection(l, r).structuralize(),
15071512
(Guard(l), Guard(r)) => {
15081513
if l.namespace == r.namespace && l.target == r.target {
@@ -1527,6 +1532,26 @@ impl Context {
15271532
(other, and @ And(_, _)) | (and @ And(_, _), other) => {
15281533
self.intersection_add(and, other)
15291534
}
1535+
// (A or B) and C == (A and C) or (B and C)
1536+
(or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => {
1537+
let ors = or_t.ors();
1538+
let mut t = Type::Never;
1539+
for branch in ors.iter() {
1540+
let isec = self.intersection(branch, other);
1541+
if branch.is_unbound_var() {
1542+
t = or(t, isec);
1543+
} else {
1544+
t = self.union(&t, &isec);
1545+
}
1546+
}
1547+
t
1548+
}
1549+
(other, Refinement(refine)) | (Refinement(refine), other) => {
1550+
let other = other.clone().into_refinement();
1551+
let intersec = self.intersection_refinement(&other, refine);
1552+
self.try_squash_refinement(intersec)
1553+
.unwrap_or_else(Type::Refinement)
1554+
}
15301555
// overloading
15311556
(l, r) if l.is_subr() && r.is_subr() => and(lhs.clone(), rhs.clone()),
15321557
_ => self.simple_intersection(lhs, rhs),

crates/erg_compiler/context/inquire.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3657,6 +3657,7 @@ impl Context {
36573657
/// ```erg
36583658
/// recover_typarams(Int, Nat) == Nat
36593659
/// recover_typarams(Array!(Int, _), Array(Nat, 2)) == Array!(Nat, 2)
3660+
/// recover_typarams(Str or NoneType, {"a", "b"}) == {"a", "b"}
36603661
/// ```
36613662
/// ```erg
36623663
/// # REVIEW: should be?

crates/erg_compiler/tests/infer.er

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ c_new x, y = C.new x, y
3030
C = Class Int
3131
C.
3232
new x, y = Self x + y
33+
34+
val!() =
35+
for! [{ "a": "b" }], (pkg as {Str: Str}) =>
36+
x = pkg.get("a", "c")
37+
assert x in {"b"}
38+
val!::return x
39+
"d"
40+
val = val!()

crates/erg_compiler/tests/test.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ fn _test_infer_types() -> Result<(), ()> {
8787
let c_new_t = func2(add_r, r, c.clone()).quantify();
8888
module.context.assert_var_type("c_new", &c_new_t)?;
8989
module.context.assert_attr_type(&c, "new", &c_new_t)?;
90+
module
91+
.context
92+
.assert_var_type("val", &v_enum(set! { "b".into(), "d".into() }))?;
9093
Ok(())
9194
}
9295

0 commit comments

Comments
 (0)