We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents f9fbbab + 3e436c7 commit 2f3c23cCopy full SHA for 2f3c23c
algoperf/workloads/ogbg/metrics.py
@@ -37,6 +37,7 @@ def compute(self):
37
labels = values['labels']
38
logits = values['logits']
39
mask = values['mask']
40
+ sigmoid = jax.nn.sigmoid
41
42
if USE_PYTORCH_DDP:
43
# Sync labels, logits, and masks across devices.
@@ -49,9 +50,14 @@ def compute(self):
49
50
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
51
labels, logits, mask = all_values
52
53
+ def sigmoid_np(x):
54
+ return 1 / (1 + np.exp(-x))
55
+
56
+ sigmoid = sigmoid_np
57
58
mask = mask.astype(bool)
59
- probs = jax.nn.sigmoid(logits)
60
+ probs = sigmoid(logits)
61
num_tasks = labels.shape[1]
62
average_precisions = np.full(num_tasks, np.nan)
63
0 commit comments