Skip to content

Commit 11499bf

Browse files
committed
feature: add log confusion matrix
1 parent 817a97b commit 11499bf

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

src/smexperiments/tracker.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def log_precision_recall(
481481
Trial Component as an output artifact. If False will be an input artifact.
482482
483483
Raises:
484-
ValueError: If mismatch between y_true and predicted_probabilities.
484+
ValueError: If length mismatch between y_true and predicted_probabilities.
485485
"""
486486

487487
if len(y_true) != len(predicted_probabilities):
@@ -542,7 +542,7 @@ def log_roc_curve(
542542
"""
543543

544544
if len(y_true) != len(y_score):
545-
raise ValueError("Mismatch between actual labels and predicted scores.")
545+
raise ValueError("Length mismatch between actual labels and predicted scores.")
546546

547547
get_module("sklearn")
548548
from sklearn.metrics import roc_curve, auc
@@ -561,6 +561,50 @@ def log_roc_curve(
561561
}
562562
self._log_graph_artifact(title, data, "ROCCurve", output_artifact)
563563

564+
def log_confusion_matrix(
565+
self,
566+
y_true,
567+
y_pred,
568+
title=None,
569+
output_artifact=True,
570+
):
571+
"""Log a confusion matrix artifact which will be displayed in
572+
studio. Requires sklearn.
573+
574+
Note that this method must be run from a SageMaker context such as studio or training job
575+
due to restrictions on the CreateArtifact API.
576+
577+
Examples
578+
.. code-block:: python
579+
580+
y_true = [2, 0, 2, 2, 0, 1]
581+
y_pred = [0, 0, 2, 2, 0, 2]
582+
583+
my_tracker.log_confusion_matrix(y_true, y_pred)
584+
585+
586+
Args:
587+
y_true (array): True labels. If labels are not binary then positive_label should be given.
588+
y_pred (array): Predicted labels.
589+
title (str, optional): Title of the graph, Defaults to none.
590+
output_artifact (boolean, optional): Determines if the artifact is associated with the
591+
Trial Component as an output artifact. If False will be an input artifact.
592+
593+
Raises:
594+
ValueError: If length mismatch between y_true and y_pred.
595+
"""
596+
597+
if len(y_true) != len(y_pred):
598+
raise ValueError("Length mismatch between actual labels and predicted labels.")
599+
600+
get_module("sklearn")
601+
from sklearn.metrics import confusion_matrix
602+
603+
matrix = confusion_matrix(y_true, y_pred)
604+
605+
data = {"type": "ConfusionMatrix", "version": 0, "title": title, "confusionMatrix": matrix.tolist()}
606+
self._log_graph_artifact(title, data, "ConfusionMatrix", output_artifact)
607+
564608
def _log_graph_artifact(self, name, data, graph_type, output_artifact):
565609
"""Logs an artifact by uploading data to S3, creating an artifact, and associating that
566610
artifact with the tracker's Trial Component.

tests/unit/test_tracker.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,31 @@ def test_log_pr_curve(under_test):
368368
)
369369

370370

371+
def test_log_confusion_matrix(under_test):
372+
373+
y_true = [2, 0, 2, 2, 0, 1]
374+
y_pred = [0, 0, 2, 2, 0, 2]
375+
376+
under_test._artifact_uploader.upload_object_artifact.return_value = ("s3uri_value", "etag_value")
377+
378+
under_test.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix")
379+
380+
expected_data = {
381+
"type": "ConfusionMatrix",
382+
"version": 0,
383+
"title": "TestConfusionMatrix",
384+
"confusionMatrix": [[2, 0, 0], [0, 0, 1], [1, 0, 2]],
385+
}
386+
387+
under_test._artifact_uploader.upload_object_artifact.assert_called_with(
388+
"TestConfusionMatrix", expected_data, file_extension="json"
389+
)
390+
391+
under_test._lineage_artifact_tracker.add_input_artifact(
392+
"TestConfusionMatrix", "s3uri_value", "etag_value", "ConfusionMatrix"
393+
)
394+
395+
371396
def test_resolve_artifact_name():
372397
file_names = {
373398
"a": "a",

0 commit comments

Comments
 (0)