Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions immich_model_exporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import resource
from pathlib import Path

import numpy as np
import typer
from tenacity import retry, stop_after_attempt, wait_fixed
from typing_extensions import Annotated
Expand Down Expand Up @@ -53,6 +54,8 @@ def export(
output_dir = output_dir / model_name
match model_source:
case ModelSource.MCLIP | ModelSource.OPENCLIP:
rand = np.random.rand(1,77)
np.save("randtextualinput.npy",rand)
output_dir.mkdir(parents=True, exist_ok=True)
onnx_export(hf_model_name, model_source, output_dir, cache=cache)
case ModelSource.INSIGHTFACE:
Expand Down
1 change: 1 addition & 0 deletions immich_model_exporter/exporters/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SourceMetadata(NamedTuple):
}

RKNN_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]
RKNN_BLOCKED_OPS = ["CumSum"]


# glob to delete old UUID blobs when reuploading models
Expand Down
11 changes: 10 additions & 1 deletion immich_model_exporter/exporters/rknn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from .constants import RKNN_SOCS
from .constants import RKNN_BLOCKED_OPS, RKNN_SOCS


def _export_platform(
Expand Down Expand Up @@ -36,6 +36,15 @@ def _export_platform(

ret = rknn.build(do_quantization=False)

if "textual" in input_path.as_posix():
ret = rknn.accuracy_analysis(inputs=["randtextualinput.npy"])
if ret != 0:
RuntimeError("Accuracy analysis failed!")
analysis_result= open("./snapshot/error_analysis.txt","r")
for ops in RKNN_BLOCKED_OPS:
if ops in analysis_result.read():
raise RuntimeError("ONNX Model contains Unsupported OPs!")

if ret != 0:
raise RuntimeError("Build failed!")

Expand Down