Skip to content

Commit 4240bbb

Browse files
author
WenkelF
committed
Fixing datamodule unit test and unit test speedup
1 parent 7665b67 commit 4240bbb

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

tests/test_datamodule.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -241,42 +241,42 @@ def test_caching(self):
241241
batch_from_disk = next(iter(cached_train_loader_from_disk))
242242

243243
# Features are the same
244-
assert torch.equal(batch["features"].edge_index, batch_from_ram["features"].edge_index)
245-
assert torch.equal(batch["features"].edge_index, batch_from_disk["features"].edge_index)
244+
np.testing.assert_array_almost_equal(batch["features"].edge_index, batch_from_ram["features"].edge_index)
245+
np.testing.assert_array_almost_equal(batch["features"].edge_index, batch_from_disk["features"].edge_index)
246246

247247
assert batch["features"].num_nodes == batch_from_ram["features"].num_nodes
248248
assert batch["features"].num_nodes == batch_from_disk["features"].num_nodes
249249

250-
assert torch.equal(batch["features"].edge_weight, batch_from_ram["features"].edge_weight)
251-
assert torch.equal(batch["features"].edge_weight, batch_from_disk["features"].edge_weight)
250+
np.testing.assert_array_almost_equal(batch["features"].edge_weight, batch_from_ram["features"].edge_weight)
251+
np.testing.assert_array_almost_equal(batch["features"].edge_weight, batch_from_disk["features"].edge_weight)
252252

253-
assert torch.equal(batch["features"].feat, batch_from_ram["features"].feat)
254-
assert torch.equal(batch["features"].feat, batch_from_disk["features"].feat)
253+
np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_ram["features"].feat)
254+
np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_disk["features"].feat)
255255

256-
assert torch.equal(batch["features"].edge_feat, batch_from_ram["features"].edge_feat)
257-
assert torch.equal(batch["features"].edge_feat, batch_from_disk["features"].edge_feat)
256+
np.testing.assert_array_almost_equal(batch["features"].edge_feat, batch_from_ram["features"].edge_feat)
257+
np.testing.assert_array_almost_equal(batch["features"].edge_feat, batch_from_disk["features"].edge_feat)
258258

259-
assert torch.equal(batch["features"].batch, batch_from_ram["features"].batch)
260-
assert torch.equal(batch["features"].batch, batch_from_disk["features"].batch)
259+
np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_ram["features"].batch)
260+
np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_disk["features"].batch)
261261

262-
assert torch.equal(batch["features"].ptr, batch_from_ram["features"].ptr)
263-
assert torch.equal(batch["features"].ptr, batch_from_disk["features"].ptr)
262+
np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_ram["features"].ptr)
263+
np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_disk["features"].ptr)
264264

265265
# Labels are the same
266-
assert torch.equal(batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1)
267-
assert torch.equal(batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1)
266+
np.testing.assert_array_almost_equal(batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1)
267+
np.testing.assert_array_almost_equal(batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1)
268268

269-
assert torch.equal(batch["labels"].x, batch_from_ram["labels"].x)
270-
assert torch.equal(batch["labels"].x, batch_from_disk["labels"].x)
269+
np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_ram["labels"].x)
270+
np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_disk["labels"].x)
271271

272-
assert torch.equal(batch["labels"].edge_index, batch_from_ram["labels"].edge_index)
273-
assert torch.equal(batch["labels"].edge_index, batch_from_disk["labels"].edge_index)
272+
np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_ram["labels"].edge_index)
273+
np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_disk["labels"].edge_index)
274274

275-
assert torch.equal(batch["labels"].batch, batch_from_ram["labels"].batch)
276-
assert torch.equal(batch["labels"].batch, batch_from_disk["labels"].batch)
275+
np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_ram["labels"].batch)
276+
np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_disk["labels"].batch)
277277

278-
assert torch.equal(batch["labels"].ptr, batch_from_ram["labels"].ptr)
279-
assert torch.equal(batch["labels"].ptr, batch_from_disk["labels"].ptr)
278+
np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_ram["labels"].ptr)
279+
np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_disk["labels"].ptr)
280280

281281
# Delete the cache if already exist
282282
if exists(TEMP_CACHE_DATA_PATH):
@@ -422,7 +422,7 @@ def test_datamodule_multiple_data_files(self):
422422
"task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs}
423423
}
424424

425-
ds = MultitaskFromSmilesDataModule(task_specific_args)
425+
ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0)
426426
ds.prepare_data()
427427
ds.setup()
428428

@@ -435,7 +435,7 @@ def test_datamodule_multiple_data_files(self):
435435
"task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs}
436436
}
437437

438-
ds = MultitaskFromSmilesDataModule(task_specific_args)
438+
ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0)
439439
ds.prepare_data()
440440
ds.setup()
441441

@@ -448,7 +448,7 @@ def test_datamodule_multiple_data_files(self):
448448
"task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs}
449449
}
450450

451-
ds = MultitaskFromSmilesDataModule(task_specific_args)
451+
ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0)
452452
ds.prepare_data()
453453
ds.setup()
454454

@@ -461,7 +461,7 @@ def test_datamodule_multiple_data_files(self):
461461
"task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs}
462462
}
463463

464-
ds = MultitaskFromSmilesDataModule(task_specific_args)
464+
ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0)
465465
ds.prepare_data()
466466
ds.setup()
467467

0 commit comments

Comments
 (0)