Skip to content

Commit a0dc002

Browse files
committed
update function signature
1 parent d4b1a8c commit a0dc002

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

torchensemble/_constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@
174174

175175

176176
__vectorize_doc = """
177-
Return the vectorization result of the ensemble using functorch.
177+
Return the vectorization result of the ensemble using functorch. Details
178+
available at `functorch model ensembling <https://pytorch.org/functorch/stable/notebooks/ensembling.html>`_.
178179
179180
Returns
180181
-------

torchensemble/voting.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,10 @@ def evaluate(self, test_loader, return_loss=False):
307307
def predict(self, *x):
308308
return super().predict(*x)
309309

310+
@torchensemble_model_doc(item="vectorize")
311+
def vectorize(self):
312+
return super().vectorize()
313+
310314

311315
@torchensemble_model_doc(
312316
"""Implementation on the NeuralForestClassifier.""", "tree_ensemble_model"
@@ -374,6 +378,10 @@ def fit(
374378
save_dir=save_dir,
375379
)
376380

381+
@torchensemble_model_doc(item="vectorize")
382+
def vectorize(self):
383+
return super().vectorize()
384+
377385

378386
@torchensemble_model_doc("""Implementation on the VotingRegressor.""", "model")
379387
class VotingRegressor(BaseRegressor):
@@ -559,6 +567,10 @@ def evaluate(self, test_loader):
559567
def predict(self, *x):
560568
return super().predict(*x)
561569

570+
@torchensemble_model_doc(item="vectorize")
571+
def vectorize(self):
572+
return super().vectorize()
573+
562574

563575
@torchensemble_model_doc(
564576
"""Implementation on the NeuralForestRegressor.""", "tree_ensemble_model"
@@ -620,3 +632,7 @@ def fit(
620632
save_model=save_model,
621633
save_dir=save_dir,
622634
)
635+
636+
@torchensemble_model_doc(item="vectorize")
637+
def vectorize(self):
638+
return super().vectorize()

0 commit comments

Comments
 (0)