@@ -241,42 +241,42 @@ def test_caching(self):
241241 batch_from_disk = next (iter (cached_train_loader_from_disk ))
242242
243243 # Features are the same
244- assert torch . equal (batch ["features" ].edge_index , batch_from_ram ["features" ].edge_index )
245- assert torch . equal (batch ["features" ].edge_index , batch_from_disk ["features" ].edge_index )
244+ np . testing . assert_array_almost_equal (batch ["features" ].edge_index , batch_from_ram ["features" ].edge_index )
245+ np . testing . assert_array_almost_equal (batch ["features" ].edge_index , batch_from_disk ["features" ].edge_index )
246246
247247 assert batch ["features" ].num_nodes == batch_from_ram ["features" ].num_nodes
248248 assert batch ["features" ].num_nodes == batch_from_disk ["features" ].num_nodes
249249
250- assert torch . equal (batch ["features" ].edge_weight , batch_from_ram ["features" ].edge_weight )
251- assert torch . equal (batch ["features" ].edge_weight , batch_from_disk ["features" ].edge_weight )
250+ np . testing . assert_array_almost_equal (batch ["features" ].edge_weight , batch_from_ram ["features" ].edge_weight )
251+ np . testing . assert_array_almost_equal (batch ["features" ].edge_weight , batch_from_disk ["features" ].edge_weight )
252252
253- assert torch . equal (batch ["features" ].feat , batch_from_ram ["features" ].feat )
254- assert torch . equal (batch ["features" ].feat , batch_from_disk ["features" ].feat )
253+ np . testing . assert_array_almost_equal (batch ["features" ].feat , batch_from_ram ["features" ].feat )
254+ np . testing . assert_array_almost_equal (batch ["features" ].feat , batch_from_disk ["features" ].feat )
255255
256- assert torch . equal (batch ["features" ].edge_feat , batch_from_ram ["features" ].edge_feat )
257- assert torch . equal (batch ["features" ].edge_feat , batch_from_disk ["features" ].edge_feat )
256+ np . testing . assert_array_almost_equal (batch ["features" ].edge_feat , batch_from_ram ["features" ].edge_feat )
257+ np . testing . assert_array_almost_equal (batch ["features" ].edge_feat , batch_from_disk ["features" ].edge_feat )
258258
259- assert torch . equal (batch ["features" ].batch , batch_from_ram ["features" ].batch )
260- assert torch . equal (batch ["features" ].batch , batch_from_disk ["features" ].batch )
259+ np . testing . assert_array_almost_equal (batch ["features" ].batch , batch_from_ram ["features" ].batch )
260+ np . testing . assert_array_almost_equal (batch ["features" ].batch , batch_from_disk ["features" ].batch )
261261
262- assert torch . equal (batch ["features" ].ptr , batch_from_ram ["features" ].ptr )
263- assert torch . equal (batch ["features" ].ptr , batch_from_disk ["features" ].ptr )
262+ np . testing . assert_array_almost_equal (batch ["features" ].ptr , batch_from_ram ["features" ].ptr )
263+ np . testing . assert_array_almost_equal (batch ["features" ].ptr , batch_from_disk ["features" ].ptr )
264264
265265 # Labels are the same
266- assert torch . equal (batch ["labels" ].graph_task_1 , batch_from_ram ["labels" ].graph_task_1 )
267- assert torch . equal (batch ["labels" ].graph_task_1 , batch_from_disk ["labels" ].graph_task_1 )
266+ np . testing . assert_array_almost_equal (batch ["labels" ].graph_task_1 , batch_from_ram ["labels" ].graph_task_1 )
267+ np . testing . assert_array_almost_equal (batch ["labels" ].graph_task_1 , batch_from_disk ["labels" ].graph_task_1 )
268268
269- assert torch . equal (batch ["labels" ].x , batch_from_ram ["labels" ].x )
270- assert torch . equal (batch ["labels" ].x , batch_from_disk ["labels" ].x )
269+ np . testing . assert_array_almost_equal (batch ["labels" ].x , batch_from_ram ["labels" ].x )
270+ np . testing . assert_array_almost_equal (batch ["labels" ].x , batch_from_disk ["labels" ].x )
271271
272- assert torch . equal (batch ["labels" ].edge_index , batch_from_ram ["labels" ].edge_index )
273- assert torch . equal (batch ["labels" ].edge_index , batch_from_disk ["labels" ].edge_index )
272+ np . testing . assert_array_almost_equal (batch ["labels" ].edge_index , batch_from_ram ["labels" ].edge_index )
273+ np . testing . assert_array_almost_equal (batch ["labels" ].edge_index , batch_from_disk ["labels" ].edge_index )
274274
275- assert torch . equal (batch ["labels" ].batch , batch_from_ram ["labels" ].batch )
276- assert torch . equal (batch ["labels" ].batch , batch_from_disk ["labels" ].batch )
275+ np . testing . assert_array_almost_equal (batch ["labels" ].batch , batch_from_ram ["labels" ].batch )
276+ np . testing . assert_array_almost_equal (batch ["labels" ].batch , batch_from_disk ["labels" ].batch )
277277
278- assert torch . equal (batch ["labels" ].ptr , batch_from_ram ["labels" ].ptr )
279- assert torch . equal (batch ["labels" ].ptr , batch_from_disk ["labels" ].ptr )
278+ np . testing . assert_array_almost_equal (batch ["labels" ].ptr , batch_from_ram ["labels" ].ptr )
279+ np . testing . assert_array_almost_equal (batch ["labels" ].ptr , batch_from_disk ["labels" ].ptr )
280280
281281 # Delete the cache if already exist
282282 if exists (TEMP_CACHE_DATA_PATH ):
@@ -422,7 +422,7 @@ def test_datamodule_multiple_data_files(self):
422422 "task" : {"task_level" : "graph" , "label_cols" : ["score" ], "smiles_col" : "SMILES" , ** task_kwargs }
423423 }
424424
425- ds = MultitaskFromSmilesDataModule (task_specific_args )
425+ ds = MultitaskFromSmilesDataModule (task_specific_args , featurization_n_jobs = 0 )
426426 ds .prepare_data ()
427427 ds .setup ()
428428
@@ -435,7 +435,7 @@ def test_datamodule_multiple_data_files(self):
435435 "task" : {"task_level" : "graph" , "label_cols" : ["score" ], "smiles_col" : "SMILES" , ** task_kwargs }
436436 }
437437
438- ds = MultitaskFromSmilesDataModule (task_specific_args )
438+ ds = MultitaskFromSmilesDataModule (task_specific_args , featurization_n_jobs = 0 )
439439 ds .prepare_data ()
440440 ds .setup ()
441441
@@ -448,7 +448,7 @@ def test_datamodule_multiple_data_files(self):
448448 "task" : {"task_level" : "graph" , "label_cols" : ["score" ], "smiles_col" : "SMILES" , ** task_kwargs }
449449 }
450450
451- ds = MultitaskFromSmilesDataModule (task_specific_args )
451+ ds = MultitaskFromSmilesDataModule (task_specific_args , featurization_n_jobs = 0 )
452452 ds .prepare_data ()
453453 ds .setup ()
454454
@@ -461,7 +461,7 @@ def test_datamodule_multiple_data_files(self):
461461 "task" : {"task_level" : "graph" , "label_cols" : ["score" ], "smiles_col" : "SMILES" , ** task_kwargs }
462462 }
463463
464- ds = MultitaskFromSmilesDataModule (task_specific_args )
464+ ds = MultitaskFromSmilesDataModule (task_specific_args , featurization_n_jobs = 0 )
465465 ds .prepare_data ()
466466 ds .setup ()
467467
0 commit comments