Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Depends:
Imports:
backports,
checkmate,
cli,
data.table,
digest,
lgr,
Expand Down Expand Up @@ -103,6 +104,7 @@ Suggests:
htmlwidgets,
ranger,
themis
Remotes: mlr-org/mlr3misc
ByteCompile: true
Encoding: UTF-8
Config/testthat/edition: 3
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ if (getRversion() >= "4.3.0") S3method(chooseOpsMethod,CnfAtom)
if (getRversion() >= "4.3.0") S3method(chooseOpsMethod,CnfClause)
if (getRversion() >= "4.3.0") S3method(chooseOpsMethod,CnfFormula)
import(checkmate)
import(cli)
import(data.table)
import(mlr3)
import(mlr3misc)
Expand Down
11 changes: 10 additions & 1 deletion R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ Graph = R6Class("Graph",
scc = self$edges[, list(sccssors = paste(unique(dst_id), collapse = ",")), by = list(ID = src_id)]
lines = scc[prd[lines, on = "ID"], on = "ID"][, c("ID", "State", "sccssors", "prdcssors")]
lines[is.na(lines)] = ""
catf("Graph with %s PipeOps:", nrow(lines))
cat_cli(cli_h1("Graph with {nrow(lines)} PipeOps:"))
## limit column width ##

outwidth = getOption("width") %??% 80 # output width we want (default 80)
Expand All @@ -423,6 +423,15 @@ Graph = R6Class("Graph",
with_options(list(datatable.prettyprint.char = collimit), {
print(lines, row.names = FALSE)
})

is_sequential = all(table(self$edges$src_id) <= 1) && all(table(self$edges$dst_id) <= 1)
if(is_sequential) {
ppunit = paste0(self$ids(), collapse = " -> ")
pp = paste0(c("<INPUT>", ppunit, "<OUTPUT>"), collapse = " -> ")
} else {
pp = "non-sequential"
}
cat_cli(cli_h3("Pipeline: {.strong {pp}}"))
} else {
cat("Empty Graph.\n")
}
Expand Down
12 changes: 12 additions & 0 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,18 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
},
plot = function(html = FALSE, horizontal = FALSE, ...) {
private$.graph$plot(html = html, horizontal = horizontal, ...)
},
print = function() {
super$print(self)

is_sequential = all(table(self$edges$src_id) <= 1) && all(table(self$edges$dst_id) <= 1)
if(is_sequential) {
ppunit = paste0(self$ids(), collapse = " -> ")
pp = paste0(c("<INPUT>", ppunit, "<OUTPUT>"), collapse = " -> ")
} else {
pp = "non-sequential"
}
cat_cli(cli_h3("Pipeline: {.strong {pp}}"))
}
),
active = list(
Expand Down
23 changes: 14 additions & 9 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,23 @@ PipeOp = R6Class("PipeOp",

print = function(...) {
type_table_printout = function(table) {
strings = do.call(sprintf, cbind(fmt = "%s`[%s,%s]", table[, c("name", "train", "predict")]))
strings = strwrap(paste(strings, collapse = ", "), indent = 2, exdent = 2)
if (length(strings) > 6) {
strings = c(strings[1:5], sprintf(" [... (%s lines omitted)]", length(strings) - 5))
print(head(table, 5L), row.names = FALSE, print.keys = FALSE)
if (nrow(table) > 5L) {
catf("[...] (%i rows omitted)", nrow(table) - 5L)
}
gsub("`", " ", paste(strings, collapse = "\n"))
}

catf("PipeOp: <%s> (%strained)", self$id, if (self$is_trained) "" else "not ")
catf("values: <%s>", as_short_string(self$param_set$values))
catf("Input channels <name [train type, predict type]>:\n%s", type_table_printout(self$input))
catf("Output channels <name [train type, predict type]>:\n%s", type_table_printout(self$output))
msg_h = if (self$is_trained) "" else "not "
cat_cli({
cli_h1("PipeOp {.cls {self$id}}: {msg_h}trained")
cli_text("Values: {as_short_string(self$param_set$values)}")
cli_h3("{.strong Input channels:}")
})
type_table_printout(self$input)
cat_cli({
cli_h3("{.strong Output channels:}")
})
type_table_printout(self$output)
},

train = function(input) {
Expand Down
2 changes: 1 addition & 1 deletion R/multiplicity.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ print.Multiplicity = function(x, ...) {
if (!length(x)) {
cat("Empty Multiplicity.\n")
} else {
cat("Multiplicity:\n")
cli_h2("Multiplicity:")
print(unclass(x), ...)
}
}
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#' @import data.table
#' @import checkmate
#' @import cli
#' @import mlr3
#' @import paradox
#' @import mlr3misc
Expand Down
2 changes: 1 addition & 1 deletion inst/testthat/helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ expect_pipeop = function(po, check_ps_default_values = TRUE) {
expect_class(po$param_set, "ParamSet", label = label)
expect_list(po$param_set$values, names = "unique", label = label)
expect_flag(po$is_trained, label = label)
expect_output(print(po), "PipeOp:", label = label)
expect_output(print(po), "PipeOp", label = label)
expect_character(po$packages, any.missing = FALSE, unique = TRUE, label = label)
expect_function(po$train, nargs = 1)
expect_function(po$predict, nargs = 1)
Expand Down
107 changes: 107 additions & 0 deletions tests/testthat/_snaps/PipeOp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# PipeOp printer

Code
print(PipeOpNOP$new())
Output

-- PipeOp <nop>: not trained ---------------------------------------------------
Values: list()

-- Input channels:
name train predict
<char> <char> <char>
input * *

-- Output channels:
name train predict
<char> <char> <char>
output * *

---

Code
print(PipeOpDebugMulti$new(3, 4))
Output

-- PipeOp <debug.multi>: not trained -------------------------------------------
Values: list()

-- Input channels:
name train predict
<char> <char> <char>
input_1 * *
input_2 * *
input_3 * *

-- Output channels:
name train predict
<char> <char> <char>
output_1 * *
output_2 * *
output_3 * *
output_4 * *

---

Code
print(PipeOpDebugMulti$new(100, 0))
Output

-- PipeOp <debug.multi>: not trained -------------------------------------------
Values: list()

-- Input channels:
name train predict
<char> <char> <char>
input_1 * *
input_2 * *
input_3 * *
input_4 * *
input_5 * *
[...] (95 rows omitted)

-- Output channels:
name train predict
<char> <char> <char>
output_ * *

---

Code
print(PipeOpBranch$new(c("odin", "dva", "tri")))
Output

-- PipeOp <branch>: not trained ------------------------------------------------
Values: selection=odin

-- Input channels:
name train predict
<char> <char> <char>
input * *

-- Output channels:
name train predict
<char> <char> <char>
odin * *
dva * *
tri * *

---

Code
print(PipeOpLearner$new(mlr_learners$get("classif.debug")))
Output

-- PipeOp <classif.debug>: not trained -----------------------------------------
Values: list()

-- Input channels:
name train predict
<char> <char> <char>
input TaskClassif TaskClassif

-- Output channels:
name train predict
<char> <char> <char>
output NULL PredictionClassif

8 changes: 4 additions & 4 deletions tests/testthat/test_Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ test_that("linear graph", {

expect_graph(g)

expect_output(print(g), "Graph with 2 PipeOps.*subsample.*UNTRAINED.*pca.*UNTRAINED")


expect_output(print(g), "Graph with 2 PipeOps:")
expect_output(print(g), ".*subsample.*UNTRAINED.*pca.*UNTRAINED")

inputs = mlr_tasks$get("iris")
x = g$train(inputs)
expect_task(x[[1]])

expect_output(print(g), "Graph with 2 PipeOps.*subsample.*list.*pca.*prcomp")
expect_output(print(g), "Graph with 2 PipeOps")
expect_output(print(g), ".*subsample.*list.*pca.*prcomp")

out = g$predict(inputs)
expect_task(x[[1]])
Expand Down
23 changes: 6 additions & 17 deletions tests/testthat/test_PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ test_that("PipeOp - General functions", {
expect_false(po_1$is_trained)
expect_class(po_1$param_set, "ParamSet")
expect_list(po_1$param_set$values, names = "unique")
expect_output(print(po_1), "PipeOp:")
expect_output(print(po_1), "PipeOp")
expect_equal(po_1$packages, "mlr3pipelines")
expect_null(po_1$state)
assert_subset(po_1$tags, mlr_reflections$pipeops$valid_tags)
Expand All @@ -29,22 +29,11 @@ test_that("PipeOp - simple tests with PipeOpScale", {
})

test_that("PipeOp printer", {
expect_output(print(PipeOpNOP$new()),
"PipeOp.*<nop>.*not trained.*values.*list().*Input channels.*input \\[\\*,\\*\\]\n.*Output channels.*output \\[\\*,\\*\\]$")


expect_output(print(PipeOpDebugMulti$new(3, 4)),
"PipeOp.*<debug.multi>.*not trained.*values.*list().*Input channels.*input_1 \\[\\*,\\*\\], input_2 \\[\\*,\\*\\], input_3 \\[\\*,\\*\\]\n.*Output channels.*output_1 \\[\\*,\\*\\], output_2 \\[\\*,\\*\\], output_3 \\[\\*,\\*\\], output_4 \\[\\*,\\*\\]$")


expect_output(print(PipeOpDebugMulti$new(100, 0)),
"\\[\\.\\.\\. \\([0-9]+ lines omitted\\)\\]")

expect_output(print(PipeOpBranch$new(c("odin", "dva", "tri"))),
"Output channels.*odin \\[\\*,\\*\\], dva \\[\\*,\\*\\], tri \\[\\*,\\*\\]$")

expect_output(print(PipeOpLearner$new(mlr_learners$get("classif.debug"))),
"PipeOp.*<classif.debug>.*Input channels.*input \\[TaskClassif,TaskClassif\\]\nOutput channels.*output \\[NULL,PredictionClassif\\]$")
expect_snapshot(print(PipeOpNOP$new()))
expect_snapshot(print(PipeOpDebugMulti$new(3, 4)))
expect_snapshot(print(PipeOpDebugMulti$new(100, 0)))
expect_snapshot(print(PipeOpBranch$new(c("odin", "dva", "tri"))))
expect_snapshot(print(PipeOpLearner$new(mlr_learners$get("classif.debug"))))
})

test_that("Prevent creation of PipeOps with no channels", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_multiplicities.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test_that("Multiplicity class and methods", {
expect_multiplicity(assert_multiplicity(nmp))
expect_multiplicity(assert_multiplicity(nmp, check_nesting = TRUE))
expect_error(assert_multiplicity(as.Multiplicity(list(0, Multiplicity(0))), .var.name = "y", check_nesting = TRUE), regexp = "Inconsistent multiplicity nesting level")
expect_output(print(mp), regexp = "Multiplicity:")
expect_message(print(mp), regexp = "Multiplicity:")
expect_output(print(Multiplicity()), regexp = "Empty Multiplicity.")
})

Expand Down
Loading