Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ Imports:
mlr3 (>= 0.20.0),
mlr3misc (>= 0.17.0),
paradox (>= 1.0.0),
R6,
withr
R6
Suggests:
ggplot2,
glmnet,
Expand Down
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -240,4 +240,3 @@ importFrom(stats,setNames)
importFrom(utils,bibentry)
importFrom(utils,head)
importFrom(utils,tail)
importFrom(withr,with_options)
42 changes: 21 additions & 21 deletions R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ Graph = R6Class("Graph",
assert_choice(src_id, names(self$pipeops))
assert_choice(dst_id, names(self$pipeops))
if (is.null(src_channel)) {
if (length(self$pipeops[[src_id]]$output$name) > 1) {
if (length(self$pipeops[[src_id]]$output$name) > 1L) {
stopf("src_channel must not be NULL if src_id pipeop has more than one output channel.")
}
src_channel = 1L
}
if (is.null(dst_channel)) {
if (length(self$pipeops[[dst_id]]$input$name) > 1) {
if (length(self$pipeops[[dst_id]]$input$name) > 1L) {
stopf("dst_channel must not be NULL if dst_id pipeop has more than one input channel.")
}
dst_channel = 1L
Expand All @@ -225,7 +225,7 @@ Graph = R6Class("Graph",
src_channel = self$pipeops[[src_id]]$output$name[src_channel]
}
assert(
check_integerish(dst_channel, lower = 1,
check_integerish(dst_channel, lower = 1L,
upper = nrow(self$pipeops[[dst_id]]$input), any.missing = FALSE),
check_choice(dst_channel, self$pipeops[[dst_id]]$input$name)
)
Expand Down Expand Up @@ -283,7 +283,7 @@ Graph = R6Class("Graph",
df = self$edges[, list(from = src_id, to = dst_id)]
df = rbind(df, self$input[, list(from = "<INPUT>", to = op.id)])
output = self$output
if (nrow(output) > 1) {
if (nrow(output) > 1L) {
# In case we have multiple outputs, we add an output for every final node
df = rbind(df, output[, list(from = op.id, to = paste0("<OUTPUT>\n", name))])
} else {
Expand All @@ -309,7 +309,7 @@ Graph = R6Class("Graph",
if (node == "<INPUT>") {
txt = paste0("Input:<br>Name: ", self$input$name, "<br>Train: ", null_str(self$input$train), "<br>Predict: ", null_str(self$input$predict))
} else if (grepl("<OUTPUT>", node)) {
if (nrow(self$output) > 1) {
if (nrow(self$output) > 1L) {
out = self$output[self$output$name == gsub("<OUTPUT>\n", "", node), ] # Deal with multiple outputs
} else {
out = self$output # Standard case, single output
Expand Down Expand Up @@ -342,8 +342,8 @@ Graph = R6Class("Graph",
if (horizontal) {
layout = -layout[, 2:1]
}
layout[, 1] = layout[, 1] * .75
layout[, 2] = layout[, 2] * .75
layout[, 1L] = layout[, 1L] * .75
layout[, 2L] = layout[, 2L] * .75

defaultargs = list(vertex.shape = "crectangle", vertex.size = 60, vertex.size2 = 15 * 2.5, vertex.color = 0,
xlim = range(layout[, 1]) + c(-0.3, 0.3),
Expand Down Expand Up @@ -407,7 +407,7 @@ Graph = R6Class("Graph",
# print table <id>, <state>, where <state> is `class(pipeop$state)`
lines = rbindlist(map(self$pipeops[self$ids(sorted = TRUE)], function(pipeop) {
data.table(ID = pipeop$id, State = sprintf("<%s>",
map_values(class(pipeop$state)[1], "NULL", "<UNTRAINED>")))
map_values(class(pipeop$state)[1L], "NULL", "<UNTRAINED>")))
}), use.names = TRUE)
if (nrow(lines)) {
prd = self$edges[, list(prdcssors = paste(unique(src_id), collapse = ",")), by = list(ID = dst_id)]
Expand All @@ -420,9 +420,9 @@ Graph = R6Class("Graph",
outwidth = getOption("width") %??% 80 # output width we want (default 80)
colwidths = map_int(lines, function(x) max(nchar(x), na.rm = TRUE)) # original width of columns
collimit = calculate_collimit(colwidths, outwidth)
with_options(list(datatable.prettyprint.char = collimit), {
print(lines, row.names = FALSE)
})
opts = options(datatable.prettyprint.char = collimit)
on.exit(options(opts), add = TRUE)
print(lines, row.names = FALSE)
} else {
cat("Empty Graph.\n")
}
Expand All @@ -436,7 +436,7 @@ Graph = R6Class("Graph",
set_names = function(old, new) {
ids = names2(self$pipeops)
assert_subset(old, ids)
assert_character(new, any.missing = FALSE, min.chars = 1)
assert_character(new, any.missing = FALSE, min.chars = 1L)
new_ids = map_values(ids, old, new)
names(self$pipeops) = new_ids
imap(self$pipeops, function(x, nn) x$id = nn)
Expand Down Expand Up @@ -465,8 +465,8 @@ Graph = R6Class("Graph",
},

help = function(help_type = getOption("help_type")) {
parts = strsplit(self$man, split = "::", fixed = TRUE)[[1]]
match.fun("help")(parts[[2]], package = parts[[1]], help_type = help_type)
parts = strsplit(self$man, split = "::", fixed = TRUE)[[1L]]
match.fun("help")(parts[[2L]], package = parts[[1L]], help_type = help_type)
}
),

Expand Down Expand Up @@ -549,8 +549,8 @@ graph_channels = function(ids, channels, pipeops, direction) {
df$op.id = po$id
df = df[rows,
c("name", "train", "predict", "op.id", "name")]
df[[1]] = paste0(po$id, ".", df[[1]])
names(df)[5] = "channel.name"
df[[1L]] = paste0(po$id, ".", df[[1L]])
names(df)[5L] = "channel.name"
df
})

Expand Down Expand Up @@ -606,12 +606,12 @@ graph_reduce = function(self, input, fun, single_input) {
# inputs differs from the number of channels -- theoretically, there could be two varargs, one
# getting two inputs, the other none.
if (!single_input && "..." %in% graph_input$channel.name) {
if (sum("..." == graph_input$channel.name) != 1 && is.null(names(input))) {
if (sum("..." == graph_input$channel.name) != 1L && is.null(names(input))) {
stop("Ambiguous distribution of inputs to vararg channels.\nAssigning more than one input to vararg channels when there are multiple vararg inputs does not work.\nYou can try using a named input list. Vararg elements must be named '<pipeopname>....' (with four dots).")
}
# repeat the "..." as often as necessary
if (is.null(names(input))) {
repeats = ifelse(graph_input$channel.name == "...", length(input) - nrow(graph_input) + 1, 1)
repeats = ifelse(graph_input$channel.name == "...", length(input) - nrow(graph_input) + 1L, 1L)
} else {
repeats = nafill(as.numeric(table(names(input))[graph_input$name]), fill = 0)
}
Expand Down Expand Up @@ -709,7 +709,7 @@ graph_load_namespaces = function(self, info) {
NULL
}, error = function(e) {
sprintf("Error loading package %s (required by %s):\n %s",
package, str_collapse(pipeops, n = 4), e$message)
package, str_collapse(pipeops, n = 4L), e$message)
})
})
errors = discard(errors, is.null)
Expand All @@ -725,7 +725,7 @@ predict.Graph = function(object, newdata, ...) {
stop("Cannot predict, Graph has not been trained yet")
}
output = object$output
if (nrow(output) != 1) {
if (nrow(output) != 1L) {
stop("Graph has more than one output channel")
}
if (!are_types_compatible(output$predict, "Prediction")) {
Expand All @@ -751,7 +751,7 @@ predict.Graph = function(object, newdata, ...) {
)
}
result = object$predict(newdata)
assert_list(result, types = "Prediction", any.missing = FALSE, len = 1)
assert_list(result, types = "Prediction", any.missing = FALSE, len = 1L)
result = result[[1]]
if (plain) {
result = result$data$response %??% result$data$prob
Expand Down
1 change: 0 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#' @importFrom R6 R6Class
#' @importFrom utils tail head
#' @importFrom digest digest
#' @importFrom withr with_options
#' @importFrom stats setNames
"_PACKAGE"

Expand Down