@@ -10,7 +10,7 @@ use crate::{
10
10
cad_op:: { Group , OpArg , OpKclValue , Operation } ,
11
11
kcl_value:: FunctionSource ,
12
12
memory,
13
- types:: RuntimeType ,
13
+ types:: { DisplayType , RuntimeType , UnitSubsts } ,
14
14
} ,
15
15
parsing:: ast:: types:: { CallExpressionKw , DefaultParamVal , FunctionExpression , Node , Program , Type } ,
16
16
std:: StdFn ,
@@ -215,7 +215,7 @@ impl Node<CallExpressionKw> {
215
215
216
216
// Clone the function so that we can use a mutable reference to
217
217
// exec_state.
218
- let func: KclValue = fn_name. get_result ( exec_state, ctx) . await ?. clone ( ) ;
218
+ let func = fn_name. get_result ( exec_state, ctx) . await ?. clone ( ) ;
219
219
220
220
let Some ( fn_src) = func. as_function ( ) else {
221
221
return Err ( KclError :: new_semantic ( KclErrorDetails :: new (
@@ -261,7 +261,7 @@ impl Node<CallExpressionKw> {
261
261
let args = Args :: new ( fn_args, unlabeled, callsite, exec_state, ctx. clone ( ) ) ;
262
262
263
263
let return_value = fn_src
264
- . call_kw ( Some ( fn_name. to_string ( ) ) , exec_state, ctx, args, callsite)
264
+ . call ( Some ( fn_name. to_string ( ) ) , exec_state, ctx, args, callsite)
265
265
. await
266
266
. map_err ( |e| {
267
267
// Add the call expression to the source ranges.
@@ -291,7 +291,7 @@ impl Node<CallExpressionKw> {
291
291
}
292
292
293
293
impl FunctionDefinition < ' _ > {
294
- pub async fn call_kw (
294
+ pub async fn call (
295
295
& self ,
296
296
fn_name : Option < String > ,
297
297
exec_state : & mut ExecState ,
@@ -324,7 +324,7 @@ impl FunctionDefinition<'_> {
324
324
) ;
325
325
}
326
326
327
- let args = type_check_params_kw ( fn_name. as_deref ( ) , self , args, exec_state) ?;
327
+ let ( args, unit_substs ) = type_check_params ( fn_name. as_deref ( ) , self , args, exec_state) ?;
328
328
329
329
// Don't early return until the stack frame is popped!
330
330
self . body . prep_mem ( exec_state) ;
@@ -367,11 +367,13 @@ impl FunctionDefinition<'_> {
367
367
None
368
368
} ;
369
369
370
+ exec_state. mut_unit_stack ( ) . push ( unit_substs) ;
370
371
let mut result = match & self . body {
371
372
FunctionBody :: Rust ( f) => f ( exec_state, args) . await . map ( Some ) ,
372
373
FunctionBody :: Kcl ( f, _) => {
373
- if let Err ( e) = assign_args_to_params_kw ( self , args, exec_state) {
374
+ if let Err ( e) = assign_args_to_params ( self , args, exec_state) {
374
375
exec_state. mut_stack ( ) . pop_env ( ) ;
376
+ exec_state. mut_unit_stack ( ) . pop ( ) ;
375
377
return Err ( e) ;
376
378
}
377
379
@@ -386,6 +388,7 @@ impl FunctionDefinition<'_> {
386
388
} ;
387
389
388
390
exec_state. mut_stack ( ) . pop_env ( ) ;
391
+ exec_state. mut_unit_stack ( ) . pop ( ) ;
389
392
390
393
if let Some ( mut op) = op {
391
394
op. set_std_lib_call_is_error ( result. is_err ( ) ) ;
@@ -405,7 +408,21 @@ impl FunctionDefinition<'_> {
405
408
update_memory_for_tags_of_geometry ( result, exec_state) ?;
406
409
}
407
410
408
- coerce_result_type ( result, self , exec_state)
411
+ let ret_ty = self
412
+ . return_type
413
+ . as_ref ( )
414
+ . map ( |ret_ty| {
415
+ let mut ty = RuntimeType :: from_parsed ( ret_ty. inner . clone ( ) , exec_state, ret_ty. as_source_range ( ) )
416
+ . map_err ( |e| KclError :: new_semantic ( e. into ( ) ) ) ?;
417
+ ty. subst_units ( unit_substs) ;
418
+ Ok :: < _ , KclError > ( ty)
419
+ } )
420
+ . transpose ( ) ?;
421
+ if let Some ( ret_ty) = & ret_ty {
422
+ coerce_result_type ( result, ret_ty, exec_state)
423
+ } else {
424
+ result
425
+ }
409
426
}
410
427
411
428
// Postcondition: result.is_some() if function is not in the standard library.
@@ -427,7 +444,7 @@ impl FunctionBody<'_> {
427
444
}
428
445
429
446
impl FunctionSource {
430
- pub async fn call_kw (
447
+ pub async fn call (
431
448
& self ,
432
449
fn_name : Option < String > ,
433
450
exec_state : & mut ExecState ,
@@ -436,7 +453,7 @@ impl FunctionSource {
436
453
callsite : SourceRange ,
437
454
) -> Result < Option < KclValue > , KclError > {
438
455
let def: FunctionDefinition = self . into ( ) ;
439
- def. call_kw ( fn_name, exec_state, ctx, args, callsite) . await
456
+ def. call ( fn_name, exec_state, ctx, args, callsite) . await
440
457
}
441
458
}
442
459
@@ -551,7 +568,12 @@ fn update_memory_for_tags_of_geometry(result: &mut KclValue, exec_state: &mut Ex
551
568
Ok ( ( ) )
552
569
}
553
570
554
- fn type_err_str ( expected : & Type , found : & KclValue , source_range : & SourceRange , exec_state : & mut ExecState ) -> String {
571
+ fn type_err_str (
572
+ expected : & impl DisplayType ,
573
+ found : & KclValue ,
574
+ source_range : & SourceRange ,
575
+ exec_state : & mut ExecState ,
576
+ ) -> String {
555
577
fn strip_backticks ( s : & str ) -> & str {
556
578
let mut result = s;
557
579
if s. starts_with ( '`' ) {
@@ -563,8 +585,8 @@ fn type_err_str(expected: &Type, found: &KclValue, source_range: &SourceRange, e
563
585
result
564
586
}
565
587
566
- let expected_human = expected. human_friendly_type ( ) ;
567
- let expected_ty = expected. to_string ( ) ;
588
+ let expected_human = expected. human_friendly_string ( ) ;
589
+ let expected_ty = expected. src_string ( ) ;
568
590
let expected_str =
569
591
if expected_human == expected_ty || expected_human == format ! ( "a value with type `{expected_ty}`" ) {
570
592
format ! ( "a value with type `{expected_ty}`" )
@@ -581,20 +603,20 @@ fn type_err_str(expected: &Type, found: &KclValue, source_range: &SourceRange, e
581
603
582
604
let mut result = format ! ( "{expected_str}, but found {found_str}." ) ;
583
605
584
- if found. is_unknown_number ( ) {
606
+ if found. is_unknown_number ( exec_state ) {
585
607
exec_state. clear_units_warnings ( source_range) ;
586
608
result. push_str ( "\n The found value is a number but has incomplete units information. You can probably fix this error by specifying the units using type ascription, e.g., `len: mm` or `(a * b): deg`." ) ;
587
609
}
588
610
589
611
result
590
612
}
591
613
592
- fn type_check_params_kw (
614
+ fn type_check_params (
593
615
fn_name : Option < & str > ,
594
616
fn_def : & FunctionDefinition < ' _ > ,
595
617
mut args : Args < Sugary > ,
596
618
exec_state : & mut ExecState ,
597
- ) -> Result < Args < Desugared > , KclError > {
619
+ ) -> Result < ( Args < Desugared > , UnitSubsts ) , KclError > {
598
620
let mut result = Args :: new_no_args ( args. source_range , args. ctx ) ;
599
621
600
622
// If it's possible the input arg was meant to be labelled and we probably don't want to use
@@ -609,6 +631,10 @@ fn type_check_params_kw(
609
631
args. labeled . insert ( label. unwrap ( ) , arg) ;
610
632
}
611
633
634
+ // Collect substitutions for `number(Length)` or `number(Angle)`. See docs on execution::types::UnitSubsts
635
+ // for details.
636
+ let mut unit_substs = UnitSubsts :: default ( ) ;
637
+
612
638
// Apply the `a == a: a` shorthand by desugaring unlabeled args into labeled ones.
613
639
let ( labeled_unlabeled, unlabeled_unlabeled) = args. unlabeled . into_iter ( ) . partition ( |( l, _) | {
614
640
if let Some ( l) = l
@@ -657,20 +683,28 @@ fn type_check_params_kw(
657
683
} else if args. unlabeled . len ( ) == 1 {
658
684
let mut arg = args. unlabeled . pop ( ) . unwrap ( ) . 1 ;
659
685
if let Some ( ty) = ty {
660
- let rty = RuntimeType :: from_parsed ( ty. clone ( ) , exec_state, arg. source_range )
686
+ let mut rty = RuntimeType :: from_parsed ( ty. clone ( ) , exec_state, arg. source_range )
661
687
. map_err ( |e| KclError :: new_semantic ( e. into ( ) ) ) ?;
662
- arg. value = arg. value . coerce ( & rty, true , exec_state) . map_err ( |_| {
663
- KclError :: new_argument ( KclErrorDetails :: new (
664
- format ! (
665
- "The input argument of {} requires {}" ,
666
- fn_name
667
- . map( |n| format!( "`{n}`" ) )
668
- . unwrap_or_else( || "this function" . to_owned( ) ) ,
669
- type_err_str( ty, & arg. value, & arg. source_range, exec_state) ,
670
- ) ,
671
- vec ! [ arg. source_range] ,
672
- ) )
673
- } ) ?;
688
+ rty. subst_units ( unit_substs) ;
689
+
690
+ let ( value, substs) = arg
691
+ . value
692
+ . coerce_and_find_unit_substs ( & rty, true , exec_state)
693
+ . map_err ( |_| {
694
+ KclError :: new_argument ( KclErrorDetails :: new (
695
+ format ! (
696
+ "The input argument of {} requires {}" ,
697
+ fn_name
698
+ . map( |n| format!( "`{n}`" ) )
699
+ . unwrap_or_else( || "this function" . to_owned( ) ) ,
700
+ type_err_str( ty, & arg. value, & arg. source_range, exec_state) ,
701
+ ) ,
702
+ vec ! [ arg. source_range] ,
703
+ ) )
704
+ } ) ?;
705
+
706
+ arg. value = value;
707
+ unit_substs = unit_substs. or ( substs) ;
674
708
}
675
709
result. unlabeled = vec ! [ ( None , arg) ]
676
710
} else {
@@ -748,11 +782,13 @@ fn type_check_params_kw(
748
782
// For optional args, passing None should be the same as not passing an arg.
749
783
if !( def. is_some ( ) && matches ! ( arg. value, KclValue :: KclNone { .. } ) ) {
750
784
if let Some ( ty) = ty {
751
- let rty = RuntimeType :: from_parsed ( ty. clone ( ) , exec_state, arg. source_range )
785
+ let mut rty = RuntimeType :: from_parsed ( ty. clone ( ) , exec_state, arg. source_range )
752
786
. map_err ( |e| KclError :: new_semantic ( e. into ( ) ) ) ?;
753
- arg. value = arg
787
+ rty. subst_units ( unit_substs) ;
788
+
789
+ let ( value, substs) = arg
754
790
. value
755
- . coerce (
791
+ . coerce_and_find_unit_substs (
756
792
& rty,
757
793
true ,
758
794
exec_state,
@@ -771,6 +807,9 @@ fn type_check_params_kw(
771
807
vec ! [ arg. source_range] ,
772
808
) )
773
809
} ) ?;
810
+
811
+ arg. value = value;
812
+ unit_substs = unit_substs. or ( substs) ;
774
813
}
775
814
result. labeled . insert ( label, arg) ;
776
815
}
@@ -789,10 +828,10 @@ fn type_check_params_kw(
789
828
}
790
829
}
791
830
792
- Ok ( result)
831
+ Ok ( ( result, unit_substs ) )
793
832
}
794
833
795
- fn assign_args_to_params_kw (
834
+ fn assign_args_to_params (
796
835
fn_def : & FunctionDefinition < ' _ > ,
797
836
args : Args < Desugared > ,
798
837
exec_state : & mut ExecState ,
@@ -848,26 +887,20 @@ fn assign_args_to_params_kw(
848
887
849
888
fn coerce_result_type (
850
889
result : Result < Option < KclValue > , KclError > ,
851
- fn_def : & FunctionDefinition < ' _ > ,
890
+ return_ty : & RuntimeType ,
852
891
exec_state : & mut ExecState ,
853
892
) -> Result < Option < KclValue > , KclError > {
854
893
if let Ok ( Some ( val) ) = result {
855
- if let Some ( ret_ty) = & fn_def. return_type {
856
- let ty = RuntimeType :: from_parsed ( ret_ty. inner . clone ( ) , exec_state, ret_ty. as_source_range ( ) )
857
- . map_err ( |e| KclError :: new_semantic ( e. into ( ) ) ) ?;
858
- let val = val. coerce ( & ty, true , exec_state) . map_err ( |_| {
859
- KclError :: new_type ( KclErrorDetails :: new (
860
- format ! (
861
- "This function requires its result to be {}" ,
862
- type_err_str( ret_ty, & val, & ( & val) . into( ) , exec_state)
863
- ) ,
864
- ret_ty. as_source_ranges ( ) ,
865
- ) )
866
- } ) ?;
867
- Ok ( Some ( val) )
868
- } else {
869
- Ok ( Some ( val) )
870
- }
894
+ let val = val. coerce ( return_ty, true , exec_state) . map_err ( |_| {
895
+ KclError :: new_type ( KclErrorDetails :: new (
896
+ format ! (
897
+ "This function requires its result to be {}" ,
898
+ type_err_str( return_ty, & val, & ( & val) . into( ) , exec_state)
899
+ ) ,
900
+ val. into ( ) ,
901
+ ) )
902
+ } ) ?;
903
+ Ok ( Some ( val) )
871
904
} else {
872
905
result
873
906
}
@@ -1016,9 +1049,8 @@ mod test {
1016
1049
pipe_value : None ,
1017
1050
_status : std:: marker:: PhantomData ,
1018
1051
} ;
1019
-
1020
- let actual = assign_args_to_params_kw ( & ( & func_src) . into ( ) , args, & mut exec_state)
1021
- . map ( |_| exec_state. mod_local . stack ) ;
1052
+ let actual =
1053
+ assign_args_to_params ( & ( & func_src) . into ( ) , args, & mut exec_state) . map ( |_| exec_state. mod_local . stack ) ;
1022
1054
assert_eq ! (
1023
1055
actual, expected,
1024
1056
"failed test '{test_name}':\n got {actual:?}\n but expected\n {expected:?}"
0 commit comments