diff --git a/DESCRIPTION b/DESCRIPTION index c89039e39..6780b1180 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -62,6 +62,7 @@ Depends: Imports: backports, checkmate, + cli, data.table, digest, lgr, @@ -103,6 +104,7 @@ Suggests: htmlwidgets, ranger, themis +Remotes: mlr-org/mlr3misc ByteCompile: true Encoding: UTF-8 Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index 05672abde..a153d85e2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -221,6 +221,7 @@ if (getRversion() >= "4.3.0") S3method(chooseOpsMethod,CnfAtom) if (getRversion() >= "4.3.0") S3method(chooseOpsMethod,CnfClause) if (getRversion() >= "4.3.0") S3method(chooseOpsMethod,CnfFormula) import(checkmate) +import(cli) import(data.table) import(mlr3) import(mlr3misc) diff --git a/R/Graph.R b/R/Graph.R index 366f2332b..d5fdc09d8 100644 --- a/R/Graph.R +++ b/R/Graph.R @@ -414,7 +414,7 @@ Graph = R6Class("Graph", scc = self$edges[, list(sccssors = paste(unique(dst_id), collapse = ",")), by = list(ID = src_id)] lines = scc[prd[lines, on = "ID"], on = "ID"][, c("ID", "State", "sccssors", "prdcssors")] lines[is.na(lines)] = "" - catf("Graph with %s PipeOps:", nrow(lines)) + cat_cli(cli_h1("Graph with {nrow(lines)} PipeOps:")) ## limit column width ## outwidth = getOption("width") %??% 80 # output width we want (default 80) @@ -423,6 +423,15 @@ Graph = R6Class("Graph", with_options(list(datatable.prettyprint.char = collimit), { print(lines, row.names = FALSE) }) + + is_sequential = all(table(self$edges$src_id) <= 1) && all(table(self$edges$dst_id) <= 1) + if(is_sequential) { + ppunit = paste0(self$ids(), collapse = " -> ") + pp = paste0(c("", ppunit, ""), collapse = " -> ") + } else { + pp = "non-sequential" + } + cat_cli(cli_h3("Pipeline: {.strong {pp}}")) } else { cat("Empty Graph.\n") } diff --git a/R/GraphLearner.R b/R/GraphLearner.R index f519e5007..ff85fd669 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -312,6 +312,18 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, }, plot = function(html = FALSE, horizontal = FALSE, ...) { private$.graph$plot(html = html, horizontal = horizontal, ...) + }, + print = function() { + super$print(self) + + is_sequential = all(table(self$edges$src_id) <= 1) && all(table(self$edges$dst_id) <= 1) + if(is_sequential) { + ppunit = paste0(self$ids(), collapse = " -> ") + pp = paste0(c("", ppunit, ""), collapse = " -> ") + } else { + pp = "non-sequential" + } + cat_cli(cli_h3("Pipeline: {.strong {pp}}")) } ), active = list( diff --git a/R/PipeOp.R b/R/PipeOp.R index b3487b0d6..a3e4d2a01 100644 --- a/R/PipeOp.R +++ b/R/PipeOp.R @@ -271,18 +271,23 @@ PipeOp = R6Class("PipeOp", print = function(...) { type_table_printout = function(table) { - strings = do.call(sprintf, cbind(fmt = "%s`[%s,%s]", table[, c("name", "train", "predict")])) - strings = strwrap(paste(strings, collapse = ", "), indent = 2, exdent = 2) - if (length(strings) > 6) { - strings = c(strings[1:5], sprintf(" [... (%s lines omitted)]", length(strings) - 5)) + print(head(table, 5L), row.names = FALSE, print.keys = FALSE) + if (nrow(table) > 5L) { + catf("[...] (%i rows omitted)", nrow(table) - 5L) } - gsub("`", " ", paste(strings, collapse = "\n")) } - catf("PipeOp: <%s> (%strained)", self$id, if (self$is_trained) "" else "not ") - catf("values: <%s>", as_short_string(self$param_set$values)) - catf("Input channels :\n%s", type_table_printout(self$input)) - catf("Output channels :\n%s", type_table_printout(self$output)) + msg_h = if (self$is_trained) "" else "not " + cat_cli({ + cli_h1("PipeOp {.cls {self$id}}: {msg_h}trained") + cli_text("Values: {as_short_string(self$param_set$values)}") + cli_h3("{.strong Input channels:}") + }) + type_table_printout(self$input) + cat_cli({ + cli_h3("{.strong Output channels:}") + }) + type_table_printout(self$output) }, train = function(input) { diff --git a/R/multiplicity.R b/R/multiplicity.R index a8643d597..d777dac4a 100644 --- a/R/multiplicity.R +++ b/R/multiplicity.R @@ -56,7 +56,7 @@ print.Multiplicity = function(x, ...) { if (!length(x)) { cat("Empty Multiplicity.\n") } else { - cat("Multiplicity:\n") + cli_h2("Multiplicity:") print(unclass(x), ...) } } diff --git a/R/zzz.R b/R/zzz.R index a4333c8a8..34b3aa4d5 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,5 +1,6 @@ #' @import data.table #' @import checkmate +#' @import cli #' @import mlr3 #' @import paradox #' @import mlr3misc diff --git a/inst/testthat/helper_functions.R b/inst/testthat/helper_functions.R index f05d640fe..c8056552a 100644 --- a/inst/testthat/helper_functions.R +++ b/inst/testthat/helper_functions.R @@ -97,7 +97,7 @@ expect_pipeop = function(po, check_ps_default_values = TRUE) { expect_class(po$param_set, "ParamSet", label = label) expect_list(po$param_set$values, names = "unique", label = label) expect_flag(po$is_trained, label = label) - expect_output(print(po), "PipeOp:", label = label) + expect_output(print(po), "PipeOp", label = label) expect_character(po$packages, any.missing = FALSE, unique = TRUE, label = label) expect_function(po$train, nargs = 1) expect_function(po$predict, nargs = 1) diff --git a/tests/testthat/_snaps/PipeOp.md b/tests/testthat/_snaps/PipeOp.md new file mode 100644 index 000000000..880a6aeed --- /dev/null +++ b/tests/testthat/_snaps/PipeOp.md @@ -0,0 +1,107 @@ +# PipeOp printer + + Code + print(PipeOpNOP$new()) + Output + + -- PipeOp : not trained --------------------------------------------------- + Values: list() + + -- Input channels: + name train predict + + input * * + + -- Output channels: + name train predict + + output * * + +--- + + Code + print(PipeOpDebugMulti$new(3, 4)) + Output + + -- PipeOp : not trained ------------------------------------------- + Values: list() + + -- Input channels: + name train predict + + input_1 * * + input_2 * * + input_3 * * + + -- Output channels: + name train predict + + output_1 * * + output_2 * * + output_3 * * + output_4 * * + +--- + + Code + print(PipeOpDebugMulti$new(100, 0)) + Output + + -- PipeOp : not trained ------------------------------------------- + Values: list() + + -- Input channels: + name train predict + + input_1 * * + input_2 * * + input_3 * * + input_4 * * + input_5 * * + [...] (95 rows omitted) + + -- Output channels: + name train predict + + output_ * * + +--- + + Code + print(PipeOpBranch$new(c("odin", "dva", "tri"))) + Output + + -- PipeOp : not trained ------------------------------------------------ + Values: selection=odin + + -- Input channels: + name train predict + + input * * + + -- Output channels: + name train predict + + odin * * + dva * * + tri * * + +--- + + Code + print(PipeOpLearner$new(mlr_learners$get("classif.debug"))) + Output + + -- PipeOp : not trained ----------------------------------------- + Values: list() + + -- Input channels: + name train predict + + input TaskClassif TaskClassif + + -- Output channels: + name train predict + + output NULL PredictionClassif + diff --git a/tests/testthat/test_Graph.R b/tests/testthat/test_Graph.R index 0e8501643..b3ac9ee93 100644 --- a/tests/testthat/test_Graph.R +++ b/tests/testthat/test_Graph.R @@ -20,15 +20,15 @@ test_that("linear graph", { expect_graph(g) - expect_output(print(g), "Graph with 2 PipeOps.*subsample.*UNTRAINED.*pca.*UNTRAINED") - - + expect_output(print(g), "Graph with 2 PipeOps:") + expect_output(print(g), ".*subsample.*UNTRAINED.*pca.*UNTRAINED") inputs = mlr_tasks$get("iris") x = g$train(inputs) expect_task(x[[1]]) - expect_output(print(g), "Graph with 2 PipeOps.*subsample.*list.*pca.*prcomp") + expect_output(print(g), "Graph with 2 PipeOps") + expect_output(print(g), ".*subsample.*list.*pca.*prcomp") out = g$predict(inputs) expect_task(x[[1]]) diff --git a/tests/testthat/test_PipeOp.R b/tests/testthat/test_PipeOp.R index 6a030929b..daca39329 100644 --- a/tests/testthat/test_PipeOp.R +++ b/tests/testthat/test_PipeOp.R @@ -9,7 +9,7 @@ test_that("PipeOp - General functions", { expect_false(po_1$is_trained) expect_class(po_1$param_set, "ParamSet") expect_list(po_1$param_set$values, names = "unique") - expect_output(print(po_1), "PipeOp:") + expect_output(print(po_1), "PipeOp") expect_equal(po_1$packages, "mlr3pipelines") expect_null(po_1$state) assert_subset(po_1$tags, mlr_reflections$pipeops$valid_tags) @@ -29,22 +29,11 @@ test_that("PipeOp - simple tests with PipeOpScale", { }) test_that("PipeOp printer", { - expect_output(print(PipeOpNOP$new()), - "PipeOp.*.*not trained.*values.*list().*Input channels.*input \\[\\*,\\*\\]\n.*Output channels.*output \\[\\*,\\*\\]$") - - - expect_output(print(PipeOpDebugMulti$new(3, 4)), - "PipeOp.*.*not trained.*values.*list().*Input channels.*input_1 \\[\\*,\\*\\], input_2 \\[\\*,\\*\\], input_3 \\[\\*,\\*\\]\n.*Output channels.*output_1 \\[\\*,\\*\\], output_2 \\[\\*,\\*\\], output_3 \\[\\*,\\*\\], output_4 \\[\\*,\\*\\]$") - - - expect_output(print(PipeOpDebugMulti$new(100, 0)), - "\\[\\.\\.\\. \\([0-9]+ lines omitted\\)\\]") - - expect_output(print(PipeOpBranch$new(c("odin", "dva", "tri"))), - "Output channels.*odin \\[\\*,\\*\\], dva \\[\\*,\\*\\], tri \\[\\*,\\*\\]$") - - expect_output(print(PipeOpLearner$new(mlr_learners$get("classif.debug"))), - "PipeOp.*.*Input channels.*input \\[TaskClassif,TaskClassif\\]\nOutput channels.*output \\[NULL,PredictionClassif\\]$") + expect_snapshot(print(PipeOpNOP$new())) + expect_snapshot(print(PipeOpDebugMulti$new(3, 4))) + expect_snapshot(print(PipeOpDebugMulti$new(100, 0))) + expect_snapshot(print(PipeOpBranch$new(c("odin", "dva", "tri")))) + expect_snapshot(print(PipeOpLearner$new(mlr_learners$get("classif.debug")))) }) test_that("Prevent creation of PipeOps with no channels", { diff --git a/tests/testthat/test_multiplicities.R b/tests/testthat/test_multiplicities.R index cbb362fd7..4e0c05530 100644 --- a/tests/testthat/test_multiplicities.R +++ b/tests/testthat/test_multiplicities.R @@ -10,7 +10,7 @@ test_that("Multiplicity class and methods", { expect_multiplicity(assert_multiplicity(nmp)) expect_multiplicity(assert_multiplicity(nmp, check_nesting = TRUE)) expect_error(assert_multiplicity(as.Multiplicity(list(0, Multiplicity(0))), .var.name = "y", check_nesting = TRUE), regexp = "Inconsistent multiplicity nesting level") - expect_output(print(mp), regexp = "Multiplicity:") + expect_message(print(mp), regexp = "Multiplicity:") expect_output(print(Multiplicity()), regexp = "Empty Multiplicity.") })