Skip to content

Commit 9edcc9c

Browse files
authored
feat: add botorch gp learners (#459)
Thank you for contributing a learner to the mlr3 ecosystem. Please make sure that: - [x] The added learner(s) are sufficiently tested - [x] All the CI tests are passing (including the CodeFactor) - [x] You ran `devtools::document()` - [x] You updated the `NEWS.md` field to include the addition of the learner - [x] You did not modify anything **not** related to the new learner - [x] You are listed as a contributor in the `DESCRIPTION` of the R package
1 parent 743e286 commit 9edcc9c

17 files changed

+638
-37
lines changed

DESCRIPTION

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@ Authors@R: c(
2323
person("Lukas", "Burk", , "[email protected]", role = "ctb",
2424
comment = c(ORCID = "0000-0001-7528-3795")),
2525
person("Lona", "Koers", , "[email protected]", role = "ctb"),
26+
person("Anna", "Nazarova", , "[email protected]", role = "ctb"),
2627
person("Baisu", "Zhou", , "[email protected]", role = "ctb"),
27-
person("Nikolai", "German", , "[email protected]", role = "ctb"),
28-
person("Anna", "Nazarova", , "[email protected]", role = "ctb")
28+
person("Marc", "Becker", , "[email protected]", role = "ctb",
29+
comment = c(ORCID = "0000-0002-8115-0400")),
30+
person("Nikolai", "German", , "[email protected]", role = "ctb")
2931
)
3032
Description: Extra learners for use in mlr3.
3133
License: LGPL-3
32-
URL: https://mlr3extralearners.mlr-org.com, https://github.com/mlr-org/mlr3extralearners
34+
URL: https://mlr3extralearners.mlr-org.com,
35+
https://github.com/mlr-org/mlr3extralearners
3336
BugReports: https://github.com/mlr-org/mlr3extralearners/issues
3437
Depends:
3538
mlr3 (>= 1.0.0),
@@ -47,8 +50,8 @@ Imports:
4750
Suggests:
4851
abess,
4952
actuar,
50-
adabag,
5153
ada,
54+
adabag,
5255
aorsf (>= 0.1.5),
5356
apcluster,
5457
BART (>= 2.9.4),
@@ -68,9 +71,9 @@ Suggests:
6871
distr6,
6972
earth,
7073
evtree,
74+
ExhaustiveSearch,
7175
fastai,
7276
flexsurv (>= 2.3),
73-
ExhaustiveSearch,
7477
FNN,
7578
formattable,
7679
future,
@@ -104,24 +107,24 @@ Suggests:
104107
partykit (>= 1.2-21),
105108
penalized (>= 0.9-52),
106109
pendensity,
107-
polyreg,
108110
plugdensity,
111+
polyreg,
109112
prioritylasso (>= 0.3.1),
110113
pseudo,
111114
qgam,
112115
randomForest,
113116
randomForestSRC (>= 3.4.1),
114117
ranger (>= 0.17.0),
115118
remotes,
116-
rotationForest,
117119
rFerns,
118120
riskRegression,
119121
rJava,
122+
rmarkdown,
123+
rotationForest,
120124
rpart,
121125
rsm,
122126
rvest,
123127
RWeka,
124-
rmarkdown,
125128
sandwich,
126129
sda,
127130
set6,
@@ -146,10 +149,10 @@ Remotes:
146149
mlr-org/mlr3misc,
147150
mlr-org/mlr3proba,
148151
RaphaelS1/survivalmodels,
152+
rstudio/reticulate,
149153
xoopR/distr6,
150154
xoopR/param6,
151-
xoopR/set6,
152-
rstudio/reticulate
155+
xoopR/set6
153156
Config/Needs/website: rmarkdown
154157
Config/testthat/edition: 3
155158
Config/testthat/parallel: false

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(marshal_model,Weka_classifier)
4+
S3method(marshal_model,botorch_gp_model)
45
S3method(marshal_model,fastai_model)
56
S3method(marshal_model,tabpfn_model)
67
S3method(marshal_model,xgboost_cox_model)
78
S3method(unmarshal_model,Weka_classifier_marshaled)
9+
S3method(unmarshal_model,botorch_gp_model_marshaled)
810
S3method(unmarshal_model,fastai_model_marshaled)
911
S3method(unmarshal_model,tabpfn_model_marshaled)
1012
S3method(unmarshal_model,xgboost_cox_model_marshaled)
@@ -87,6 +89,8 @@ export(LearnerDensSpline)
8789
export(LearnerRegrAbess)
8890
export(LearnerRegrBart)
8991
export(LearnerRegrBlockForest)
92+
export(LearnerRegrBotorchMixedSingleTaskGP)
93+
export(LearnerRegrBotorchSingleTaskGP)
9094
export(LearnerRegrBrnn)
9195
export(LearnerRegrBst)
9296
export(LearnerRegrCForest)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
- `LearnerClassifRferns`
2222
- `LearnerClassifNeuralnet`
2323
- `LearnerRegrBrnn`
24+
- `LearnerRegrBotorchSingleTaskGP`
25+
- `LearnerRegrBotorchMixedSingleTaskGP`
2426

2527
* Add new `control_custom_fun` parameter in `surv.aorsf`
2628
* New function `learner_is_runnable()` to check whether the
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#' @title BoTorch MixedSingleTaskGP Regression Learner
2+
#' @author Marc Becker
3+
#' @name mlr_learners_regr.botorch_mixedsingletaskgp
4+
#'
5+
#' @description
6+
#' Gaussian Process via [botorch](https://botorch.org/) and [gpytorch](https://gpytorch.ai/), using the `MixedSingleTaskGP`.
7+
#' Uses \CRANpkg{reticulate} to interface with Python.
8+
#'
9+
#' @templateVar id regr.botorch_mixedsingletaskgp
10+
#' @template learner
11+
#'
12+
#' @export
13+
LearnerRegrBotorchMixedSingleTaskGP = R6Class("LearnerRegrBotorchMixedSingleTaskGP",
14+
inherit = LearnerRegr,
15+
16+
public = list(
17+
18+
#' @description
19+
#' Creates a new instance of this [R6][R6::R6Class] class.
20+
initialize = function() {
21+
ps = ps(
22+
device = p_fct(default = "cpu", levels = c("cpu", "cuda"), tags = "train")
23+
)
24+
super$initialize(
25+
id = "regr.botorch_mixedsingletaskgp",
26+
packages = c("mlr3extralearners", "reticulate"),
27+
feature_types = c("integer", "numeric", "logical", "factor"),
28+
predict_types = c("response", "se"),
29+
param_set = ps,
30+
properties = "marshal",
31+
label = "BoTorch MixedSingleTaskGP",
32+
man = "mlr3extralearners::mlr_learners_regr.botorch_mixedsingletaskgp"
33+
)
34+
},
35+
36+
#' @description
37+
#' Marshal the learner's model.
38+
#' @param ... (any)\cr
39+
#' Additional arguments passed to [`marshal_model()`].
40+
marshal = function(...) {
41+
mlr3::learner_marshal(.learner = self, ...)
42+
},
43+
44+
#' @description
45+
#' Unmarshal the learner's model.
46+
#' @param ... (any)\cr
47+
#' Additional arguments passed to [`unmarshal_model()`].
48+
unmarshal = function(...) {
49+
mlr3::learner_unmarshal(.learner = self, ...)
50+
}
51+
),
52+
active = list(
53+
54+
#' @field marshaled (`logical(1)`)
55+
#' Whether the learner has been marshaled.
56+
marshaled = function() {
57+
mlr3::learner_marshaled(self)
58+
}
59+
),
60+
private = list(
61+
.train = function(task) {
62+
assert_python_packages(c("torch", "botorch", "gpytorch"))
63+
torch = reticulate::import("torch")
64+
botorch = reticulate::import("botorch")
65+
gpytorch = reticulate::import("gpytorch")
66+
MixedSingleTaskGP = botorch$models$gp_regression_mixed$MixedSingleTaskGP
67+
ExactMarginalLogLikelihood = gpytorch$mlls$ExactMarginalLogLikelihood
68+
69+
pars = self$param_set$get_values(tags = "train")
70+
device = pars$device
71+
72+
x = task$data(cols = task$feature_names)
73+
y = task$truth()
74+
75+
# convert factors and logicals to integers
76+
cols = which(sapply(x, function(x) is.factor(x) || is.logical(x)))
77+
78+
if (!length(cols)) {
79+
stop("At least one logical or categorical feature is required")
80+
}
81+
82+
x[, (cols) := lapply(.SD, as.integer), .SDcols = cols]
83+
x = as_numeric_matrix(x)
84+
x_py = torch$as_tensor(x, dtype = torch$float64, device = device)
85+
y_py = torch$as_tensor(matrix(y, ncol = 1), dtype = torch$float64, device = device)
86+
87+
# 0-based categorical dimensions
88+
cat_dims = reticulate::r_to_py(unname(as.list(cols - 1L)))
89+
90+
gp = MixedSingleTaskGP(x_py, y_py, cat_dims = cat_dims)
91+
mll = ExactMarginalLogLikelihood(gp$likelihood, gp)
92+
botorch$fit$fit_gpytorch_mll(mll)
93+
94+
gp
95+
},
96+
97+
.predict = function(task) {
98+
assert_python_packages(c("torch", "botorch", "gpytorch"))
99+
torch = reticulate::import("torch")
100+
pars = self$param_set$get_values(tags = "predict")
101+
102+
# compute the posterior distribution and extract the mean and covariance matrix
103+
# disable gradient computation with torch.no_grad() for efficiency
104+
reticulate::py_run_string("def predict_gp(model, x_py):
105+
import torch
106+
with torch.no_grad():
107+
posterior = model.posterior(x_py)
108+
mean = posterior.mean.cpu().numpy()
109+
covar = posterior.mvn.covariance_matrix.cpu().numpy()
110+
return mean, covar")
111+
112+
gp = self$model
113+
# change the model to evaluation mode
114+
gp$eval()
115+
116+
x = task$data(cols = task$feature_names)
117+
# convert factors and logicals to integers
118+
cols = which(sapply(x, function(x) is.factor(x) || is.logical(x)))
119+
x[, (cols) := lapply(.SD, as.integer), .SDcols = cols]
120+
x = as_numeric_matrix(x)
121+
x_py = torch$as_tensor(x, dtype = torch$float64, device = pars$device)
122+
123+
posterior = reticulate::py$predict_gp(gp, x_py)
124+
mean = as.numeric(posterior[[1]])
125+
covar = posterior[[2]]
126+
127+
if (self$predict_type == "response") {
128+
list(response = mean)
129+
} else {
130+
sd = sqrt(diag(covar))
131+
list(response = mean, se = sd)
132+
}
133+
}
134+
)
135+
)
136+
137+
.extralrns_dict$add("regr.botorch_mixedsingletaskgp", LearnerRegrBotorchMixedSingleTaskGP)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#' @title BoTorch SingleTaskGP Regression Learner
2+
#' @author Marc Becker
3+
#' @name mlr_learners_regr.botorch_singletaskgp
4+
#'
5+
#' @description
6+
#' Gaussian Process via [botorch](https://botorch.org/) and [gpytorch](https://gpytorch.ai/), using the `SingleTaskGP` model.
7+
#' Uses \CRANpkg{reticulate} to interface with Python.
8+
#'
9+
#' @templateVar id regr.botorch_singletaskgp
10+
#' @template learner
11+
#'
12+
#' @export
13+
LearnerRegrBotorchSingleTaskGP = R6Class("LearnerRegrBotorchSingleTaskGP",
14+
inherit = LearnerRegr,
15+
16+
public = list(
17+
18+
#' @description
19+
#' Creates a new instance of this [R6][R6::R6Class] class.
20+
initialize = function() {
21+
param_set = ps(
22+
device = p_fct(default = "cpu", levels = c("cpu", "cuda"), tags = c("train", "predict"))
23+
)
24+
25+
super$initialize(
26+
id = "regr.botorch_singletaskgp",
27+
packages = c("mlr3extralearners", "reticulate"),
28+
feature_types = c("integer", "numeric"),
29+
predict_types = c("response", "se"),
30+
param_set = param_set,
31+
properties = "marshal",
32+
label = "BoTorch SingleTaskGP",
33+
man = "mlr3extralearners::mlr_learners_regr.botorch_singletaskgp"
34+
)
35+
},
36+
37+
#' @description
38+
#' Marshal the learner's model.
39+
#' @param ... (any)\cr
40+
#' Additional arguments passed to [`marshal_model()`].
41+
marshal = function(...) {
42+
mlr3::learner_marshal(.learner = self, ...)
43+
},
44+
45+
#' @description
46+
#' Unmarshal the learner's model.
47+
#' @param ... (any)\cr
48+
#' Additional arguments passed to [`unmarshal_model()`].
49+
unmarshal = function(...) {
50+
mlr3::learner_unmarshal(.learner = self, ...)
51+
}
52+
),
53+
active = list(
54+
55+
#' @field marshaled (`logical(1)`)
56+
#' Whether the learner has been marshaled.
57+
marshaled = function() {
58+
mlr3::learner_marshaled(self)
59+
}
60+
),
61+
private = list(
62+
.train = function(task) {
63+
assert_python_packages(c("torch", "botorch", "gpytorch"))
64+
torch = reticulate::import("torch")
65+
botorch = reticulate::import("botorch")
66+
gpytorch = reticulate::import("gpytorch")
67+
68+
SingleTaskGP = botorch$models$SingleTaskGP
69+
ExactMarginalLogLikelihood = gpytorch$mlls$ExactMarginalLogLikelihood
70+
71+
pars = self$param_set$get_values(tags = "train")
72+
device = pars$device
73+
74+
x = as_numeric_matrix(task$data(cols = task$feature_names))
75+
y = as.numeric(task$truth())
76+
x_py = torch$as_tensor(x, dtype = torch$float64, device = device)
77+
y_py = torch$as_tensor(matrix(y, ncol = 1), dtype = torch$float64, device = device)
78+
79+
# normalizing is strongly recommended for the SingleTaskGP model
80+
input_transform = botorch$models$transforms$Normalize(d = ncol(x))
81+
82+
gp = SingleTaskGP(x_py, y_py, input_transform = input_transform)
83+
mll = ExactMarginalLogLikelihood(gp$likelihood, gp)
84+
botorch$fit$fit_gpytorch_mll(mll)
85+
gp
86+
},
87+
88+
.predict = function(task) {
89+
assert_python_packages(c("torch", "botorch", "gpytorch"))
90+
torch = reticulate::import("torch")
91+
pars = self$param_set$get_values(tags = "predict")
92+
93+
# compute the posterior distribution and extract the mean and covariance matrix
94+
# disable gradient computation with torch.no_grad() for efficiency
95+
reticulate::py_run_string("def predict_gp(model, x_py):
96+
import torch
97+
with torch.no_grad():
98+
posterior = model.posterior(x_py)
99+
mean = posterior.mean.cpu().numpy()
100+
covar = posterior.mvn.covariance_matrix.cpu().numpy()
101+
return mean, covar")
102+
103+
gp = self$model
104+
# change the model to evaluation mode
105+
gp$eval()
106+
107+
x = as_numeric_matrix(task$data(cols = task$feature_names))
108+
x_py = torch$as_tensor(x, dtype = torch$float64, device = pars$device)
109+
110+
posterior = reticulate::py$predict_gp(gp, x_py)
111+
mean = as.numeric(posterior[[1]])
112+
covar = posterior[[2]]
113+
114+
if (self$predict_type == "response") {
115+
list(response = mean)
116+
} else {
117+
sd = sqrt(diag(covar))
118+
list(response = mean, se = sd)
119+
}
120+
}
121+
)
122+
)
123+
124+
.extralrns_dict$add("regr.botorch_singletaskgp", LearnerRegrBotorchSingleTaskGP)
125+
126+
#' @export
127+
marshal_model.botorch_gp_model = function(model, inplace = FALSE, ...) {
128+
reticulate::py_require(c("torch", "botorch", "gpytorch", "pickle"))
129+
pickle = reticulate::import("pickle")
130+
pickled_model = pickle$dumps(model$gp)
131+
structure(list(
132+
marshaled = as.raw(pickled_model),
133+
packages = "mlr3extralearners"
134+
), class = c("botorch_gp_model_marshaled", "marshaled"))
135+
}
136+
137+
#' @export
138+
unmarshal_model.botorch_gp_model_marshaled = function(model, inplace = FALSE, ...) {
139+
reticulate::py_require(c("torch", "botorch", "gpytorch", "pickle"))
140+
pickle = reticulate::import("pickle")
141+
model_obj = pickle$loads(reticulate::r_to_py(model$marshaled))
142+
structure(list(gp = model_obj), class = "botorch_gp_model")
143+
}

0 commit comments

Comments
 (0)