-
-
Notifications
You must be signed in to change notification settings - Fork 28
open up PipeOpLearnerCV to all resampling methods #513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2efdb99
b8e1deb
a90616c
7c3c301
10066ed
d7f8969
b114dd8
c7b8a3c
fe1624b
2f235a6
2f99db8
6431bd9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
#' @title Aggregate Features Row-Wise | ||
#' | ||
#' @usage NULL | ||
#' @name mlr_pipeops_aggregate | ||
#' @format [`R6Class`] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. | ||
#' | ||
#' @description | ||
#' Aggregates features row-wise based on multiple observations indicated via a column of role `row_reference` according to expressions given as formulas. | ||
#' Typically used after [`PipeOpLearnerCV`] and prior to [`PipeOpFeatureUnion`] if the resampling method returned multiple predictions per row id. | ||
#' However, note that not all [`Resampling`][mlr3::Resampling] methods result in at least one prediction per original row id. | ||
#' | ||
#' @section Construction: | ||
#' ``` | ||
#' PipeOpAggregate$new(id = "aggregate", param_vals = list()) | ||
#' ``` | ||
#' * `id` :: `character(1)`\cr | ||
#' Identifier of resulting object, default `"aggregate"`. | ||
#' * `param_vals` :: named `list`\cr | ||
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`. | ||
#' | ||
#' @section Input and Output Channels: | ||
#' Input and output channels are inherited from [`PipeOpTaskPreproc`]. | ||
# | ||
#' The output is a [`Task`][mlr3::Task] with the same target as the input [`Task`][mlr3::Task], with features aggregated as specified. | ||
#' | ||
#' @section State: | ||
#' The `$state` is a named `list` with the `$state` elements inherited from [`PipeOpTaskPreproc`]. | ||
#' | ||
#' @section Parameters: | ||
#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as: | ||
#' * `aggregation` :: named `list` of `formula`\cr | ||
#' Expressions for how features should be aggregated, in the form of `formula`. | ||
#' Each element of the list is a `formula` with the name of the element naming the feature to aggregate and the formula expression determining the result. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this shouldn't be a named list of formulae, but just a single formula naming a data.table expression, such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alternatively: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
#' Each formula is evaluated within [`data.table`] environments of the [`Task`][mlr3::Task] that contain all features split via the `by` argument (see below). | ||
#' Initialized to `list()`, i.e., no aggregation is performed. | ||
#' * `by` :: `character(1)` | `NULL`\cr | ||
#' Column indicating the `row_reference` column of the [`Task`][mlr3::Task] that should be the row-wise basis for the aggregation. | ||
#' Initialized to `NULL`, i.e., no aggregation is performed. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe also |
||
#' | ||
#' @section Internals: | ||
#' A `formula` created using the `~` operator always contains a reference to the `environment` in which | ||
#' the `formula` is created. This makes it possible to use variables in the `~`-expressions that both | ||
#' reference either column names or variable names. | ||
#' | ||
#' @section Fields: | ||
#' Only fields inherited from [`PipeOpTaskPreproc`]/[`PipeOp`]. | ||
#' | ||
#' @section Methods: | ||
#' Only methods inherited from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. | ||
#' | ||
#' @family PipeOps | ||
#' @seealso https://mlr3book.mlr-org.com/list-pipeops.html | ||
#' @include PipeOpTaskPreproc.R | ||
#' @export | ||
#' @examples | ||
#' library("mlr3") | ||
#' calculate_mode = function(x) { | ||
#' unique_x = unique(x) | ||
#' unique_x[which.max(tabulate(match(x, unique_x)))] | ||
#' } | ||
#' | ||
#' task = tsk("iris") | ||
#' learner = lrn("classif.rpart") | ||
#' | ||
#' lrnloo_po = po("learner_cv", learner, rsmp("loo")) | ||
#' nop = mlr_pipeops$get("nop") | ||
#' agg_po = po("aggregate", | ||
#' aggregation = list( | ||
#' classif.rpart.response = ~ calculate_mode(classif.rpart.response) | ||
#' ), | ||
#' by = "pre.classif.rpart") | ||
#' | ||
#' graph = gunion(list( | ||
#' lrnloo_po %>>% agg_po, | ||
#' nop | ||
#' )) %>>% po("featureunion") | ||
#' | ||
#' graph$train(task) | ||
#' | ||
#' graph$pipeops$classif.rpart$learner$predict_type = "prob" | ||
#' graph$param_set$values$aggregate.aggregation = list( | ||
#' classif.rpart.prob.setosa = ~ mean(classif.rpart.prob.setosa), | ||
#' classif.rpart.prob.versicolor = ~ mean(classif.rpart.prob.versicolor), | ||
#' classif.rpart.prob.virginica = ~ mean(classif.rpart.prob.virginica) | ||
#' ) | ||
#' graph$train(task) | ||
PipeOpAggregate = R6Class("Aggregate", | ||
inherit = PipeOpTaskPreprocSimple, | ||
public = list( | ||
initialize = function(id = "aggregate", param_vals = list()) { | ||
ps = ParamSet$new(params = list( | ||
ParamUty$new("aggregation", tags = c("train", "predict", "required"), custom_check = check_aggregation_formulae), | ||
ParamUty$new("by", tags = c("train", "predict", "required"), custom_check = function(x) check_string(x, null.ok = TRUE)) | ||
)) | ||
ps$values = list(aggregation = list(), by = NULL) | ||
super$initialize(id, ps, param_vals = param_vals, tags = "ensemble") | ||
} | ||
), | ||
private = list( | ||
.transform = function(task) { | ||
|
||
if (length(self$param_set$values$aggregation) == 0L || is.null(self$param_set$values$by)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. empty aggregation should not be allowed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. empty |
||
return(task) # early exit | ||
} | ||
|
||
assert_set_equal(names(self$param_set$values$aggregation), task$feature_names) | ||
assert_choice(self$param_set$values$by, choices = task$col_roles$row_reference) | ||
|
||
taskdata = task$data(cols = c(task$feature_names, task$col_roles$row_reference)) | ||
taskdata_split = split(taskdata, by = self$param_set$values$by) | ||
|
||
newdata = unique(task$data(cols = c(task$target_names, task$col_roles$row_reference[match(task$col_roles$row_reference, self$param_set$values$by)])), by = self$param_set$values$by) | ||
|
||
nms = names(self$param_set$values$aggregation) | ||
for (i in seq_along(nms)) { | ||
frm = self$param_set$values$aggregation[[i]] | ||
set(newdata, j = nms[i], value = unlist(map(taskdata_split, .f = function(split) eval(frm[[2L]], envir = split, enclos = environment(frm))))) | ||
} | ||
setnames(newdata, old = self$param_set$values$by, new = task$backend$primary_key) | ||
|
||
# get task_type from mlr_reflections and call constructor | ||
constructor = get(mlr_reflections$task_types[["task"]][chmatch(task$task_type, table = mlr_reflections$task_types[["type"]], nomatch = 0L)][[1L]]) | ||
newtask = invoke(constructor$new, id = task$id, backend = as_data_backend(newdata, primary_key = task$backend$primary_key), target = task$target_names, .args = task$extra_args) | ||
newtask$extra_args = task$extra_args | ||
|
||
newtask | ||
} | ||
) | ||
) | ||
|
||
mlr_pipeops$add("aggregate", PipeOpAggregate) | ||
|
||
# check the `aggregation` parameter of PipeOpAggregate | ||
# @param x [list] whatever `aggregation` is being set to | ||
# checks that `aggregation` is | ||
# * a named list of `formula` | ||
# * that each element has only a rhs | ||
check_aggregation_formulae = function(x) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. by now we can use |
||
check_list(x, types = "formula", names = "unique") %check&&% | ||
Reduce(`%check&&%`, lapply(x, function(xel) { | ||
if (length(xel) != 2L) { | ||
return(sprintf("formula %s must not have a left hand side.", | ||
deparse(xel, nlines = 1L, width.cutoff = 500L))) | ||
} | ||
TRUE | ||
}), TRUE) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we don't need to restrict to that colrole?