@@ -10,9 +10,8 @@ use cairo_lang_semantic::{
10
10
} ;
11
11
use cairo_lang_syntax:: node:: TypedStablePtr ;
12
12
use cairo_lang_syntax:: node:: ast:: ExprPtr ;
13
- use cairo_lang_utils:: iterators:: zip_eq3;
14
13
use cairo_lang_utils:: ordered_hash_map:: OrderedHashMap ;
15
- use itertools:: Itertools ;
14
+ use itertools:: { Itertools , zip_eq } ;
16
15
use num_bigint:: BigInt ;
17
16
use salsa:: Database ;
18
17
@@ -162,6 +161,16 @@ pub fn create_node_for_patterns<'db>(
162
161
}
163
162
}
164
163
164
+ #[ derive( Clone , Default ) ]
165
+ struct VariantInfo < ' a , ' db > {
166
+ // The list of the indices of the patterns that match this variant.
167
+ filter : FilteredPatterns ,
168
+ // The list of the inner patterns.
169
+ // For example, a pattern `A(B(x))` will add the (inner) pattern `B(x)` to the vector at the
170
+ // index of the variant `A`.
171
+ inner_patterns : Vec < PatternOption < ' a , ' db > > ,
172
+ }
173
+
165
174
/// Creates an [EnumMatch] node for the given `input_var` and `patterns`.
166
175
fn create_node_for_enum < ' db > (
167
176
params : CreateNodeParams < ' db , ' _ , ' _ > ,
@@ -172,34 +181,27 @@ fn create_node_for_enum<'db>(
172
181
let CreateNodeParams { ctx, graph, patterns, build_node_callback, location } = params;
173
182
let concrete_variants = ctx. db . concrete_enum_variants ( concrete_enum_id) . unwrap ( ) ;
174
183
175
- // Maps variant index to the list of the indices of the patterns that match it.
176
- let mut variant_to_pattern_indices = vec ! [ FilteredPatterns :: default ( ) ; concrete_variants. len( ) ] ;
177
-
178
- // Maps variant index to the list of the inner patterns.
179
- // For example, a pattern `A(B(x))` will add the (inner) pattern `B(x)` to the vector at the
180
- // index of the variant `A`.
181
- let mut variant_to_inner_patterns: Vec < Vec < PatternOption < ' _ , ' db > > > =
182
- vec ! [ vec![ ] ; concrete_variants. len( ) ] ;
184
+ // Maps variant index to the [VariantInfo].
185
+ let mut variants = vec ! [ VariantInfo :: default ( ) ; concrete_variants. len( ) ] ;
183
186
184
187
for ( idx, pattern) in patterns. iter ( ) . enumerate ( ) {
185
188
match pattern {
186
189
Some ( semantic:: Pattern :: EnumVariant ( PatternEnumVariant {
187
190
variant,
188
- inner_pattern,
191
+ inner_pattern : inner_pattern_id ,
189
192
..
190
193
} ) ) => {
191
- variant_to_pattern_indices[ variant. idx ] . add ( idx) ;
192
- variant_to_inner_patterns[ variant. idx ]
193
- . push ( inner_pattern. map ( |inner_pattern| get_pattern ( ctx, inner_pattern) ) ) ;
194
+ let inner_pattern =
195
+ inner_pattern_id. map ( |inner_pattern| get_pattern ( ctx, inner_pattern) ) ;
196
+ variants[ variant. idx ] . filter . add ( idx) ;
197
+ variants[ variant. idx ] . inner_patterns . push ( inner_pattern) ;
194
198
}
195
199
Some ( semantic:: Pattern :: Otherwise ( ..) ) | None => {
196
- // Add `idx` to all the variants.
197
- for pattern_indices in variant_to_pattern_indices. iter_mut ( ) {
198
- pattern_indices. add ( idx) ;
199
- }
200
- // Add the `_` pattern (represented by `None`) to all the variants.
201
- for inner_patterns in variant_to_inner_patterns. iter_mut ( ) {
202
- inner_patterns. push ( None ) ;
200
+ for variant_info in variants. iter_mut ( ) {
201
+ // Add `idx` to all the variants.
202
+ variant_info. filter . add ( idx) ;
203
+ // Add the `_` pattern (represented by `None`) to all the variants.
204
+ variant_info. inner_patterns . push ( None ) ;
203
205
}
204
206
}
205
207
Some ( semantic:: Pattern :: Variable ( ..) ) => unreachable ! ( ) ,
@@ -221,30 +223,29 @@ fn create_node_for_enum<'db>(
221
223
}
222
224
223
225
// Create a node in the graph for each variant.
224
- let variants =
225
- zip_eq3 ( concrete_variants, variant_to_pattern_indices, variant_to_inner_patterns)
226
- . map ( |( concrete_variant, pattern_indices, inner_patterns) | {
227
- let inner_var = graph
228
- . new_var ( wrap_in_snapshots ( ctx. db , concrete_variant. ty , n_snapshots) , location) ;
229
- let node = create_node_for_patterns (
230
- CreateNodeParams {
231
- ctx,
232
- graph,
233
- patterns : & inner_patterns,
234
- build_node_callback : & mut |graph, pattern_indices_inner, path| {
235
- build_node_callback (
236
- graph,
237
- pattern_indices_inner. lift ( & pattern_indices) ,
238
- format ! ( "{}({path})" , concrete_variant. id. name( ctx. db) ) ,
239
- )
240
- } ,
241
- location,
226
+ let variants = zip_eq ( concrete_variants, variants)
227
+ . map ( |( concrete_variant, variant_info) | {
228
+ let inner_var = graph
229
+ . new_var ( wrap_in_snapshots ( ctx. db , concrete_variant. ty , n_snapshots) , location) ;
230
+ let node = create_node_for_patterns (
231
+ CreateNodeParams {
232
+ ctx,
233
+ graph,
234
+ patterns : & variant_info. inner_patterns ,
235
+ build_node_callback : & mut |graph, pattern_indices_inner, path| {
236
+ build_node_callback (
237
+ graph,
238
+ pattern_indices_inner. lift ( & variant_info. filter ) ,
239
+ format ! ( "{}({path})" , concrete_variant. id. name( ctx. db) ) ,
240
+ )
242
241
} ,
243
- inner_var,
244
- ) ;
245
- ( concrete_variant, node, inner_var)
246
- } )
247
- . collect_vec ( ) ;
242
+ location,
243
+ } ,
244
+ inner_var,
245
+ ) ;
246
+ ( concrete_variant, node, inner_var)
247
+ } )
248
+ . collect_vec ( ) ;
248
249
249
250
// Optimization: If all the variants lead to the same node, and the inner variables are not
250
251
// used, there is no need to do the match.
0 commit comments