Skip to content

Commit 9df55c0

Browse files
committed
add no skill to P/R curve graph artifact
1 parent aeeace7 commit 9df55c0

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

src/smexperiments/tracker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,7 @@ def log_precision_recall(
467467
positive_label=None,
468468
title=None,
469469
output_artifact=True,
470+
no_skill=None,
470471
):
471472
"""Log a precision recall graph artifact which will be displayed in studio.
472473
Requires sklearn. Not yet supported by studio.
@@ -479,8 +480,9 @@ def log_precision_recall(
479480
480481
y_true = [0, 0, 1, 1]
481482
y_scores = [0.1, 0.4, 0.35, 0.8]
483+
no_skill = len(y_true[y_true==1]) / len(y_true)
482484
483-
my_tracker._log_precision_recall(y_true, y_scores)
485+
my_tracker._log_precision_recall(y_true, y_scores, no_skill=no_skill)
484486
485487
Args:
486488
y_true (array): True labels. If labels are not binary then positive_label should be given.
@@ -489,6 +491,9 @@ def log_precision_recall(
489491
title (str, optional): Title of the graph, Defaults to none.
490492
output_artifact (boolean, optional): Determines if the artifact is associated with the
491493
Trial Component as an output artifact. If False will be an input artifact.
494+
no_skill (int): The precision threshold under which the classifier cannot discriminate
495+
between the classes and would predict a random class or a constant class in
496+
all cases.
492497
493498
Raises:
494499
ValueError: If length mismatch between y_true and predicted_probabilities.
@@ -516,6 +521,7 @@ def log_precision_recall(
516521
"precision": precision.tolist(),
517522
"recall": recall.tolist(),
518523
"averagePrecisionScore": ap,
524+
"noSkill": no_skill,
519525
}
520526
self._log_graph_artifact(title, data, "PrecisionRecallCurve", output_artifact)
521527

tests/unit/artifact_schemas/precision_recall_curve_v0.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
"averagePrecisionScore": {
2929
"description": "AP summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold, with the increase in recall from the previous threshold used as the weight.",
3030
"type": "number"
31+
},
32+
"noSkill": {
33+
"description": "The precision threshold under which the classifier cannot discriminate between the classes and would predict a random class or a constant class in all cases.",
34+
"type": "number"
3135
}
3236
},
3337
"additionalProperties": false

tests/unit/test_tracker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,11 @@ def test_log_pr_curve(under_test):
341341

342342
y_true = [0, 0, 1, 1]
343343
y_scores = [0.1, 0.4, 0.35, 0.8]
344+
no_skill = 0.1
344345

345346
under_test._artifact_uploader.upload_object_artifact.return_value = ("s3uri_value", "etag_value")
346347

347-
under_test.log_precision_recall(y_true, y_scores, title="TestPRCurve")
348+
under_test.log_precision_recall(y_true, y_scores, title="TestPRCurve", no_skill=no_skill)
348349

349350
expected_data = {
350351
"type": "PrecisionRecallCurve",
@@ -353,6 +354,7 @@ def test_log_pr_curve(under_test):
353354
"precision": [0.6666666666666666, 0.5, 1.0, 1.0],
354355
"recall": [1.0, 0.5, 0.5, 0.0],
355356
"averagePrecisionScore": 0.8333333333333333,
357+
"noSkill": 0.1,
356358
}
357359
under_test._artifact_uploader.upload_object_artifact.assert_called_with(
358360
"TestPRCurve", expected_data, file_extension="json"

0 commit comments

Comments
 (0)