Skip to content

Commit c29da13

Browse files
authored
Merge pull request #460 from bioimage-io/dev
bump spec
2 parents 2b33b26 + 2766274 commit c29da13

26 files changed

+551
-268
lines changed

.github/workflows/build.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,6 @@ jobs:
6969
strategy:
7070
matrix:
7171
include:
72-
- python-version: '3.8'
73-
conda-env: py38
74-
spec: conda
75-
- python-version: '3.8'
76-
conda-env: py38
77-
spec: main
7872
- python-version: '3.9'
7973
conda-env: dev
8074
spec: conda
@@ -174,7 +168,7 @@ jobs:
174168
path: bioimageio_cache
175169
key: ${{matrix.run-expensive-tests && needs.populate-cache.outputs.cache-key || needs.populate-cache.outputs.cache-key-light}}
176170
- name: pytest
177-
run: pytest --disable-pytest-warnings
171+
run: pytest --cov bioimageio --cov-report xml --cov-append --capture no --disable-pytest-warnings
178172
env:
179173
BIOIMAGEIO_CACHE_PATH: bioimageio_cache
180174
RUN_EXPENSIVE_TESTS: ${{ matrix.run-expensive-tests && 'true' || 'false' }}

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,10 @@ may be controlled with the `LOGURU_LEVEL` environment variable.
364364

365365
## Changelog
366366

367+
### 0.9.0 (coming soon)
368+
369+
- update to [bioimageio.spec 0.5.4.3](https://github.com/bioimage-io/spec-bioimage-io/blob/main/changelog.md#bioimageiospec-0543)
370+
367371
### 0.8.0
368372

369373
- breaking: removed `decimals` argument from bioimageio CLI and `bioimageio.core.commands.test()`

bioimageio/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from bioimageio.spec import (
6+
ValidationSummary,
67
build_description,
78
dump_description,
89
load_dataset_description,
@@ -112,4 +113,5 @@
112113
"test_model",
113114
"test_resource",
114115
"validate_format",
116+
"ValidationSummary",
115117
]

bioimageio/core/_resource_tests.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
overload,
2222
)
2323

24+
import xarray as xr
2425
from loguru import logger
2526
from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
2627

@@ -55,6 +56,7 @@
5556
InstalledPackage,
5657
ValidationDetail,
5758
ValidationSummary,
59+
WarningEntry,
5860
)
5961

6062
from ._prediction_pipeline import create_prediction_pipeline
@@ -510,7 +512,7 @@ def load_description_and_test(
510512

511513
enable_determinism(determinism, weight_formats=weight_formats)
512514
for w in weight_formats:
513-
_test_model_inference(rd, w, devices, **deprecated)
515+
_test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated)
514516
if stop_early and rd.validation_summary.status == "failed":
515517
break
516518

@@ -587,14 +589,16 @@ def _test_model_inference(
587589
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
588590
weight_format: SupportedWeightsFormat,
589591
devices: Optional[Sequence[str]],
592+
stop_early: bool,
590593
**deprecated: Unpack[DeprecatedKwargs],
591594
) -> None:
592595
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
593596
logger.debug("starting '{}'", test_name)
594-
errors: List[ErrorEntry] = []
597+
error_entries: List[ErrorEntry] = []
598+
warning_entries: List[WarningEntry] = []
595599

596600
def add_error_entry(msg: str, with_traceback: bool = False):
597-
errors.append(
601+
error_entries.append(
598602
ErrorEntry(
599603
loc=("weights", weight_format),
600604
msg=msg,
@@ -603,6 +607,15 @@ def add_error_entry(msg: str, with_traceback: bool = False):
603607
)
604608
)
605609

610+
def add_warning_entry(msg: str):
611+
warning_entries.append(
612+
WarningEntry(
613+
loc=("weights", weight_format),
614+
msg=msg,
615+
type="bioimageio.core",
616+
)
617+
)
618+
606619
try:
607620
inputs = get_test_inputs(model)
608621
expected = get_test_outputs(model)
@@ -622,34 +635,58 @@ def add_error_entry(msg: str, with_traceback: bool = False):
622635
actual = results.members.get(m)
623636
if actual is None:
624637
add_error_entry("Output tensors for test case may not be None")
625-
break
638+
if stop_early:
639+
break
640+
else:
641+
continue
626642

627643
rtol, atol, mismatched_tol = _get_tolerance(
628644
model, wf=weight_format, m=m, **deprecated
629645
)
630-
mismatched = (abs_diff := abs(actual - expected)) > atol + rtol * abs(
631-
expected
632-
)
646+
rtol_value = rtol * abs(expected)
647+
abs_diff = abs(actual - expected)
648+
mismatched = abs_diff > atol + rtol_value
633649
mismatched_elements = mismatched.sum().item()
634-
if mismatched_elements / expected.size > mismatched_tol / 1e6:
635-
r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
636-
r_max = r_diff[r_max_idx].item()
637-
r_actual = actual[r_max_idx].item()
638-
r_expected = expected[r_max_idx].item()
639-
a_max_idx = abs_diff.argmax()
640-
a_max = abs_diff[a_max_idx].item()
641-
a_actual = actual[a_max_idx].item()
642-
a_expected = expected[a_max_idx].item()
643-
add_error_entry(
644-
f"Output '{m}' disagrees with {mismatched_elements} of"
645-
+ f" {expected.size} expected values."
646-
+ f"\n Max relative difference: {r_max:.2e}"
647-
+ rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
648-
+ f" at {r_max_idx}"
649-
+ f"\n Max absolute difference: {a_max:.2e}"
650-
+ rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
651-
)
652-
break
650+
if not mismatched_elements:
651+
continue
652+
653+
mismatched_ppm = mismatched_elements / expected.size * 1e6
654+
abs_diff[~mismatched] = 0 # ignore non-mismatched elements
655+
656+
r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
657+
r_max = r_diff[r_max_idx].item()
658+
r_actual = actual[r_max_idx].item()
659+
r_expected = expected[r_max_idx].item()
660+
661+
# Calculate the max absolute difference with the relative tolerance subtracted
662+
abs_diff_wo_rtol: xr.DataArray = xr.ufuncs.maximum(
663+
(abs_diff - rtol_value).data, 0
664+
)
665+
a_max_idx = {
666+
AxisId(k): int(v) for k, v in abs_diff_wo_rtol.argmax().items()
667+
}
668+
669+
a_max = abs_diff[a_max_idx].item()
670+
a_actual = actual[a_max_idx].item()
671+
a_expected = expected[a_max_idx].item()
672+
673+
msg = (
674+
f"Output '{m}' disagrees with {mismatched_elements} of"
675+
+ f" {expected.size} expected values"
676+
+ f" ({mismatched_ppm:.1f} ppm)."
677+
+ f"\n Max relative difference: {r_max:.2e}"
678+
+ rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
679+
+ f" at {r_max_idx}"
680+
+ f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
681+
+ rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
682+
)
683+
if mismatched_ppm > mismatched_tol:
684+
add_error_entry(msg)
685+
if stop_early:
686+
break
687+
else:
688+
add_warning_entry(msg)
689+
653690
except Exception as e:
654691
if get_validation_context().raise_errors:
655692
raise e
@@ -660,9 +697,10 @@ def add_error_entry(msg: str, with_traceback: bool = False):
660697
ValidationDetail(
661698
name=test_name,
662699
loc=("weights", weight_format),
663-
status="failed" if errors else "passed",
700+
status="failed" if error_entries else "passed",
664701
recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
665-
errors=errors,
702+
errors=error_entries,
703+
warnings=warning_entries,
666704
)
667705
)
668706

bioimageio/core/_settings.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
from typing import Literal
22

3-
from dotenv import load_dotenv
43
from pydantic import Field
54
from typing_extensions import Annotated
65

76
from bioimageio.spec._internal._settings import Settings as SpecSettings
87

9-
_ = load_dotenv()
10-
118

129
class Settings(SpecSettings):
1310
"""environment variables for bioimageio.spec and bioimageio.core"""

bioimageio/core/backends/keras_backend.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
import os
2+
import shutil
3+
from pathlib import Path
4+
from tempfile import TemporaryDirectory
25
from typing import Any, Optional, Sequence, Union
36

7+
import h5py # pyright: ignore[reportMissingTypeStubs]
8+
from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs]
9+
legacy_h5_format,
10+
)
411
from loguru import logger
512
from numpy.typing import NDArray
613

7-
from bioimageio.spec._internal.io import download
8-
from bioimageio.spec._internal.type_guards import is_list, is_tuple
914
from bioimageio.spec.model import v0_4, v0_5
1015
from bioimageio.spec.model.v0_5 import Version
1116

1217
from .._settings import settings
1318
from ..digest_spec import get_axes_infos
19+
from ..utils._type_guards import is_list, is_tuple
1420
from ._model_adapter import ModelAdapter
1521

1622
os.environ["KERAS_BACKEND"] = settings.keras_backend
1723

24+
1825
# by default, we use the keras integrated with tensorflow
1926
# TODO: check if we should prefer keras
2027
try:
@@ -67,9 +74,18 @@ def __init__(
6774
devices,
6875
)
6976

70-
weight_path = download(model_description.weights.keras_hdf5.source).path
77+
weight_reader = model_description.weights.keras_hdf5.get_reader()
78+
if weight_reader.suffix in (".h5", "hdf5"):
79+
h5_file = h5py.File(weight_reader, mode="r")
80+
self._network = legacy_h5_format.load_model_from_hdf5(h5_file)
81+
else:
82+
with TemporaryDirectory() as temp_dir:
83+
temp_path = Path(temp_dir) / weight_reader.original_file_name
84+
with temp_path.open("wb") as f:
85+
shutil.copyfileobj(weight_reader, f)
86+
87+
self._network = keras.models.load_model(temp_path)
7188

72-
self._network = keras.models.load_model(weight_path)
7389
self._output_axes = [
7490
tuple(a.id for a in get_axes_infos(out))
7591
for out in model_description.outputs

bioimageio/core/backends/onnx_backend.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
66
from numpy.typing import NDArray
77

8-
from bioimageio.spec._internal.type_guards import is_list, is_tuple
98
from bioimageio.spec.model import v0_4, v0_5
10-
from bioimageio.spec.utils import download
119

1210
from ..model_adapters import ModelAdapter
11+
from ..utils._type_guards import is_list, is_tuple
1312

1413

1514
class ONNXModelAdapter(ModelAdapter):
@@ -24,8 +23,8 @@ def __init__(
2423
if model_description.weights.onnx is None:
2524
raise ValueError("No ONNX weights specified for {model_description.name}")
2625

27-
local_path = download(model_description.weights.onnx.source).path
28-
self._session = rt.InferenceSession(local_path.read_bytes())
26+
reader = model_description.weights.onnx.get_reader()
27+
self._session = rt.InferenceSession(reader.read())
2928
onnx_inputs = self._session.get_inputs()
3029
self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
3130

bioimageio/core/backends/pytorch_backend.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import gc
22
import warnings
33
from contextlib import nullcontext
4-
from io import TextIOWrapper
4+
from io import BytesIO, TextIOWrapper
55
from pathlib import Path
66
from typing import Any, List, Literal, Optional, Sequence, Union
77

@@ -11,12 +11,13 @@
1111
from torch import nn
1212
from typing_extensions import assert_never
1313

14-
from bioimageio.spec._internal.type_guards import is_list, is_ndarray, is_tuple
15-
from bioimageio.spec.common import ZipPath
14+
from bioimageio.spec._internal.version_type import Version
15+
from bioimageio.spec.common import BytesReader, ZipPath
1616
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
1717
from bioimageio.spec.utils import download
1818

1919
from ..digest_spec import import_callable
20+
from ..utils._type_guards import is_list, is_ndarray, is_tuple
2021
from ._model_adapter import ModelAdapter
2122

2223

@@ -73,7 +74,9 @@ def _forward_impl(
7374
if r is None:
7475
result.append(None)
7576
elif isinstance(r, torch.Tensor):
76-
r_np: NDArray[Any] = r.detach().cpu().numpy()
77+
r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
78+
r.detach().cpu().numpy()
79+
)
7780
result.append(r_np)
7881
elif is_ndarray(r):
7982
result.append(r)
@@ -129,34 +132,49 @@ def load_torch_model(
129132
if load_state:
130133
torch_model = load_torch_state_dict(
131134
torch_model,
132-
path=download(weight_spec).path,
135+
path=download(weight_spec),
133136
devices=use_devices,
134137
)
135138
return torch_model
136139

137140

138141
def load_torch_state_dict(
139142
model: nn.Module,
140-
path: Union[Path, ZipPath],
143+
path: Union[Path, ZipPath, BytesReader],
141144
devices: Sequence[torch.device],
142145
) -> nn.Module:
143146
model = model.to(devices[0])
144-
with path.open("rb") as f:
147+
if isinstance(path, (Path, ZipPath)):
148+
ctxt = path.open("rb")
149+
else:
150+
ctxt = nullcontext(BytesIO(path.read()))
151+
152+
with ctxt as f:
145153
assert not isinstance(f, TextIOWrapper)
146-
state = torch.load(f, map_location=devices[0], weights_only=True)
154+
if Version(str(torch.__version__)) < Version("1.13"):
155+
state = torch.load(f, map_location=devices[0])
156+
else:
157+
state = torch.load(f, map_location=devices[0], weights_only=True)
147158

148159
incompatible = model.load_state_dict(state)
149160
if (
150-
incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
151-
and incompatible.missing_keys
161+
isinstance(incompatible, tuple)
162+
and hasattr(incompatible, "missing_keys")
163+
and hasattr(incompatible, "unexpected_keys")
152164
):
153-
logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
165+
if incompatible.missing_keys:
166+
logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
154167

155-
if (
156-
incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
157-
and incompatible.unexpected_keys
158-
):
159-
logger.warning("Unexpected state dict keys: {}", incompatible.unexpected_keys)
168+
if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys:
169+
logger.warning(
170+
"Unexpected state dict keys: {}", incompatible.unexpected_keys
171+
)
172+
else:
173+
logger.warning(
174+
"`model.load_state_dict()` unexpectedly returned: {} "
175+
+ "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)",
176+
(s[:20] + "..." if len(s := str(incompatible)) > 20 else s),
177+
)
160178

161179
return model
162180

0 commit comments

Comments
 (0)