Skip to content

Commit f5a9249

Browse files
committed
refactor: update model logging approach and clean up utility functions
1 parent 194fd31 commit f5a9249

File tree

4 files changed

+45
-143
lines changed

4 files changed

+45
-143
lines changed

deep-learning/classification-with-keras/notebooks/register-model.ipynb

Lines changed: 40 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@
124124
},
125125
{
126126
"cell_type": "code",
127-
"execution_count": 5,
127+
"execution_count": null,
128128
"id": "006f7912",
129129
"metadata": {},
130130
"outputs": [
@@ -172,7 +172,10 @@
172172
"# ------------------------ Utils Import ------------------------\n",
173173
"sys.path.append(\"../src\")\n",
174174
"from onnx_utils import ModelExportConfig\n",
175-
"from utils import load_config\n"
175+
"from utils import load_config\n",
176+
"\n",
177+
"# ------------------------ MLflow Logger Import ------------------------\n",
178+
"from src.mlflow import Logger"
176179
]
177180
},
178181
{
@@ -243,7 +246,7 @@
243246
},
244247
{
245248
"cell_type": "code",
246-
"execution_count": 9,
249+
"execution_count": null,
247250
"id": "9c6d53c6",
248251
"metadata": {},
249252
"outputs": [
@@ -257,116 +260,39 @@
257260
}
258261
],
259262
"source": [
260-
"class MNISTModel(mlflow.pyfunc.PythonModel):\n",
261-
" def load_context(self, context):\n",
262-
" \"\"\"\n",
263-
" Load model and configuration from artifacts.\n",
264-
" \"\"\"\n",
265-
" try:\n",
266-
" import tensorflow as tf\n",
267-
" import yaml\n",
268-
" \n",
269-
" # Load the model from artifacts\n",
270-
" self.model = tf.keras.models.load_model(context.artifacts[\"mnist_model\"])\n",
271-
" \n",
272-
" # Load configuration from artifacts (like vanilla-rag does)\n",
273-
" config_path = context.artifacts[\"config\"]\n",
274-
" with open(config_path, 'r') as f:\n",
275-
" self.config = yaml.safe_load(f)\n",
276-
" \n",
277-
" logger.info(\"✅ Model and configuration loaded successfully\")\n",
278-
"\n",
279-
" except Exception as e:\n",
280-
" logger.error(f\"❌ Error loading context: {str(e)}\")\n",
281-
" raise\n",
282-
"\n",
283-
" def predict(self, context, model_input, params=None):\n",
284-
" \"\"\"\n",
285-
" Computes the predicted digit, by converting the base64 to a numpy array.\n",
286-
" \"\"\"\n",
287-
" try:\n",
288-
" if isinstance(model_input, pd.DataFrame):\n",
289-
" image_input = model_input.iloc[0, 0]\n",
290-
" elif isinstance(model_input, list):\n",
291-
" image_input = model_input[0]\n",
292-
" else:\n",
293-
" image_input = str(model_input)\n",
294-
" \n",
295-
" base64_array = base64_to_numpy(image_input)\n",
296-
"\n",
297-
" predictions = self.model.predict(base64_array)\n",
298-
" predicted_classes = np.argmax(predictions, axis=1)\n",
299-
" \n",
300-
" return predicted_classes.tolist()\n",
263+
"def log_mnist_model_to_mlflow(model, artifact_path, config_path, demo_folder):\n",
264+
" \"\"\"\n",
265+
" Log MNIST model to MLflow using the new Logger approach.\n",
266+
" \"\"\"\n",
267+
" try:\n",
268+
" # Define input and output schema\n",
269+
" input_schema = Schema([\n",
270+
" ColSpec(\"string\", name=\"digit\"),\n",
271+
" ])\n",
272+
" output_schema = Schema([\n",
273+
" ColSpec(\"long\", name=\"prediction\"),\n",
274+
" ])\n",
301275
" \n",
302-
" except Exception as e:\n",
303-
" logger.error(f\"❌ Error performing prediction: {str(e)}\")\n",
304-
" raise\n",
305-
" \n",
306-
" @classmethod\n",
307-
" def log_model(cls,artifact_path, config_path, demo_folder):\n",
308-
" \"\"\"\n",
309-
" Logs the model to MLflow with appropriate artifacts and schema.\n",
310-
" Now uses in-memory model loading for ONNX export efficiency.\n",
311-
" \"\"\"\n",
276+
" # Define model signature\n",
277+
" signature = ModelSignature(inputs=input_schema, outputs=output_schema)\n",
312278
" \n",
313-
" try:\n",
314-
" sys.path.append(\"../src\")\n",
315-
" from onnx_utils import ModelExportConfig,log_model\n",
316-
" \n",
317-
" # Define input and output schema\n",
318-
" input_schema = Schema([\n",
319-
" ColSpec(\"string\", name=\"digit\"),\n",
320-
" ])\n",
321-
" output_schema = Schema([\n",
322-
" ColSpec(\"long\", name=\"prediction\"),\n",
323-
" ])\n",
324-
" \n",
325-
" # Define model signature\n",
326-
" signature = ModelSignature(inputs=input_schema, outputs=output_schema)\n",
327-
" \n",
328-
" # Save the model to disk for artifacts (still needed for MLflow artifacts)\n",
329-
" model.save(MODEL_PATH)\n",
330-
" \n",
331-
" artifacts = {\n",
332-
" \"mnist_model\": MODEL_PATH,\n",
333-
" \"config\": config_path\n",
334-
" }\n",
335-
" \n",
336-
" if demo_folder and os.path.exists(demo_folder):\n",
337-
" artifacts[\"demo\"] = demo_folder\n",
338-
" logger.info(f\"✅ Demo folder added to artifacts: {demo_folder}\")\n",
339-
" \n",
340-
" \n",
341-
" model_configs = [\n",
342-
" ModelExportConfig(\n",
343-
" model=model, # 🚀 Pre-loaded Keras model object!\n",
344-
" model_name=\"keras_mnist_onnx\", # ONNX file naming\n",
345-
" input_sample = np.random.random((1, 28, 28, 1)).astype(np.float32)\n",
346-
" \n",
347-
" ) \n",
348-
" ]\n",
349-
"\n",
350-
" log_model(\n",
351-
" artifact_path=artifact_path,\n",
352-
" python_model=cls(),\n",
353-
" artifacts=artifacts,\n",
354-
" signature=signature,\n",
355-
" models_to_convert_onnx=model_configs,\n",
356-
" pip_requirements=[\n",
357-
" \"tensorflow>=2.0.0\",\n",
358-
" \"numpy\",\n",
359-
" \"pillow\",\n",
360-
" \"streamlit>=1.28.0\",\n",
361-
" \"pyyaml\"\n",
362-
" ]\n",
363-
" )\n",
364-
" \n",
365-
" logger.info(\"✅ Model and artifacts successfully registered in MLflow\")\n",
279+
" # Save the model to disk for artifacts\n",
280+
" model.save(MODEL_PATH)\n",
281+
" \n",
282+
" # Use the new Logger to log the model\n",
283+
" Logger.log_model(\n",
284+
" artifact_path=artifact_path,\n",
285+
" config_path=config_path,\n",
286+
" demo_folder=demo_folder,\n",
287+
" signature=signature,\n",
288+
" model_path=MODEL_PATH\n",
289+
" )\n",
290+
" \n",
291+
" logger.info(\"✅ Model and artifacts successfully registered in MLflow using new Logger\")\n",
366292
"\n",
367-
" except Exception as e:\n",
368-
" logger.error(f\"❌ Error logging model: {str(e)}\")\n",
369-
" raise"
293+
" except Exception as e:\n",
294+
" logger.error(f\"❌ Error logging model: {str(e)}\")\n",
295+
" raise"
370296
]
371297
},
372298
{
@@ -505,7 +431,7 @@
505431
},
506432
{
507433
"cell_type": "code",
508-
"execution_count": 12,
434+
"execution_count": null,
509435
"id": "326467bc",
510436
"metadata": {},
511437
"outputs": [
@@ -630,8 +556,9 @@
630556
" \n",
631557
" logger.info(f\"📊 Test accuracy: {test_accuracy:.4f}\")\n",
632558
"\n",
633-
" # Log the model to MLflow using vanilla-rag pattern\n",
634-
" MNISTModel.log_model(\n",
559+
" # Log the model to MLflow using new Logger approach\n",
560+
" log_mnist_model_to_mlflow(\n",
561+
" model=model,\n",
635562
" artifact_path=MODEL_NAME,\n",
636563
" config_path=CONFIG_PATH,\n",
637564
" demo_folder=DEMO_FOLDER\n",

deep-learning/classification-with-keras/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ onnx>=1.14.0
22
onnxruntime>=1.15.0
33
numpy>=1.21.0
44
tf2onnx>=1.15.0
5-
mlflow==2.21.2
5+
mlflow==3.1.0 #TODO: remove me when you finish

deep-learning/classification-with-keras/src/mlflow/loader.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,22 @@ def _load_pyfunc(data_path: str):
2626
Model: Initialized MNIST model instance ready for prediction
2727
"""
2828
from src.mlflow.model import Model
29-
from src.utils import load_config, get_model_path
29+
from src.utils import load_config
3030

3131
logger.info(f"Loading MNIST Model from artifacts at: {data_path}")
3232

33-
# Set MODEL_ARTIFACTS_PATH for model loading
34-
os.environ["MODEL_ARTIFACTS_PATH"] = data_path
35-
3633
config_path = os.path.join(data_path, "config.yaml")
3734
if not os.path.exists(config_path):
3835
raise FileNotFoundError(f"Configuration file not found at: {config_path}")
3936

4037
config = load_config(config_path)
4138
logger.info("Configuration loaded successfully")
4239

43-
# Get model path from artifacts
44-
model_path = get_model_path("model_keras_mnist.keras")
40+
# The model file is always named "model_keras_mnist.keras" in artifacts
41+
model_path = os.path.join(data_path, "model_keras_mnist.keras")
4542

4643
if not os.path.exists(model_path):
47-
# Fallback to direct path in artifacts
48-
model_path = os.path.join(data_path, "model_keras_mnist.keras")
49-
if not os.path.exists(model_path):
50-
raise FileNotFoundError(f"Model file not found at: {model_path}")
44+
raise FileNotFoundError(f"Model file not found at: {model_path}")
5145

5246
logger.info(f"Model path resolved to: {model_path}")
5347

deep-learning/classification-with-keras/src/utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,3 @@ def get_ports_config(config: Dict[str, Any]) -> Dict[str, Any]:
8383
Ports configuration dictionary.
8484
"""
8585
return config.get("ports", {})
86-
87-
88-
def get_model_path(model_name: str) -> str:
89-
"""
90-
Get the full path to the model file using the artifacts path and model name.
91-
92-
Args:
93-
model_name: Name of the model file or full path (will extract filename)
94-
95-
Returns:
96-
Full path to the model file
97-
"""
98-
# Extract just the filename if model_name contains a path
99-
filename = os.path.basename(model_name)
100-
101-
artifacts_path = os.environ.get("MODEL_ARTIFACTS_PATH", "")
102-
model_path = os.path.join(artifacts_path, filename)
103-
104-
return model_path

0 commit comments

Comments
 (0)