Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 138 additions & 2 deletions vignettes/extending.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ model = rpart::rpart(mpg ~ ., data = mtcars, xval = 0)
We need to pass the formula notation `mpg ~ .`, the data and the hyperparameters.
To get the hyperparameters, we call `self$param_set$get_values(tag = "train")` and thereby query all parameters that are using during `"train"`.
Then, the dataset is extracted from the `Task`.
Because the learner has the property `"weights"`, we insert the weights of the task if there are any.
Because the learner has the property `"weights"`, we insert the weights of the task if there are any, by calling `$.get_weights()` to to retrieve them.
Then we obtain the formula from the task using `task$formula()` and access the training data using `task$data()`.
Last, we call the upstream function `rpart::rpart()` with the data and pass all hyperparameters via argument `.args` using the `mlr3misc::invoke()` function.
The latter is simply an optimized version of `do.call()` that we use within the `mlr3` ecosystem.
Expand Down Expand Up @@ -295,6 +295,143 @@ importance = function() {
}
```

## Internal Validation and Tuning

Some learners, such as boosting algorithms (e.g., LightGBM), or deep learning models support validation data during the training to monitor the validation performance during training.
This validation data can also be used to internally tune hyperparameters, e.g., the number of boosting iterations in LightGBM or the number of epochs in FastAi, enabling early stopping.

**Internal Validation**

Add the `"validation"` property to the learner. Then, implement the `$internal_valid_scores` active binding to retrieve the validation scores computed during training, and the `$validate` binding to serve as a getter/setter for the internal validation set.

```{r, eval=FALSE}
active = list(
#' @field internal_valid_scores
internal_valid_scores = function() {
self$state$internal_valid_scores
}

#' @field validate
validate = function(rhs) {
if (!missing(rhs)) {
private$.validate = mlr3::assert_validate(rhs)
}
private$.validate
}
)
```

Additionally, implement the private method `$.extract_internal_valid_scores()` which returns the (final) internal validation scores from the model of the `Learner` as a named `list()` of `numeric(1)`.
If the model is not trained yet, this method should return `NULL`. Also, add a private field `.validate = NULL` to store the internal validation set. This will later be accessed or modified via the active binding `$validate()`.

**Internal Tuning**

Internal tuning _can_ rely on internal validation data, but it doesn't necessarily have to. To enable internal tuning and early stopping, annotate the learner with the `"internal_tuning"` property. Then, add the active binding `$internal_tuned_values` which accesses internal hyperparameters that were automatically tuned during training, e.g., last epoch that yielded improvement in `LearnerClassifFastai`. Similar to internal validation, implement the `$internal_tuned_values` active binding and the private method `$.extract_internal_tuned_values()` which should return the internally tuned values from the learner's model as a named `list()`. If the model is not trained yet or internal tuning is disabled, this method should return `NULL`.
For example, `LearnerClassifFastai` supports early stopping by tuning the number of epochs and stopping when no further improvement is observed. This behavior is activated by setting the `"patience"` parameter in the learner’s `ParamSet` to a non-null value. The tuned value can then be accessed by the `eval_protocol` which itself is set in the `.train` function:

```{r, eval=FALSE}
.extract_internal_tuned_values = function() {
if (is.null(self$state$param_vals$patience) || is.null(self$state$eval_protocol)) {
return(NULL)
}
list(n_epoch = max(self$state$eval_protocol$epoch) + 1)
}
```

The parameter that is ought to be tuned by internal tuning must must be tagged with `"internal_tuning"`, which requires to also provide a `in_tune_fn` and `disable_tune_fn`, and should also include a default aggregation function.

```{r}
p_n_epoch = p_int(1L,
tags = c("train", "hotstart", "internal_tuning"),
init = 5L,
aggr = crate(function(x) as.integer(ceiling(mean(unlist(x)))), .parent = topenv()),
in_tune_fn = crate(function(domain, param_vals) {
if (is.null(param_vals$patience)) {
stop("Parameter 'patience' must be set to use internal tuning.")
}
assert_integerish(domain$upper, len = 1L, any.missing = FALSE)
}, .parent = topenv()),
disable_in_tune = list(n_epoch = NULL)
)
```


## Implement a python-powered learner with reticulate

Some learners rely on Python packages like `fastai`, `tabpfn`, or `pycox`. To implement such learners, you need to utilize the `reticulate` R package that provides seamless interoperability between R and Python. It allows you to call Python code, import Python modules, and work with Python modules directly from R.
Type conversion happens for most basic and structured types automatically when passing objects between R and Python. However, not all Python objects have a direct R equivalent. When such objects are returned to R, `reticulate` does not convert them but instead, it creates a reference to the original Python object that remains in memory. These objects can still be used and interacted with directly from R. To manage these references, `reticulate` assigns them the S3 class `python.builtin.object`.

```{r, eval=FALSE}
np = reticulate::import("numpy")
class(np)
#> [1] "python.builtin.module" "python.builtin.object"

class(fastai::TabularDataTable(df = iris))
#> [1] "fastai.tabular.core.TabularPandas" "fastai.tabular.core.Tabular"
#> [3] "fastcore.foundation.CollBase" "fastcore.basics.GetAttr"
#> [5] "fastai.data.core.FilteredBase" "python.builtin.object"
```

Generally, you need to explicitly convert an R object into a Python-compatible object using `reticulate::r_to_py()`, before passing it to a Python function or class method. Furthermore, you want to convert a Python object into an R-native format with `reticulate::py_to_r()`, to use it in downstream R functions or to extract values for inspection or postprocessing. Note, that you must force `NaN` in R objects to make the conversion work, otherwise `reticulate` will not properly convert `NA`s in logical and integer columns to `numpy.na`.

By default, `reticulate` creates and uses a Conda environment called `"r-reticulate"` to manage Python dependencies for R projects. This environment is automatically created the first time you install a Python package via `reticulate::py_install()`.
Unless you explicitly set another Python path (via `use_python()` or `use_condaenv()`), `reticulate` defaults to using `"r-reticulate"`.

In the learner's `.train()` and `.predict()` methods, you will typically either call a Python module directly, e.g., `np = reticulate::import("numpy")`, or use an R wrapper package that internally handles the Python interaction, such as `fastai`. To ensure the required Python modules are available and working correctly, add `assert_python_packages("<python-package>")` at the very beginning of both the `.train()` *and* `.predict()` functions. This checks that the package is installed and ready to use.

Here is a minimal working example showing the `.train()` function of `LearnerClassifTabPFN`:

```{r, eval=FALSE}
.train = function(task) {
assert_python_packages(c("torch", "tabpfn"))
tabpfn = reticulate::import("tabpfn")
pars = self$param_set$get_values(tags = "train")

if (!is.null(pars$device) && pars$device != "auto") {
torch = reticulate::import("torch")
pars$device = torch$device(pars$device)
}
x = as.matrix(task$data(cols = task$feature_names))
x[is.na(x)] = NaN
y = task$truth()

classifier = mlr3misc::invoke(tabpfn$TabPFNClassifier, .args = pars)
x_py = reticulate::r_to_py(x)
y_py = reticulate::r_to_py(y)
fitted = mlr3misc::invoke(classifier$fit, X = x_py, y = y_py)

structure(list(fitted = fitted), class = "tabpfn_model")
}
```

### Marshaling

For learners that wrap Python models, marshaling requires special handling, because the underlying model lives in the Python runtime and cannot be directly serialized with R. Instead, you need to use the `pickle` Python module. It should be available in any Python environment, so you do not need to assert it with `assert_python_packages`. You simply load it with `reticulate::import("pickle")`.
Here's an example of how marshal_model() and unmarshal_model() are implemented for `LearnerClassifTabPFN`:

```{r, eval=FALSE}
marshal_model.tabpfn_model = function(model, inplace = FALSE, ...) {
reticulate::py_require("tabpfn")
reticulate::import("tabpfn")
pickle = reticulate::import("pickle")
pickled = pickle$dumps(model$fitted)
pickled = as.raw(pickled)

structure(list(
marshaled = pickled,
packages = "mlr3extralearners"
), class = c("tabpfn_model_marshaled", "marshaled"))
}

unmarshal_model.tabpfn_model_marshaled = function(model, inplace = FALSE, ...) {
reticulate::py_require("tabpfn")
reticulate::import("tabpfn")
pickle = reticulate::import("pickle")
fitted = pickle$loads(reticulate::r_to_py(model$marshaled))
structure(list(fitted = fitted), class = "tabpfn_model")
}
```

## Testing the learner

Once your learner is created, you should write tests to verify its correctness.
Expand Down Expand Up @@ -389,7 +526,6 @@ test_that("paramtest", {
})
```


## Contributing to mlr3extralearners

When adding a `Learner` to `mlr3extralearners` there are some additional requirements that have to be satisfied:
Expand Down