diff --git a/DESCRIPTION b/DESCRIPTION index 84630662..d9697736 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -60,6 +60,7 @@ Roxygen: list(markdown = TRUE, r6 = TRUE) RoxygenNote: 7.1.2 Collate: 'Condition.R' + 'ContextPV.R' 'Design.R' 'NoDefault.R' 'Param.R' diff --git a/NAMESPACE b/NAMESPACE index 6153cbcd..7df28f76 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(as.data.table,Param) S3method(as.data.table,ParamSet) +S3method(print,ContextPV) S3method(print,Domain) S3method(print,FullTuneToken) S3method(print,ObjectTuneToken) @@ -10,6 +11,7 @@ S3method(rd_info,ParamSet) export(CondAnyOf) export(CondEqual) export(Condition) +export(ContextPV) export(Design) export(NO_DEF) export(NoDefault) diff --git a/R/ContextPV.R b/R/ContextPV.R new file mode 100644 index 00000000..bf4cf953 --- /dev/null +++ b/R/ContextPV.R @@ -0,0 +1,45 @@ + +#' @title ParamSet Value that Depends on Context +#' +#' @description +#' Set a [`ParamSet`]`$value` slot to this. The `.fn` function will +#' be called with the respective function arguments. The function argument +#' names of `.fn` must be a subset of the [`ParamSet`]'s `$context_available` slot. +#' +#' @param .fn `function`\cr +#' Function to be executed in the context where [`ParamSet`] values +#' are retrieved. +#' @param ... any\cr +#' Variable names to make available to function. Functions are run in +#' the [`.GlobalEnv`][base::.GlobalEnv] scope and only the variables +#' named in `...` will be available. +#' @examples +#' p = ps(x = p_dbl(), y = p_dbl()) +#' p$context_available = c("a", "x") +#' +#' b = 10 +#' +#' p$values$x = ContextPV(function(a) a * b, b) +#' p$values$y = ContextPV(function(x) x^2) +#' # ContextPV uses the value of b right at this moment +#' b = 20 +#' +#' p$values$x(1) # 1 * 10 +#' +#' p$get_values(context = list(a = 10, x = 20)) # using 'a' from context: 10 * 10 +#' +#' @export +ContextPV = function(.fn, ...) { + assert_function(.fn) + set_class(crate(.fn, ...), c("ContextPV", "function")) +} + +#' @export +print.ContextPV = function(x, ...) { + y = x + cat("ContextPV ") + environment(y) = .GlobalEnv + print(unclass(y), ...) + cat("Using following environment:\n") + print(as.list(environment(x))) +} diff --git a/R/Param.R b/R/Param.R index a004c4d2..648a9600 100644 --- a/R/Param.R +++ b/R/Param.R @@ -77,6 +77,7 @@ Param = R6Class("Param", TRUE }, error = function(e) paste("tune token invalid:", conditionMessage(e)))) } + if (inherits(x, "ContextPV")) return(TRUE) ch = private$.check(x) ifelse(isTRUE(ch) || has_element(self$special_vals, x), TRUE, ch) }, diff --git a/R/ParamSet.R b/R/ParamSet.R index 103efceb..526569cb 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -45,6 +45,13 @@ ParamSet = R6Class("ParamSet", #' Default is `TRUE`, only switch this off if you know what you are doing. assert_values = TRUE, + #' @field context_available (`character`)\cr + #' Context that [`ContextPV`] values can rely on being present. When assigning a [`ContextPV`] + #' to a `$values` element, its argument names must be a subset of `$context_available`. + #' Conversely, the `context` argument of `$get_values()` must contain named elements + #' corresponding to all `$context_available` entries. + context_available = character(0), + #' @description #' Creates a new instance of this [R6][R6::R6Class] class. #' @@ -138,15 +145,20 @@ ParamSet = R6Class("ParamSet", #' Return values `with_token`, `without_token` or `only_token`? #' @param check_required (`logical(1)`)\cr #' Check if all required parameters are set? + #' @param context (`environment` | named `list`)\cr + #' Must have elements named by `$context_available`, which are then given to the + #' respective arguments of [`ContextPV`] values. #' @return Named `list()`. get_values = function(class = NULL, is_bounded = NULL, tags = NULL, - type = "with_token", check_required = TRUE) { + type = "with_token", check_required = TRUE, context = parent.frame()) { assert_choice(type, c("with_token", "without_token", "only_token")) assert_flag(check_required) values = self$values params = self$params_unid ns = names(values) + if (type != "only_token") assert_names(names(context), type = "unique", must.include = self$context_available) + if (type == "without_token") { values = discard(values, is, "TuneToken") } else if (type == "only_token") { @@ -160,7 +172,23 @@ ParamSet = R6Class("ParamSet", } } - values[intersect(names(values), self$ids(class = class, is_bounded = is_bounded, tags = tags))] + values = values[intersect(names(values), self$ids(class = class, is_bounded = is_bounded, tags = tags))] + + if (type == "only_token") return(values) + + imap(values, function(x, name) { + if (!inherits(x, "ContextPV")) return(x) + x = do.call(x, lapply(names(formals(args(x))), get, pos = context)) + checked = params[[name]]$check(x) + if (!isTRUE(checked)) { + stopf("ContextPV for %s resulted in infeasible value:\n%s", + name, checked) + } + if (!has_element(params[[name]]$special_vals, x)) { + x = params[[name]]$convert(x) + } + x + }) }, #' @description @@ -225,11 +253,18 @@ ParamSet = R6Class("ParamSet", return(sprintf("Parameter '%s' not available.%s", ns[extra], did_you_mean(extra, ids))) } - # check each parameters feasibility - for (n in ns) { - ch = params[[n]]$check(xs[[n]]) - if (test_string(ch)) { # we failed a check, return string - return(paste0(n, ": ", ch)) + # check each parameters feasibility, only necessary if we are a leaf ParamSet + if (!inherits(self, "ParamSetCollection")) { + for (n in ns) { + if (inherits(xs[[n]], "ContextPV")) { + ch = check_names(names(formals(args(xs[[n]]))), type = "unique", subset.of = self$context_available) + if (!isTRUE(ch)) ch = sprintf("Argument names of ContextPV %s", ch) + } else { + ch = params[[n]]$check(xs[[n]]) + } + if (test_string(ch)) { # we failed a check, return string + return(paste0(n, ": ", ch)) + } } } @@ -247,7 +282,9 @@ ParamSet = R6Class("ParamSet", # - if param is there, then parent must be there, then cond must be true # - if param is not there cond = deps$cond[[j]] - ok = (p1id %in% ns && p2id %in% ns && cond$test(xs[[p2id]])) || + ok = (p1id %in% ns && p2id %in% ns && + !inherits(xs[[p2id]], "ContextPV") && + cond$test(xs[[p2id]])) || (p1id %nin% ns) if (isFALSE(ok)) { message = sprintf("The parameter '%s' can only be set if the following condition is met '%s'.", @@ -256,6 +293,8 @@ ParamSet = R6Class("ParamSet", if (is.null(val)) { message = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.", "Try setting '%s' to a value that satisfies the condition"), message, p2id, p2id) + } else if (inherits(val, "ContextPV")) { + message = sprintf("%s However, %s is a ContextPV. Conditions on ContextPV values default to FALSE / unmet.", message, p2id) } else { message = sprintf("%s Instead the current parameter value is: %s=%s", message, p2id, val) } @@ -584,7 +623,7 @@ ParamSet = R6Class("ParamSet", for (n in names(xs)) { p = params[[n]] x = xs[[n]] - if (inherits(x, "TuneToken")) next + if (inherits(x, "TuneToken") || inherits(x, "ContextPV")) next if (has_element(p$special_vals, x)) next xs[[n]] = p$convert(x) } diff --git a/R/Sampler.R b/R/Sampler.R index 8b7c5e34..528773cc 100644 --- a/R/Sampler.R +++ b/R/Sampler.R @@ -19,7 +19,7 @@ Sampler = R6Class("Sampler", #' Note that this object is typically constructed via derived classes, #' e.g., [Sampler1D]. initialize = function(param_set) { - assert_param_set(param_set, no_untyped = TRUE) + assert_param_set(param_set) self$param_set = param_set$clone(deep = TRUE) }, diff --git a/man/ContextPV.Rd b/man/ContextPV.Rd new file mode 100644 index 00000000..06349322 --- /dev/null +++ b/man/ContextPV.Rd @@ -0,0 +1,39 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ContextPV.R +\name{ContextPV} +\alias{ContextPV} +\title{ParamSet Value that Depends on Context} +\usage{ +ContextPV(.fn, ...) +} +\arguments{ +\item{.fn}{\code{function}\cr +Function to be executed in the context where \code{\link{ParamSet}} values +are retrieved.} + +\item{...}{any\cr +Variable names to make available to function. Functions are run in +the \code{\link[base:environment]{.GlobalEnv}} scope and only the variables +named in \code{...} will be available.} +} +\description{ +Set a \code{\link{ParamSet}}\verb{$value} slot to this. The \code{.fn} function will +be called with the respective function arguments. The function argument +names of \code{.fn} must be a subset of the \code{\link{ParamSet}}'s \verb{$context_available} slot. +} +\examples{ +p = ps(x = p_dbl(), y = p_dbl()) +p$context_available = c("a", "x") + +b = 10 + +p$values$x = ContextPV(function(a) a * b, b) +p$values$y = ContextPV(function(x) x^2) +# ContextPV uses the value of b right at this moment +b = 20 + +p$values$x(1) # 1 * 10 + +p$get_values(context = list(a = 10, x = 20)) # using 'a' from context: 10 * 10 + +} diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index 14fca66c..ca072097 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -53,6 +53,12 @@ ps$check(list(d = 2.1, f = "a", i = 3L)) \item{\code{assert_values}}{(\code{logical(1)})\cr Should values be checked for validity during assigment to active binding \verb{$values}? Default is \code{TRUE}, only switch this off if you know what you are doing.} + +\item{\code{context_available}}{(\code{character})\cr +Context that \code{\link{ContextPV}} values can rely on being present. When assigning a \code{\link{ContextPV}} +to a \verb{$values} element, its argument names must be a subset of \verb{$context_available}. +Conversely, the \code{context} argument of \verb{$get_values()} must contain named elements +corresponding to all \verb{$context_available} entries.} } \if{html}{\out{}} } @@ -259,7 +265,8 @@ Only returns values of parameters that satisfy all conditions. is_bounded = NULL, tags = NULL, type = "with_token", - check_required = TRUE + check_required = TRUE, + context = parent.frame() )}\if{html}{\out{}} } @@ -277,6 +284,10 @@ Return values \code{with_token}, \code{without_token} or \code{only_token}?} \item{\code{check_required}}{(\code{logical(1)})\cr Check if all required parameters are set?} + +\item{\code{context}}{(\code{environment} | named \code{list})\cr +Must have elements named by \verb{$context_available}, which are then given to the +respective arguments of \code{\link{ContextPV}} values.} } \if{html}{\out{}} } diff --git a/tests/testthat/test_ContextPV.R b/tests/testthat/test_ContextPV.R new file mode 100644 index 00000000..77e7d572 --- /dev/null +++ b/tests/testthat/test_ContextPV.R @@ -0,0 +1,108 @@ + +context("ContextPV") + +test_that("ContextPV construction", { + y = 2 + fpv = ContextPV(function(x) x * y, y) + expect_equal(fpv(10), 20) + expect_output(print(fpv), "ContextPV function.*x \\* y.*\\$y.*2") +}) + +test_that("ContextPV applied in get_values()", { + p = ParamSet$new(list(ParamInt$new("x"))) + y = 2 + expect_error({p$values$x = ContextPV(function(x) x * y, y)}, "Argument names of ContextPV Must be a subset of set \\{\\}") + + p$context_available = c("a", "x") + + p$values$x = ContextPV(function(x) x * y, y) + expect_output(print(p$values$x), "x \\* y.*\\$y.*2") + + expect_error(p$get_values(), "context.*Must include the elements \\{a,x\\}") + + expect_equal(p$get_values(context = list(x = 10, a = 20)), list(x = 20)) + + x = 30 + a = 20 + expect_equal(p$get_values(), list(x = 60)) # default context is parent.frame + +}) + +test_that("ContextPV checks range", { + p = ParamSet$new(list(ParamInt$new("x", lower = 0, upper = 30))) + p$context_available = "x" + + y = 2 + p$values$x = ContextPV(function(x) x * y, y) + + expect_equal(p$get_values(context = list(x = 10)), list(x = 20)) + expect_error(p$get_values(context = list(x = 20)), " x resulted in infeasible value.*is not <= 30") + + expect_error(p$get_values(context = list(x = 10.25)), " x resulted in infeasible value.*not 'double'") + +}) + +test_that("ContextPV convert", { + + p = ParamSet$new(list(ParamDbl$new("x", lower = 0, upper = 30))) + p$context_available = "x" + + y = 2 + p$values$x = ContextPV(function(x) x * y, y) + + expect_equal(p$get_values(context = list(x = 10)), list(x = 20)) + # convert to within range + expect_equal(p$get_values(context = list(x = 15.0000000001)), list(x = 30), tolerance = 1e-100) +}) + +test_that("ContextPV may not be in variable with dependency", { + p = ParamSet$new(list( + ParamInt$new("x"), + ParamInt$new("y") + )) + p$add_dep("x", "y", CondEqual$new(0)) + y = 2 + expect_error({p$values = list(x = 1, y = 1)}, "can only be set if the following.*y = 0") + + expect_error({p$values = list(x = 1, y = 0)}, NA) + expect_error({p$values = list(y = 1)}, NA) + + fpv = ContextPV(function(x) x * y, y) + p$context_available = "x" + expect_error({p$values = list(y = fpv)}, NA) + + expect_error({p$values = list(x = 1, y = fpv)}, "y is a ContextPV") +}) + +test_that("ContextPV in Tuning PS", { + p = ParamSet$new(list( + ParamInt$new("x"), + ParamInt$new("y") + )) + p$context_available = "scale" + p2 = p$search_space(list( + x = to_tune(0, 10), + y = to_tune(p_dbl(0, 1, + trafo = function(x) ContextPV(function(scale) scale * x, x))) + )) + + paramval = generate_design_grid(p2, 3)$transpose()[[5]] # 5, 0.5 + p$values = paramval + expect_equal(p$get_values(context = list(scale = 100)), list(x = 5, y = 50)) + expect_equal(p$get_values(context = list(scale = 10)), list(x = 5, y = 5)) + expect_error(p$get_values(context = list(scale = 1)), "infeasible value.*not 'double'") + +}) + + +test_that("ContextPV with PSC", { + p = ps(x = p_dbl()) + p$context_available = "x" + p$set_id = "n" + psc = ParamSetCollection$new(list(p)) + expect_error({psc$values$n.x = ContextPV(function(y) x * 2)}, "Must be a subset of set \\{x\\}") + psc$values$n.x = ContextPV(function(x) x * 2) + + expect_equal(p$get_values(context = list(x = 10)), list(x = 20)) + +})