11from __future__ import annotations
22
3+ import json
34import logging
45import os
6+ import platform
7+ import tarfile
8+ import typing as t
59import urllib .request
610from collections .abc import Sequence
711from dataclasses import dataclass , field
1014
1115from .constants import (
1216 FILENAME_METADATA_JSON ,
13- MIME_APPLICATION_CONFIG ,
17+ MIME_APPLICATION_MLMETADATA ,
1418 MIME_APPLICATION_MLMODEL ,
19+ MIME_BLOB ,
20+ MIME_MANIFEST_CONFIG ,
1521)
1622from .listener import Event , Listener , PushEvent
1723from .model_metadata import ModelMetadata
2026logger = logging .getLogger (__name__ )
2127
2228
29+ @dataclass
30+ class DeferredLayer :
31+ src : Path
32+ dest : Path
33+ media_type : str
34+ transform : t .Callable [[], None ] | None = None
35+ owned : bool = True
36+
37+ def __post_init__ (self ):
38+ if self .dest .exists ():
39+ self .owned = False
40+
41+ @classmethod
42+ def raw (cls , src : Path , media_type : str ) -> DeferredLayer :
43+ return cls (src , src , media_type )
44+
45+ @classmethod
46+ def blob (cls , src : Path , gz : bool = False ) -> DeferredLayer :
47+ oflag = "w"
48+ media_type = MIME_BLOB
49+ if gz :
50+ oflag += ":gz"
51+ media_type += "+gzip"
52+
53+ dest = src .with_suffix (".tar" )
54+
55+ def _tar ():
56+ with tarfile .open (dest , oflag ) as tf :
57+ tf .add (src , arcname = src .name )
58+
59+ return cls (src , dest , media_type , _tar )
60+
61+ def as_layer (self ) -> str :
62+ if self .owned and self .transform :
63+ self .transform ()
64+ return f"{ self .dest } :{ self .media_type } "
65+
66+
67+ def get_arch () -> str :
68+ mac = platform .machine ()
69+ if mac == "x86_64" :
70+ return "amd64"
71+ if mac == "arm64" or mac == "aarch64" :
72+ return "arm64"
73+ msg = f"Unsupported architecture: { mac } "
74+ raise NotImplementedError (msg )
75+
76+
2377def download_file (uri : str ):
2478 file_name = os .path .basename (uri )
2579 urllib .request .urlretrieve (uri , file_name )
@@ -41,54 +95,83 @@ def push(
4195 self ,
4296 target : str ,
4397 path : Path | str ,
98+ as_artifact : bool = False ,
4499 ** kwargs ,
45100 ):
46101 owns_meta = True
47102 if isinstance (path , str ):
48103 path = Path (path )
49104
50105 meta_path = path .parent / FILENAME_METADATA_JSON
51- if not kwargs and meta_path .exists ():
106+ if meta_path .exists ():
52107 owns_meta = False
53108 logger .warning ("Reusing intermediate metadata files." )
54109 logger .debug (f"{ meta_path } " )
55- with open (meta_path , "r" ) as f :
56- model_metadata = ModelMetadata .from_json (f .read ())
57- elif meta_path .exists ():
58- err = dedent (f"""
59- OMLMD intermediate metadata files found at '{ meta_path } '.
60- Cannot resolve with conflicting keyword args: { kwargs } .
61- You can reuse the existing metadata by omitting any keywords.
62- If that was NOT intended, please REMOVE that file from your environment before re-running.
63-
64- Note for advanced users: if merging keys with existing metadata is desired, you should create a Feature Request upstream: https://github.com/containers/omlmd""" )
65- raise RuntimeError (err )
110+ model_metadata = ModelMetadata (** json .loads (meta_path .read_bytes ()))
111+ if kwargs and ModelMetadata .from_dict (kwargs ) != model_metadata :
112+ err = dedent (f"""
113+ OMLMD intermediate metadata files found at '{ meta_path } '.
114+ Cannot resolve with conflicting keyword args: { kwargs } .
115+ You can reuse the existing metadata by omitting any keywords.
116+ If that was NOT intended, please REMOVE that file from your environment before re-running.
117+
118+ Note for advanced users: if merging keys with existing metadata is desired, you should create a Feature Request upstream: https://github.com/containers/omlmd""" )
119+ raise RuntimeError (err )
66120 else :
67121 model_metadata = ModelMetadata .from_dict (kwargs )
68- meta_path .write_text (model_metadata .to_json ())
122+ meta_path .write_text (json .dumps (model_metadata .to_dict ()))
123+
124+ manifest_path = path .parent / "manifest.json"
125+ model : DeferredLayer | None = None
126+ meta : DeferredLayer | None = None
127+ if not as_artifact :
128+ manifest_path .write_text (
129+ json .dumps (
130+ {
131+ "architecture" : get_arch (),
132+ "os" : "linux" ,
133+ }
134+ )
135+ )
136+ config = DeferredLayer .raw (manifest_path , MIME_MANIFEST_CONFIG )
137+ model = DeferredLayer .blob (path )
138+ meta = DeferredLayer .blob (meta_path , gz = True )
139+ else :
140+ manifest_path .write_text (
141+ json .dumps (
142+ {
143+ "artifactType" : MIME_APPLICATION_MLMODEL ,
144+ }
145+ )
146+ )
147+ config = DeferredLayer .raw (manifest_path , MIME_APPLICATION_MLMODEL )
148+ model = DeferredLayer .raw (path , MIME_APPLICATION_MLMODEL )
149+ meta = DeferredLayer .raw (meta_path , MIME_APPLICATION_MLMETADATA )
150+ meta .owned = owns_meta
69151
70- config = f"{ meta_path } :{ MIME_APPLICATION_CONFIG } "
71- files = [
72- f"{ path } :{ MIME_APPLICATION_MLMODEL } " ,
152+ layers = [
73153 config ,
154+ model ,
155+ meta ,
74156 ]
75157 try :
76- # print(target, files, model_metadata.to_annotations_dict())
77158 result = self ._registry .push (
78159 target = target ,
79- files = files ,
160+ files = [ lay . as_layer () for lay in layers ] ,
80161 manifest_annotations = model_metadata .to_annotations_dict (),
81- manifest_config = config ,
162+ manifest_config = config . as_layer () ,
82163 do_chunked = True ,
83164 )
84- self .notify_listeners (
85- PushEvent .from_response (result , target , model_metadata )
86- )
87- return result
88165 finally :
89- if owns_meta :
166+ for lay in layers :
167+ if lay .owned :
168+ lay .dest .unlink ()
169+ if owns_meta and meta_path .exists ():
90170 meta_path .unlink ()
91171
172+ self .notify_listeners (PushEvent .from_response (result , target , model_metadata ))
173+ return result
174+
92175 def pull (
93176 self , target : str , outdir : Path | str , media_types : Sequence [str ] | None = None
94177 ):
0 commit comments