Skip to content

Commit 3366043

Browse files
committed
fix: Dict::get
1 parent f6145d0 commit 3366043

File tree

9 files changed

+45
-23
lines changed

9 files changed

+45
-23
lines changed

crates/erg_compiler/context/initialize/classes.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,14 +2359,13 @@ impl Context {
23592359
)));
23602360
dict_.register_builtin_const(FUNC_AS_RECORD, Visibility::BUILTIN_PUBLIC, None, as_record);
23612361
let Def = type_q(TY_DEFAULT);
2362+
let K = type_q(TY_K);
2363+
let V = type_q(TY_V);
23622364
let get_t = no_var_fn_met(
2363-
dict_t.clone(),
2364-
vec![kw(KW_KEY, T.clone())],
2365+
dict! { K.clone() => V.clone() }.into(),
2366+
vec![kw(KW_KEY, K.clone())],
23652367
vec![kw_default(KW_DEFAULT, Def.clone(), NoneType)],
2366-
or(
2367-
proj_call(D.clone(), FUNDAMENTAL_GETITEM, vec![ty_tp(T.clone())]),
2368-
Def,
2369-
),
2368+
or(V.clone(), Def),
23702369
)
23712370
.quantify();
23722371
dict_.register_py_builtin(FUNC_GET, get_t, Some(FUNC_GET), 9);

crates/erg_compiler/context/initialize/const_func.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,21 +239,21 @@ pub(crate) fn sub_vdict_get<'d>(
239239
) -> Option<&'d ValueObj> {
240240
let mut matches = vec![];
241241
for (k, v) in dict.iter() {
242-
match (key, k) {
243-
(ValueObj::Type(idx), ValueObj::Type(kt))
244-
if ctx.subtype_of(&idx.typ().lower_bounded(), &kt.typ().lower_bounded()) =>
242+
if key == k {
243+
return Some(v);
244+
}
245+
match (ctx.convert_value_into_type(key.clone()), ctx.convert_value_into_type(k.clone())) {
246+
(Ok(idx), Ok(kt))
247+
if ctx.subtype_of(&idx.lower_bounded(), &kt.lower_bounded()) /*|| dict.len() == 1*/ =>
245248
{
246249
matches.push((idx, kt, v));
247250
}
248-
(idx, k) if idx == k => {
249-
return Some(v);
250-
}
251251
_ => {}
252252
}
253253
}
254254
for (idx, kt, v) in matches.into_iter() {
255255
let list = UndoableLinkedList::new();
256-
match ctx.undoable_sub_unify(idx.typ(), kt.typ(), &(), &list, None) {
256+
match ctx.undoable_sub_unify(&idx, &kt, &(), &list, None) {
257257
Ok(_) => {
258258
return Some(v);
259259
}
@@ -272,21 +272,24 @@ pub(crate) fn sub_tpdict_get<'d>(
272272
) -> Option<&'d TyParam> {
273273
let mut matches = vec![];
274274
for (k, v) in dict.iter() {
275-
match (<&Type>::try_from(key), <&Type>::try_from(k)) {
275+
if key == k {
276+
return Some(v);
277+
}
278+
match (
279+
ctx.convert_tp_into_type(key.clone()),
280+
ctx.convert_tp_into_type(k.clone()),
281+
) {
276282
(Ok(idx), Ok(kt))
277283
if ctx.subtype_of(&idx.lower_bounded(), &kt.lower_bounded()) || dict.len() == 1 =>
278284
{
279285
matches.push((idx, kt, v));
280286
}
281-
(_, _) if key == k => {
282-
return Some(v);
283-
}
284287
_ => {}
285288
}
286289
}
287290
for (idx, kt, v) in matches.into_iter() {
288291
let list = UndoableLinkedList::new();
289-
match ctx.undoable_sub_unify(idx, kt, &(), &list, None) {
292+
match ctx.undoable_sub_unify(&idx, &kt, &(), &list, None) {
290293
Ok(_) => {
291294
return Some(v);
292295
}

crates/erg_compiler/context/initialize/funcs.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,8 +1077,9 @@ impl Context {
10771077
);
10781078
let E = mono_q(TY_E, subtypeof(mono(EQ)));
10791079
let E2 = mono_q(TY_E, subtypeof(mono(IRREGULAR_EQ)));
1080-
let op_t = bin_op(E.clone(), E, Bool).quantify()
1081-
& bin_op(E2.clone(), E2.clone(), E2.proj(OUTPUT)).quantify();
1080+
let op_t = (bin_op(E.clone(), E, Bool).quantify()
1081+
& bin_op(E2.clone(), E2.clone(), E2.proj(OUTPUT)).quantify())
1082+
.with_default_intersec_index(0);
10821083
self.register_builtin_py_impl(
10831084
OP_EQ,
10841085
op_t.clone(),

crates/erg_compiler/lower.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,17 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
12211221
Some(guard(namespace, target, to))
12221222
}
12231223
TokenKind::Symbol if &op.content[..] == "isinstance" => {
1224-
let to = self.module.context.expr_to_type(rhs.clone()).ok()?;
1224+
// isinstance(x, (T, U)) => x: T or U
1225+
let to = if let ast::Expr::Tuple(ast::Tuple::Normal(tys)) = rhs {
1226+
tys.elems.pos_args.iter().fold(Type::Never, |acc, ex| {
1227+
let Ok(ty) = self.module.context.expr_to_type(ex.expr.clone()) else {
1228+
return acc;
1229+
};
1230+
self.module.context.union(&acc, &ty)
1231+
})
1232+
} else {
1233+
self.module.context.expr_to_type(rhs.clone()).ok()?
1234+
};
12251235
Some(guard(namespace, target, to))
12261236
}
12271237
TokenKind::IsOp | TokenKind::DblEq => {

tests/should_err/mut_dict.er

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ dict = !d
44
dict.insert! "b", 2
55
_ = dict.get("a") == "a" # ERR
66
_ = dict.get("b") == "a" # ERR
7-
_ = dict.get("c") # ERR
7+
_ = dict.get("c") # OK
8+
_ = dict["b"] # OK
9+
_ = dict["c"] # ERR

tests/should_err/refinement.er

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ _: {I: Int | (I < 5 or I != 3) and I != 4} = 4 # ERR
1313

1414
check _: {S: Str | S.replace("abc", "") == ""} = None
1515
check "abcd" # ERR
16+
17+
dic as Dict({{111}: {222}}) = {111: 222}
18+
_ = dic[333] # ERR

tests/should_ok/dict.er

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ for! {"a": 1, "b": 2}.values(), i =>
66
dic = { "a": 1, "b": 2 }
77
assert dic.concat({ "c": 3 }) == { "a": 1, "b": 2, "c": 3 }
88
assert dic.diff({ "a": 1 }) == { "b": 2 }
9+
assert dic.get("a"+"b", 3) == 3
910
rec = dic.as_record()
1011
assert rec.a == 1 and rec.b == 2

tests/should_ok/refinement.er

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ _: {I: Int | I < 5 or I != 3 and I != 4} = 4
66

77
check _: {S: Str | S.replace("abc", "") == ""} = None
88
check "abc"
9+
10+
dic as Dict({{111}: {222}}) = {111: 222}
11+
_: {222} = dic[111]

tests/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ fn exec_recursive_fn_err() -> Result<(), ()> {
766766

767767
#[test]
768768
fn exec_refinement_err() -> Result<(), ()> {
769-
expect_failure("tests/should_err/refinement.er", 0, 9)
769+
expect_failure("tests/should_err/refinement.er", 0, 10)
770770
}
771771

772772
#[test]

0 commit comments

Comments
 (0)