Skip to content

Commit 2f3c23c

Browse files
Merge pull request #871 from davidtweedle/ogbg_fix
Update metrics.py - fix for ogbg pytorch
2 parents f9fbbab + 3e436c7 commit 2f3c23c

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

algoperf/workloads/ogbg/metrics.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def compute(self):
3737
labels = values['labels']
3838
logits = values['logits']
3939
mask = values['mask']
40+
sigmoid = jax.nn.sigmoid
4041

4142
if USE_PYTORCH_DDP:
4243
# Sync labels, logits, and masks across devices.
@@ -49,9 +50,14 @@ def compute(self):
4950
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
5051
labels, logits, mask = all_values
5152

53+
def sigmoid_np(x):
54+
return 1 / (1 + np.exp(-x))
55+
56+
sigmoid = sigmoid_np
57+
5258
mask = mask.astype(bool)
5359

54-
probs = jax.nn.sigmoid(logits)
60+
probs = sigmoid(logits)
5561
num_tasks = labels.shape[1]
5662
average_precisions = np.full(num_tasks, np.nan)
5763

0 commit comments

Comments
 (0)