@@ -524,6 +524,17 @@ private Struct getRangeType(RangeExpr re) {
524524 result instanceof RangeToInclusiveStruct
525525}
526526
527+ private predicate bodyReturns ( Expr body , Expr e ) {
528+ exists ( ReturnExpr re , Callable c |
529+ e = re .getExpr ( ) and
530+ c = re .getEnclosingCallable ( )
531+ |
532+ body = c .( Function ) .getBody ( )
533+ or
534+ body = c .( ClosureExpr ) .getBody ( )
535+ )
536+ }
537+
527538/**
528539 * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
529540 * of `n2` at `prefix2` and type information should propagate in both directions
@@ -540,9 +551,11 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
540551 let .getInitializer ( ) = n2
541552 )
542553 or
543- n1 = n2 .( IfExpr ) .getABranch ( )
544- or
545- n1 = n2 .( MatchExpr ) .getAnArm ( ) .getExpr ( )
554+ n1 =
555+ any ( MatchExpr me |
556+ n1 = me .getAnArm ( ) .getExpr ( ) and
557+ me .getNumberOfArms ( ) = 1
558+ )
546559 or
547560 exists ( LetExpr let |
548561 n1 = let .getScrutinee ( ) and
@@ -573,6 +586,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
573586 n1 = n2 .( MacroExpr ) .getMacroCall ( ) .getMacroCallExpansion ( )
574587 or
575588 n1 = n2 .( MacroPat ) .getMacroCall ( ) .getMacroCallExpansion ( )
589+ or
590+ bodyReturns ( n1 , n2 ) and
591+ strictcount ( Expr e | bodyReturns ( n1 , e ) ) = 1
576592 )
577593 or
578594 (
@@ -606,8 +622,12 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
606622 )
607623 )
608624 or
609- // an array list expression (`[1, 2, 3]`) has the type of the first (any) element
610- n1 .( ArrayListExpr ) .getExpr ( _) = n2 and
625+ // an array list expression (`[1, 2, 3]`) has the type of the element
626+ n1 =
627+ any ( ArrayListExpr ale |
628+ ale .getAnExpr ( ) = n2 and
629+ ale .getNumberOfExprs ( ) = 1
630+ ) and
611631 prefix1 = TypePath:: singleton ( TArrayTypeParameter ( ) ) and
612632 prefix2 .isEmpty ( )
613633 or
@@ -635,6 +655,61 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
635655 prefix2 .isEmpty ( )
636656}
637657
658+ /**
659+ * Holds if `child` is a child of `parent`, and the Rust compiler applies [least
660+ * upper bound (LUB) coercion](1) to infer the type of `parent` from the type of
661+ * `child`.
662+ *
663+ * In this case, we want type information to only flow from `child` to `parent`,
664+ * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial
665+ * explosion in inferred types.
666+ *
667+ * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound
668+ */
669+ private predicate lubCoercion ( AstNode parent , AstNode child , TypePath prefix ) {
670+ child = parent .( IfExpr ) .getABranch ( ) and
671+ prefix .isEmpty ( )
672+ or
673+ parent =
674+ any ( MatchExpr me |
675+ child = me .getAnArm ( ) .getExpr ( ) and
676+ me .getNumberOfArms ( ) > 1
677+ ) and
678+ prefix .isEmpty ( )
679+ or
680+ parent =
681+ any ( ArrayListExpr ale |
682+ child = ale .getAnExpr ( ) and
683+ ale .getNumberOfExprs ( ) > 1
684+ ) and
685+ prefix = TypePath:: singleton ( TArrayTypeParameter ( ) )
686+ or
687+ bodyReturns ( parent , child ) and
688+ strictcount ( Expr e | bodyReturns ( parent , e ) ) > 1 and
689+ prefix .isEmpty ( )
690+ }
691+
692+ /**
693+ * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
694+ * of `n2` at `prefix2`, but type information should only propagate from `n1` to
695+ * `n2`.
696+ */
697+ private predicate typeEqualityNonSymmetric (
698+ AstNode n1 , TypePath prefix1 , AstNode n2 , TypePath prefix2
699+ ) {
700+ lubCoercion ( n2 , n1 , prefix2 ) and
701+ prefix1 .isEmpty ( )
702+ or
703+ exists ( AstNode mid , TypePath prefixMid , TypePath suffix |
704+ typeEquality ( n1 , prefixMid , mid , prefix2 ) or
705+ typeEquality ( mid , prefix2 , n1 , prefixMid )
706+ |
707+ lubCoercion ( mid , n2 , suffix ) and
708+ not lubCoercion ( mid , n1 , _) and
709+ prefix1 = prefixMid .append ( suffix )
710+ )
711+ }
712+
638713pragma [ nomagic]
639714private Type inferTypeEquality ( AstNode n , TypePath path ) {
640715 exists ( TypePath prefix1 , AstNode n2 , TypePath prefix2 , TypePath suffix |
@@ -644,6 +719,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
644719 typeEquality ( n , prefix1 , n2 , prefix2 )
645720 or
646721 typeEquality ( n2 , prefix2 , n , prefix1 )
722+ or
723+ typeEqualityNonSymmetric ( n2 , prefix2 , n , prefix1 )
647724 )
648725}
649726
0 commit comments