|
3 | 3 | import numpy as np
|
4 | 4 | import torch.nn as nn
|
5 | 5 | from numpy.testing import assert_array_equal
|
| 6 | + |
| 7 | +from functorch import vmap |
6 | 8 | from torch.utils.data import TensorDataset, DataLoader
|
7 | 9 |
|
8 | 10 | import torchensemble
|
@@ -302,3 +304,79 @@ def test_predict():
|
302 | 304 | with pytest.raises(ValueError) as excinfo:
|
303 | 305 | model.predict([X_test]) # list
|
304 | 306 | assert "The type of input X should be one of" in str(excinfo.value)
|
| 307 | + |
| 308 | + |
| 309 | +@pytest.mark.parametrize("clf", all_clf) |
| 310 | +def test_clf_vectorize_same_output(clf): |
| 311 | + """ |
| 312 | + This unit test checks the inference with/without vectorize for all |
| 313 | + classifiers. |
| 314 | + """ |
| 315 | + epochs = 2 |
| 316 | + n_estimators = 2 |
| 317 | + |
| 318 | + model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) |
| 319 | + |
| 320 | + # Optimizer |
| 321 | + model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) |
| 322 | + |
| 323 | + # Prepare data |
| 324 | + train = TensorDataset(X_train, y_train_clf) |
| 325 | + train_loader = DataLoader(train, batch_size=2, shuffle=False) |
| 326 | + test = TensorDataset(X_test, y_test_clf) |
| 327 | + test_loader = DataLoader(test, batch_size=2, shuffle=False) |
| 328 | + |
| 329 | + # Train |
| 330 | + model.fit(train_loader, epochs=epochs, test_loader=test_loader) |
| 331 | + |
| 332 | + fmodel, params, buffers = model.vectorize() |
| 333 | + |
| 334 | + with torch.no_grad(): |
| 335 | + for idx, (data, target) in enumerate(test_loader): |
| 336 | + vmap_output = vmap(fmodel, in_dims=(0, 0, None))( |
| 337 | + params, buffers, data |
| 338 | + ) |
| 339 | + pytorch_output = [ |
| 340 | + estimator(data) for estimator in model.estimators_ |
| 341 | + ] |
| 342 | + assert torch.allclose( |
| 343 | + vmap_output, torch.stack(pytorch_output), atol=1e-3, rtol=1e-5 |
| 344 | + ) |
| 345 | + |
| 346 | + |
| 347 | +@pytest.mark.parametrize("reg", all_reg) |
| 348 | +def test_reg_vectorize_same_output(reg): |
| 349 | + """ |
| 350 | + This unit test checks the inference with/without vectorize for all |
| 351 | + classifiers. |
| 352 | + """ |
| 353 | + epochs = 2 |
| 354 | + n_estimators = 2 |
| 355 | + |
| 356 | + model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) |
| 357 | + |
| 358 | + # Optimizer |
| 359 | + model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) |
| 360 | + |
| 361 | + # Prepare data |
| 362 | + train = TensorDataset(X_train, y_train_reg) |
| 363 | + train_loader = DataLoader(train, batch_size=2, shuffle=False) |
| 364 | + test = TensorDataset(X_test, y_test_reg) |
| 365 | + test_loader = DataLoader(test, batch_size=2, shuffle=False) |
| 366 | + |
| 367 | + # Train |
| 368 | + model.fit(train_loader, epochs=epochs, test_loader=test_loader) |
| 369 | + |
| 370 | + fmodel, params, buffers = model.vectorize() |
| 371 | + |
| 372 | + with torch.no_grad(): |
| 373 | + for idx, (data, target) in enumerate(test_loader): |
| 374 | + vmap_output = vmap(fmodel, in_dims=(0, 0, None))( |
| 375 | + params, buffers, data |
| 376 | + ) |
| 377 | + pytorch_output = [ |
| 378 | + estimator(data) for estimator in model.estimators_ |
| 379 | + ] |
| 380 | + assert torch.allclose( |
| 381 | + vmap_output, torch.stack(pytorch_output), atol=1e-3, rtol=1e-5 |
| 382 | + ) |
0 commit comments