@@ -481,7 +481,7 @@ def log_precision_recall(
481
481
Trial Component as an output artifact. If False will be an input artifact.
482
482
483
483
Raises:
484
- ValueError: If mismatch between y_true and predicted_probabilities.
484
+ ValueError: If length mismatch between y_true and predicted_probabilities.
485
485
"""
486
486
487
487
if len (y_true ) != len (predicted_probabilities ):
@@ -542,7 +542,7 @@ def log_roc_curve(
542
542
"""
543
543
544
544
if len (y_true ) != len (y_score ):
545
- raise ValueError ("Mismatch between actual labels and predicted scores." )
545
+ raise ValueError ("Length mismatch between actual labels and predicted scores." )
546
546
547
547
get_module ("sklearn" )
548
548
from sklearn .metrics import roc_curve , auc
@@ -561,6 +561,50 @@ def log_roc_curve(
561
561
}
562
562
self ._log_graph_artifact (title , data , "ROCCurve" , output_artifact )
563
563
564
+ def log_confusion_matrix (
565
+ self ,
566
+ y_true ,
567
+ y_pred ,
568
+ title = None ,
569
+ output_artifact = True ,
570
+ ):
571
+ """Log a confusion matrix artifact which will be displayed in
572
+ studio. Requires sklearn.
573
+
574
+ Note that this method must be run from a SageMaker context such as studio or training job
575
+ due to restrictions on the CreateArtifact API.
576
+
577
+ Examples
578
+ .. code-block:: python
579
+
580
+ y_true = [2, 0, 2, 2, 0, 1]
581
+ y_pred = [0, 0, 2, 2, 0, 2]
582
+
583
+ my_tracker.log_confusion_matrix(y_true, y_pred)
584
+
585
+
586
+ Args:
587
+ y_true (array): True labels. If labels are not binary then positive_label should be given.
588
+ y_pred (array): Predicted labels.
589
+ title (str, optional): Title of the graph, Defaults to none.
590
+ output_artifact (boolean, optional): Determines if the artifact is associated with the
591
+ Trial Component as an output artifact. If False will be an input artifact.
592
+
593
+ Raises:
594
+ ValueError: If length mismatch between y_true and y_pred.
595
+ """
596
+
597
+ if len (y_true ) != len (y_pred ):
598
+ raise ValueError ("Length mismatch between actual labels and predicted labels." )
599
+
600
+ get_module ("sklearn" )
601
+ from sklearn .metrics import confusion_matrix
602
+
603
+ matrix = confusion_matrix (y_true , y_pred )
604
+
605
+ data = {"type" : "ConfusionMatrix" , "version" : 0 , "title" : title , "confusionMatrix" : matrix .tolist ()}
606
+ self ._log_graph_artifact (title , data , "ConfusionMatrix" , output_artifact )
607
+
564
608
def _log_graph_artifact (self , name , data , graph_type , output_artifact ):
565
609
"""Logs an artifact by uploading data to S3, creating an artifact, and associating that
566
610
artifact with the tracker's Trial Component.
0 commit comments