diff --git a/R/Graph.R b/R/Graph.R index 366f2332b..483634de5 100644 --- a/R/Graph.R +++ b/R/Graph.R @@ -459,6 +459,10 @@ Graph = R6Class("Graph", graph_load_namespaces(self, "predict") graph_reduce(self, input, "predict", single_input) }, + hotstart = function(input, single_input = TRUE) { + graph_load_namespaces(self, "train") + graph_reduce(self, input, "hotstart", single_input) + }, help = function(help_type = getOption("help_type")) { parts = strsplit(self$man, split = "::", fixed = TRUE)[[1]] match.fun("help")(parts[[2]], package = parts[[1]], help_type = help_type) diff --git a/R/GraphLearner.R b/R/GraphLearner.R index f519e5007..a19f65bb9 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -454,6 +454,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, } on.exit({self$graph$state = NULL}) + self$graph$train(task) state = self$graph$state class(state) = c("graph_learner_model", class(state)) @@ -466,6 +467,23 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, assert_list(prediction, types = "Prediction", len = 1, .var.name = sprintf("Prediction returned by Graph %s", self$id)) prediction[[1]] + }, + + .hotstart = function(task) { + if (!is.null(get0("validate", self))) { + some_pipeops_validate = some(pos_with_property(self, "validation"), function(po) !is.null(po$validate)) + if (!some_pipeops_validate) { + lg$warn("GraphLearner '%s' specifies a validation set, but none of its PipeOps use it.", self$id) + } + } + + on.exit({self$graph$state = NULL}) + # copy hotstart state to graph + self$graph$state = self$state$model + self$graph$hotstart(task) + state = self$graph$state + class(state) = c("graph_learner_model", class(state)) + state } ) ) diff --git a/R/PipeOp.R b/R/PipeOp.R index b3487b0d6..456561387 100644 --- a/R/PipeOp.R +++ b/R/PipeOp.R @@ -348,6 +348,11 @@ PipeOp = R6Class("PipeOp", output = check_types(self, output, "output", "predict") output }, + hotstart = function(input) { + # default for all pipops is to just train them + # pipeops that can do hotstarting should overload this method + self$train(input) + }, help = function(help_type = getOption("help_type")) { parts = strsplit(self$man, split = "::", fixed = TRUE)[[1]] match.fun("help")(parts[[2]], package = parts[[1]], help_type = help_type) diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index 0849480af..a5fcc8e1b 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -112,6 +112,18 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, output = data.table(name = "output", train = "NULL", predict = out_type), tags = "learner", packages = learner$packages, properties = properties ) + }, + hotstart = function(input) { + on.exit({private$.learner$state = NULL}) + # copy state to learner + learner = private$.learner + learner$state = self$state + + # train learner with hotstarting + train_result = mlr3:::learner_train(learner, task = input[[1]], train_row_ids = NULL, mode = "hotstart") + self$state = train_result$learner$state + + list(NULL) } ), active = list(