Skip to content

Commit 4afc38c

Browse files
committed
add unit tests
1 parent eda9009 commit 4afc38c

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

torchensemble/tests/test_all_models.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import numpy as np
44
import torch.nn as nn
55
from numpy.testing import assert_array_equal
6+
7+
from functorch import vmap
68
from torch.utils.data import TensorDataset, DataLoader
79

810
import torchensemble
@@ -302,3 +304,79 @@ def test_predict():
302304
with pytest.raises(ValueError) as excinfo:
303305
model.predict([X_test]) # list
304306
assert "The type of input X should be one of" in str(excinfo.value)
307+
308+
309+
@pytest.mark.parametrize("clf", all_clf)
310+
def test_clf_vectorize_same_output(clf):
311+
"""
312+
This unit test checks the inference with/without vectorize for all
313+
classifiers.
314+
"""
315+
epochs = 2
316+
n_estimators = 2
317+
318+
model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False)
319+
320+
# Optimizer
321+
model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4)
322+
323+
# Prepare data
324+
train = TensorDataset(X_train, y_train_clf)
325+
train_loader = DataLoader(train, batch_size=2, shuffle=False)
326+
test = TensorDataset(X_test, y_test_clf)
327+
test_loader = DataLoader(test, batch_size=2, shuffle=False)
328+
329+
# Train
330+
model.fit(train_loader, epochs=epochs, test_loader=test_loader)
331+
332+
fmodel, params, buffers = model.vectorize()
333+
334+
with torch.no_grad():
335+
for idx, (data, target) in enumerate(test_loader):
336+
vmap_output = vmap(fmodel, in_dims=(0, 0, None))(
337+
params, buffers, data
338+
)
339+
pytorch_output = [
340+
estimator(data) for estimator in model.estimators_
341+
]
342+
assert torch.allclose(
343+
vmap_output, torch.stack(pytorch_output), atol=1e-3, rtol=1e-5
344+
)
345+
346+
347+
@pytest.mark.parametrize("reg", all_reg)
348+
def test_reg_vectorize_same_output(reg):
349+
"""
350+
This unit test checks the inference with/without vectorize for all
351+
classifiers.
352+
"""
353+
epochs = 2
354+
n_estimators = 2
355+
356+
model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False)
357+
358+
# Optimizer
359+
model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4)
360+
361+
# Prepare data
362+
train = TensorDataset(X_train, y_train_reg)
363+
train_loader = DataLoader(train, batch_size=2, shuffle=False)
364+
test = TensorDataset(X_test, y_test_reg)
365+
test_loader = DataLoader(test, batch_size=2, shuffle=False)
366+
367+
# Train
368+
model.fit(train_loader, epochs=epochs, test_loader=test_loader)
369+
370+
fmodel, params, buffers = model.vectorize()
371+
372+
with torch.no_grad():
373+
for idx, (data, target) in enumerate(test_loader):
374+
vmap_output = vmap(fmodel, in_dims=(0, 0, None))(
375+
params, buffers, data
376+
)
377+
pytorch_output = [
378+
estimator(data) for estimator in model.estimators_
379+
]
380+
assert torch.allclose(
381+
vmap_output, torch.stack(pytorch_output), atol=1e-3, rtol=1e-5
382+
)

0 commit comments

Comments
 (0)