| 
124 | 124 |   },  | 
125 | 125 |   {  | 
126 | 126 |    "cell_type": "code",  | 
127 |  | -   "execution_count": 5,  | 
 | 127 | +   "execution_count": null,  | 
128 | 128 |    "id": "006f7912",  | 
129 | 129 |    "metadata": {},  | 
130 | 130 |    "outputs": [  | 
 | 
172 | 172 |     "# ------------------------ Utils Import ------------------------\n",  | 
173 | 173 |     "sys.path.append(\"../src\")\n",  | 
174 | 174 |     "from onnx_utils import ModelExportConfig\n",  | 
175 |  | -    "from utils import load_config\n"  | 
 | 175 | +    "from utils import load_config\n",  | 
 | 176 | +    "\n",  | 
 | 177 | +    "# ------------------------ MLflow Logger Import ------------------------\n",  | 
 | 178 | +    "from src.mlflow import Logger"  | 
176 | 179 |    ]  | 
177 | 180 |   },  | 
178 | 181 |   {  | 
 | 
243 | 246 |   },  | 
244 | 247 |   {  | 
245 | 248 |    "cell_type": "code",  | 
246 |  | -   "execution_count": 9,  | 
 | 249 | +   "execution_count": null,  | 
247 | 250 |    "id": "9c6d53c6",  | 
248 | 251 |    "metadata": {},  | 
249 | 252 |    "outputs": [  | 
 | 
257 | 260 |     }  | 
258 | 261 |    ],  | 
259 | 262 |    "source": [  | 
260 |  | -    "class MNISTModel(mlflow.pyfunc.PythonModel):\n",  | 
261 |  | -    "    def load_context(self, context):\n",  | 
262 |  | -    "        \"\"\"\n",  | 
263 |  | -    "        Load model and configuration from artifacts.\n",  | 
264 |  | -    "        \"\"\"\n",  | 
265 |  | -    "        try:\n",  | 
266 |  | -    "            import tensorflow as tf\n",  | 
267 |  | -    "            import yaml\n",  | 
268 |  | -    "            \n",  | 
269 |  | -    "            # Load the model from artifacts\n",  | 
270 |  | -    "            self.model = tf.keras.models.load_model(context.artifacts[\"mnist_model\"])\n",  | 
271 |  | -    "            \n",  | 
272 |  | -    "            # Load configuration from artifacts (like vanilla-rag does)\n",  | 
273 |  | -    "            config_path = context.artifacts[\"config\"]\n",  | 
274 |  | -    "            with open(config_path, 'r') as f:\n",  | 
275 |  | -    "                self.config = yaml.safe_load(f)\n",  | 
276 |  | -    "            \n",  | 
277 |  | -    "            logger.info(\"✅ Model and configuration loaded successfully\")\n",  | 
278 |  | -    "\n",  | 
279 |  | -    "        except Exception as e:\n",  | 
280 |  | -    "            logger.error(f\"❌ Error loading context: {str(e)}\")\n",  | 
281 |  | -    "            raise\n",  | 
282 |  | -    "\n",  | 
283 |  | -    "    def predict(self, context, model_input, params=None):\n",  | 
284 |  | -    "        \"\"\"\n",  | 
285 |  | -    "        Computes the predicted digit, by converting the base64 to a numpy array.\n",  | 
286 |  | -    "        \"\"\"\n",  | 
287 |  | -    "        try:\n",  | 
288 |  | -    "            if isinstance(model_input, pd.DataFrame):\n",  | 
289 |  | -    "                image_input = model_input.iloc[0, 0]\n",  | 
290 |  | -    "            elif isinstance(model_input, list):\n",  | 
291 |  | -    "                image_input = model_input[0]\n",  | 
292 |  | -    "            else:\n",  | 
293 |  | -    "                image_input = str(model_input)\n",  | 
294 |  | -    "                \n",  | 
295 |  | -    "            base64_array = base64_to_numpy(image_input)\n",  | 
296 |  | -    "\n",  | 
297 |  | -    "            predictions = self.model.predict(base64_array)\n",  | 
298 |  | -    "            predicted_classes = np.argmax(predictions, axis=1)\n",  | 
299 |  | -    "            \n",  | 
300 |  | -    "            return predicted_classes.tolist()\n",  | 
 | 263 | +    "def log_mnist_model_to_mlflow(model, artifact_path, config_path, demo_folder):\n",  | 
 | 264 | +    "    \"\"\"\n",  | 
 | 265 | +    "    Log MNIST model to MLflow using the new Logger approach.\n",  | 
 | 266 | +    "    \"\"\"\n",  | 
 | 267 | +    "    try:\n",  | 
 | 268 | +    "        # Define input and output schema\n",  | 
 | 269 | +    "        input_schema = Schema([\n",  | 
 | 270 | +    "            ColSpec(\"string\", name=\"digit\"),\n",  | 
 | 271 | +    "        ])\n",  | 
 | 272 | +    "        output_schema = Schema([\n",  | 
 | 273 | +    "            ColSpec(\"long\", name=\"prediction\"),\n",  | 
 | 274 | +    "        ])\n",  | 
301 | 275 |     "        \n",  | 
302 |  | -    "        except Exception as e:\n",  | 
303 |  | -    "            logger.error(f\"❌ Error performing prediction: {str(e)}\")\n",  | 
304 |  | -    "            raise\n",  | 
305 |  | -    "    \n",  | 
306 |  | -    "    @classmethod\n",  | 
307 |  | -    "    def log_model(cls,artifact_path, config_path, demo_folder):\n",  | 
308 |  | -    "        \"\"\"\n",  | 
309 |  | -    "        Logs the model to MLflow with appropriate artifacts and schema.\n",  | 
310 |  | -    "        Now uses in-memory model loading for ONNX export efficiency.\n",  | 
311 |  | -    "        \"\"\"\n",  | 
 | 276 | +    "        # Define model signature\n",  | 
 | 277 | +    "        signature = ModelSignature(inputs=input_schema, outputs=output_schema)\n",  | 
312 | 278 |     "        \n",  | 
313 |  | -    "        try:\n",  | 
314 |  | -    "            sys.path.append(\"../src\")\n",  | 
315 |  | -    "            from onnx_utils import ModelExportConfig,log_model\n",  | 
316 |  | -    "            \n",  | 
317 |  | -    "            # Define input and output schema\n",  | 
318 |  | -    "            input_schema = Schema([\n",  | 
319 |  | -    "                ColSpec(\"string\", name=\"digit\"),\n",  | 
320 |  | -    "            ])\n",  | 
321 |  | -    "            output_schema = Schema([\n",  | 
322 |  | -    "                ColSpec(\"long\", name=\"prediction\"),\n",  | 
323 |  | -    "            ])\n",  | 
324 |  | -    "            \n",  | 
325 |  | -    "            # Define model signature\n",  | 
326 |  | -    "            signature = ModelSignature(inputs=input_schema, outputs=output_schema)\n",  | 
327 |  | -    "            \n",  | 
328 |  | -    "            # Save the model to disk for artifacts (still needed for MLflow artifacts)\n",  | 
329 |  | -    "            model.save(MODEL_PATH)\n",  | 
330 |  | -    "          \n",  | 
331 |  | -    "            artifacts = {\n",  | 
332 |  | -    "                \"mnist_model\": MODEL_PATH,\n",  | 
333 |  | -    "                \"config\": config_path\n",  | 
334 |  | -    "            }\n",  | 
335 |  | -    "            \n",  | 
336 |  | -    "            if demo_folder and os.path.exists(demo_folder):\n",  | 
337 |  | -    "                artifacts[\"demo\"] = demo_folder\n",  | 
338 |  | -    "                logger.info(f\"✅ Demo folder added to artifacts: {demo_folder}\")\n",  | 
339 |  | -    "      \n",  | 
340 |  | -    "            \n",  | 
341 |  | -    "            model_configs = [\n",  | 
342 |  | -    "                ModelExportConfig(\n",  | 
343 |  | -    "                    model=model,                            # 🚀 Pre-loaded Keras model object!\n",  | 
344 |  | -    "                    model_name=\"keras_mnist_onnx\",          # ONNX file naming\n",  | 
345 |  | -    "                    input_sample = np.random.random((1, 28, 28, 1)).astype(np.float32)\n",  | 
346 |  | -    "                  \n",  | 
347 |  | -    "                )    \n",  | 
348 |  | -    "            ]\n",  | 
349 |  | -    "\n",  | 
350 |  | -    "            log_model(\n",  | 
351 |  | -    "                    artifact_path=artifact_path,\n",  | 
352 |  | -    "                    python_model=cls(),\n",  | 
353 |  | -    "                    artifacts=artifacts,\n",  | 
354 |  | -    "                    signature=signature,\n",  | 
355 |  | -    "                    models_to_convert_onnx=model_configs,\n",  | 
356 |  | -    "                    pip_requirements=[\n",  | 
357 |  | -    "                    \"tensorflow>=2.0.0\",\n",  | 
358 |  | -    "                    \"numpy\",\n",  | 
359 |  | -    "                    \"pillow\",\n",  | 
360 |  | -    "                    \"streamlit>=1.28.0\",\n",  | 
361 |  | -    "                    \"pyyaml\"\n",  | 
362 |  | -    "                    ]\n",  | 
363 |  | -    "                )\n",  | 
364 |  | -    "            \n",  | 
365 |  | -    "            logger.info(\"✅ Model and artifacts successfully registered in MLflow\")\n",  | 
 | 279 | +    "        # Save the model to disk for artifacts\n",  | 
 | 280 | +    "        model.save(MODEL_PATH)\n",  | 
 | 281 | +    "        \n",  | 
 | 282 | +    "        # Use the new Logger to log the model\n",  | 
 | 283 | +    "        Logger.log_model(\n",  | 
 | 284 | +    "            artifact_path=artifact_path,\n",  | 
 | 285 | +    "            config_path=config_path,\n",  | 
 | 286 | +    "            demo_folder=demo_folder,\n",  | 
 | 287 | +    "            signature=signature,\n",  | 
 | 288 | +    "            model_path=MODEL_PATH\n",  | 
 | 289 | +    "        )\n",  | 
 | 290 | +    "        \n",  | 
 | 291 | +    "        logger.info(\"✅ Model and artifacts successfully registered in MLflow using new Logger\")\n",  | 
366 | 292 |     "\n",  | 
367 |  | -    "        except Exception as e:\n",  | 
368 |  | -    "            logger.error(f\"❌ Error logging model: {str(e)}\")\n",  | 
369 |  | -    "            raise"  | 
 | 293 | +    "    except Exception as e:\n",  | 
 | 294 | +    "        logger.error(f\"❌ Error logging model: {str(e)}\")\n",  | 
 | 295 | +    "        raise"  | 
370 | 296 |    ]  | 
371 | 297 |   },  | 
372 | 298 |   {  | 
 | 
505 | 431 |   },  | 
506 | 432 |   {  | 
507 | 433 |    "cell_type": "code",  | 
508 |  | -   "execution_count": 12,  | 
 | 434 | +   "execution_count": null,  | 
509 | 435 |    "id": "326467bc",  | 
510 | 436 |    "metadata": {},  | 
511 | 437 |    "outputs": [  | 
 | 
630 | 556 |     "    \n",  | 
631 | 557 |     "    logger.info(f\"📊 Test accuracy: {test_accuracy:.4f}\")\n",  | 
632 | 558 |     "\n",  | 
633 |  | -    "    # Log the model to MLflow using vanilla-rag pattern\n",  | 
634 |  | -    "    MNISTModel.log_model(\n",  | 
 | 559 | +    "    # Log the model to MLflow using new Logger approach\n",  | 
 | 560 | +    "    log_mnist_model_to_mlflow(\n",  | 
 | 561 | +    "        model=model,\n",  | 
635 | 562 |     "        artifact_path=MODEL_NAME,\n",  | 
636 | 563 |     "        config_path=CONFIG_PATH,\n",  | 
637 | 564 |     "        demo_folder=DEMO_FOLDER\n",  | 
 | 
0 commit comments