21
21
overload ,
22
22
)
23
23
24
+ import xarray as xr
24
25
from loguru import logger
25
26
from typing_extensions import NotRequired , TypedDict , Unpack , assert_never , get_args
26
27
55
56
InstalledPackage ,
56
57
ValidationDetail ,
57
58
ValidationSummary ,
59
+ WarningEntry ,
58
60
)
59
61
60
62
from ._prediction_pipeline import create_prediction_pipeline
@@ -510,7 +512,7 @@ def load_description_and_test(
510
512
511
513
enable_determinism (determinism , weight_formats = weight_formats )
512
514
for w in weight_formats :
513
- _test_model_inference (rd , w , devices , ** deprecated )
515
+ _test_model_inference (rd , w , devices , stop_early = stop_early , ** deprecated )
514
516
if stop_early and rd .validation_summary .status == "failed" :
515
517
break
516
518
@@ -587,14 +589,16 @@ def _test_model_inference(
587
589
model : Union [v0_4 .ModelDescr , v0_5 .ModelDescr ],
588
590
weight_format : SupportedWeightsFormat ,
589
591
devices : Optional [Sequence [str ]],
592
+ stop_early : bool ,
590
593
** deprecated : Unpack [DeprecatedKwargs ],
591
594
) -> None :
592
595
test_name = f"Reproduce test outputs from test inputs ({ weight_format } )"
593
596
logger .debug ("starting '{}'" , test_name )
594
- errors : List [ErrorEntry ] = []
597
+ error_entries : List [ErrorEntry ] = []
598
+ warning_entries : List [WarningEntry ] = []
595
599
596
600
def add_error_entry (msg : str , with_traceback : bool = False ):
597
- errors .append (
601
+ error_entries .append (
598
602
ErrorEntry (
599
603
loc = ("weights" , weight_format ),
600
604
msg = msg ,
@@ -603,6 +607,15 @@ def add_error_entry(msg: str, with_traceback: bool = False):
603
607
)
604
608
)
605
609
610
+ def add_warning_entry (msg : str ):
611
+ warning_entries .append (
612
+ WarningEntry (
613
+ loc = ("weights" , weight_format ),
614
+ msg = msg ,
615
+ type = "bioimageio.core" ,
616
+ )
617
+ )
618
+
606
619
try :
607
620
inputs = get_test_inputs (model )
608
621
expected = get_test_outputs (model )
@@ -622,34 +635,58 @@ def add_error_entry(msg: str, with_traceback: bool = False):
622
635
actual = results .members .get (m )
623
636
if actual is None :
624
637
add_error_entry ("Output tensors for test case may not be None" )
625
- break
638
+ if stop_early :
639
+ break
640
+ else :
641
+ continue
626
642
627
643
rtol , atol , mismatched_tol = _get_tolerance (
628
644
model , wf = weight_format , m = m , ** deprecated
629
645
)
630
- mismatched = ( abs_diff := abs ( actual - expected )) > atol + rtol * abs (
631
- expected
632
- )
646
+ rtol_value = rtol * abs (expected )
647
+ abs_diff = abs ( actual - expected )
648
+ mismatched = abs_diff > atol + rtol_value
633
649
mismatched_elements = mismatched .sum ().item ()
634
- if mismatched_elements / expected .size > mismatched_tol / 1e6 :
635
- r_max_idx = (r_diff := (abs_diff / (abs (expected ) + 1e-6 ))).argmax ()
636
- r_max = r_diff [r_max_idx ].item ()
637
- r_actual = actual [r_max_idx ].item ()
638
- r_expected = expected [r_max_idx ].item ()
639
- a_max_idx = abs_diff .argmax ()
640
- a_max = abs_diff [a_max_idx ].item ()
641
- a_actual = actual [a_max_idx ].item ()
642
- a_expected = expected [a_max_idx ].item ()
643
- add_error_entry (
644
- f"Output '{ m } ' disagrees with { mismatched_elements } of"
645
- + f" { expected .size } expected values."
646
- + f"\n Max relative difference: { r_max :.2e} "
647
- + rf" (= \|{ r_actual :.2e} - { r_expected :.2e} \|/\|{ r_expected :.2e} + 1e-6\|)"
648
- + f" at { r_max_idx } "
649
- + f"\n Max absolute difference: { a_max :.2e} "
650
- + rf" (= \|{ a_actual :.7e} - { a_expected :.7e} \|) at { a_max_idx } "
651
- )
652
- break
650
+ if not mismatched_elements :
651
+ continue
652
+
653
+ mismatched_ppm = mismatched_elements / expected .size * 1e6
654
+ abs_diff [~ mismatched ] = 0 # ignore non-mismatched elements
655
+
656
+ r_max_idx = (r_diff := (abs_diff / (abs (expected ) + 1e-6 ))).argmax ()
657
+ r_max = r_diff [r_max_idx ].item ()
658
+ r_actual = actual [r_max_idx ].item ()
659
+ r_expected = expected [r_max_idx ].item ()
660
+
661
+ # Calculate the max absolute difference with the relative tolerance subtracted
662
+ abs_diff_wo_rtol : xr .DataArray = xr .ufuncs .maximum (
663
+ (abs_diff - rtol_value ).data , 0
664
+ )
665
+ a_max_idx = {
666
+ AxisId (k ): int (v ) for k , v in abs_diff_wo_rtol .argmax ().items ()
667
+ }
668
+
669
+ a_max = abs_diff [a_max_idx ].item ()
670
+ a_actual = actual [a_max_idx ].item ()
671
+ a_expected = expected [a_max_idx ].item ()
672
+
673
+ msg = (
674
+ f"Output '{ m } ' disagrees with { mismatched_elements } of"
675
+ + f" { expected .size } expected values"
676
+ + f" ({ mismatched_ppm :.1f} ppm)."
677
+ + f"\n Max relative difference: { r_max :.2e} "
678
+ + rf" (= \|{ r_actual :.2e} - { r_expected :.2e} \|/\|{ r_expected :.2e} + 1e-6\|)"
679
+ + f" at { r_max_idx } "
680
+ + f"\n Max absolute difference not accounted for by relative tolerance: { a_max :.2e} "
681
+ + rf" (= \|{ a_actual :.7e} - { a_expected :.7e} \|) at { a_max_idx } "
682
+ )
683
+ if mismatched_ppm > mismatched_tol :
684
+ add_error_entry (msg )
685
+ if stop_early :
686
+ break
687
+ else :
688
+ add_warning_entry (msg )
689
+
653
690
except Exception as e :
654
691
if get_validation_context ().raise_errors :
655
692
raise e
@@ -660,9 +697,10 @@ def add_error_entry(msg: str, with_traceback: bool = False):
660
697
ValidationDetail (
661
698
name = test_name ,
662
699
loc = ("weights" , weight_format ),
663
- status = "failed" if errors else "passed" ,
700
+ status = "failed" if error_entries else "passed" ,
664
701
recommended_env = get_conda_env (entry = dict (model .weights )[weight_format ]),
665
- errors = errors ,
702
+ errors = error_entries ,
703
+ warnings = warning_entries ,
666
704
)
667
705
)
668
706
0 commit comments