|
| 1 | +#' @title BoTorch SingleTaskGP Regression Learner |
| 2 | +#' @author Marc Becker |
| 3 | +#' @name mlr_learners_regr.botorch_singletaskgp |
| 4 | +#' |
| 5 | +#' @description |
| 6 | +#' Gaussian Process via [botorch](https://botorch.org/) and [gpytorch](https://gpytorch.ai/), using the `SingleTaskGP` model. |
| 7 | +#' Uses \CRANpkg{reticulate} to interface with Python. |
| 8 | +#' |
| 9 | +#' @templateVar id regr.botorch_singletaskgp |
| 10 | +#' @template learner |
| 11 | +#' |
| 12 | +#' @export |
| 13 | +LearnerRegrBotorchSingleTaskGP = R6Class("LearnerRegrBotorchSingleTaskGP", |
| 14 | + inherit = LearnerRegr, |
| 15 | + |
| 16 | + public = list( |
| 17 | + |
| 18 | + #' @description |
| 19 | + #' Creates a new instance of this [R6][R6::R6Class] class. |
| 20 | + initialize = function() { |
| 21 | + param_set = ps( |
| 22 | + device = p_fct(default = "cpu", levels = c("cpu", "cuda"), tags = c("train", "predict")) |
| 23 | + ) |
| 24 | + |
| 25 | + super$initialize( |
| 26 | + id = "regr.botorch_singletaskgp", |
| 27 | + packages = c("mlr3extralearners", "reticulate"), |
| 28 | + feature_types = c("integer", "numeric"), |
| 29 | + predict_types = c("response", "se"), |
| 30 | + param_set = param_set, |
| 31 | + properties = "marshal", |
| 32 | + label = "BoTorch SingleTaskGP", |
| 33 | + man = "mlr3extralearners::mlr_learners_regr.botorch_singletaskgp" |
| 34 | + ) |
| 35 | + }, |
| 36 | + |
| 37 | + #' @description |
| 38 | + #' Marshal the learner's model. |
| 39 | + #' @param ... (any)\cr |
| 40 | + #' Additional arguments passed to [`marshal_model()`]. |
| 41 | + marshal = function(...) { |
| 42 | + mlr3::learner_marshal(.learner = self, ...) |
| 43 | + }, |
| 44 | + |
| 45 | + #' @description |
| 46 | + #' Unmarshal the learner's model. |
| 47 | + #' @param ... (any)\cr |
| 48 | + #' Additional arguments passed to [`unmarshal_model()`]. |
| 49 | + unmarshal = function(...) { |
| 50 | + mlr3::learner_unmarshal(.learner = self, ...) |
| 51 | + } |
| 52 | + ), |
| 53 | + active = list( |
| 54 | + |
| 55 | + #' @field marshaled (`logical(1)`) |
| 56 | + #' Whether the learner has been marshaled. |
| 57 | + marshaled = function() { |
| 58 | + mlr3::learner_marshaled(self) |
| 59 | + } |
| 60 | + ), |
| 61 | + private = list( |
| 62 | + .train = function(task) { |
| 63 | + assert_python_packages(c("torch", "botorch", "gpytorch")) |
| 64 | + torch = reticulate::import("torch") |
| 65 | + botorch = reticulate::import("botorch") |
| 66 | + gpytorch = reticulate::import("gpytorch") |
| 67 | + |
| 68 | + SingleTaskGP = botorch$models$SingleTaskGP |
| 69 | + ExactMarginalLogLikelihood = gpytorch$mlls$ExactMarginalLogLikelihood |
| 70 | + |
| 71 | + pars = self$param_set$get_values(tags = "train") |
| 72 | + device = pars$device |
| 73 | + |
| 74 | + x = as_numeric_matrix(task$data(cols = task$feature_names)) |
| 75 | + y = as.numeric(task$truth()) |
| 76 | + x_py = torch$as_tensor(x, dtype = torch$float64, device = device) |
| 77 | + y_py = torch$as_tensor(matrix(y, ncol = 1), dtype = torch$float64, device = device) |
| 78 | + |
| 79 | + # normalizing is strongly recommended for the SingleTaskGP model |
| 80 | + input_transform = botorch$models$transforms$Normalize(d = ncol(x)) |
| 81 | + |
| 82 | + gp = SingleTaskGP(x_py, y_py, input_transform = input_transform) |
| 83 | + mll = ExactMarginalLogLikelihood(gp$likelihood, gp) |
| 84 | + botorch$fit$fit_gpytorch_mll(mll) |
| 85 | + gp |
| 86 | + }, |
| 87 | + |
| 88 | + .predict = function(task) { |
| 89 | + assert_python_packages(c("torch", "botorch", "gpytorch")) |
| 90 | + torch = reticulate::import("torch") |
| 91 | + pars = self$param_set$get_values(tags = "predict") |
| 92 | + |
| 93 | + # compute the posterior distribution and extract the mean and covariance matrix |
| 94 | + # disable gradient computation with torch.no_grad() for efficiency |
| 95 | + reticulate::py_run_string("def predict_gp(model, x_py): |
| 96 | + import torch |
| 97 | + with torch.no_grad(): |
| 98 | + posterior = model.posterior(x_py) |
| 99 | + mean = posterior.mean.cpu().numpy() |
| 100 | + covar = posterior.mvn.covariance_matrix.cpu().numpy() |
| 101 | + return mean, covar") |
| 102 | + |
| 103 | + gp = self$model |
| 104 | + # change the model to evaluation mode |
| 105 | + gp$eval() |
| 106 | + |
| 107 | + x = as_numeric_matrix(task$data(cols = task$feature_names)) |
| 108 | + x_py = torch$as_tensor(x, dtype = torch$float64, device = pars$device) |
| 109 | + |
| 110 | + posterior = reticulate::py$predict_gp(gp, x_py) |
| 111 | + mean = as.numeric(posterior[[1]]) |
| 112 | + covar = posterior[[2]] |
| 113 | + |
| 114 | + if (self$predict_type == "response") { |
| 115 | + list(response = mean) |
| 116 | + } else { |
| 117 | + sd = sqrt(diag(covar)) |
| 118 | + list(response = mean, se = sd) |
| 119 | + } |
| 120 | + } |
| 121 | + ) |
| 122 | +) |
| 123 | + |
| 124 | +.extralrns_dict$add("regr.botorch_singletaskgp", LearnerRegrBotorchSingleTaskGP) |
| 125 | + |
| 126 | +#' @export |
| 127 | +marshal_model.botorch_gp_model = function(model, inplace = FALSE, ...) { |
| 128 | + reticulate::py_require(c("torch", "botorch", "gpytorch", "pickle")) |
| 129 | + pickle = reticulate::import("pickle") |
| 130 | + pickled_model = pickle$dumps(model$gp) |
| 131 | + structure(list( |
| 132 | + marshaled = as.raw(pickled_model), |
| 133 | + packages = "mlr3extralearners" |
| 134 | + ), class = c("botorch_gp_model_marshaled", "marshaled")) |
| 135 | +} |
| 136 | + |
| 137 | +#' @export |
| 138 | +unmarshal_model.botorch_gp_model_marshaled = function(model, inplace = FALSE, ...) { |
| 139 | + reticulate::py_require(c("torch", "botorch", "gpytorch", "pickle")) |
| 140 | + pickle = reticulate::import("pickle") |
| 141 | + model_obj = pickle$loads(reticulate::r_to_py(model$marshaled)) |
| 142 | + structure(list(gp = model_obj), class = "botorch_gp_model") |
| 143 | +} |
0 commit comments