Skip to content
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.1.2
Collate:
'Condition.R'
'ContextPV.R'
'Design.R'
'NoDefault.R'
'Param.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method(as.data.table,Param)
S3method(as.data.table,ParamSet)
S3method(print,ContextPV)
S3method(print,Domain)
S3method(print,FullTuneToken)
S3method(print,ObjectTuneToken)
Expand All @@ -10,6 +11,7 @@ S3method(rd_info,ParamSet)
export(CondAnyOf)
export(CondEqual)
export(Condition)
export(ContextPV)
export(Design)
export(NO_DEF)
export(NoDefault)
Expand Down
45 changes: 45 additions & 0 deletions R/ContextPV.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

#' @title ParamSet Value that Depends on Context
#'
#' @description
#' Set a [`ParamSet`]`$value` slot to this. The `.fn` function will
#' be called with the respective function arguments. The function argument
#' names of `.fn` must be a subset of the [`ParamSet`]'s `$context_available` slot.
#'
#' @param .fn `function`\cr
#' Function to be executed in the context where [`ParamSet`] values
#' are retrieved.
#' @param ... any\cr
#' Variable names to make available to function. Functions are run in
#' the [`.GlobalEnv`][base::.GlobalEnv] scope and only the variables
#' named in `...` will be available.
#' @examples
#' p = ps(x = p_dbl(), y = p_dbl())
#' p$context_available = c("a", "x")
#'
#' b = 10
#'
#' p$values$x = ContextPV(function(a) a * b, b)
#' p$values$y = ContextPV(function(x) x^2)
#' # ContextPV uses the value of b right at this moment
#' b = 20
#'
#' p$values$x(1) # 1 * 10
#'
#' p$get_values(context = list(a = 10, x = 20)) # using 'a' from context: 10 * 10
#'
#' @export
ContextPV = function(.fn, ...) {
assert_function(.fn)
set_class(crate(.fn, ...), c("ContextPV", "function"))
}

#' @export
print.ContextPV = function(x, ...) {
y = x
cat("ContextPV ")
environment(y) = .GlobalEnv
print(unclass(y), ...)
cat("Using following environment:\n")
print(as.list(environment(x)))
}
1 change: 1 addition & 0 deletions R/Param.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Param = R6Class("Param",
TRUE
}, error = function(e) paste("tune token invalid:", conditionMessage(e))))
}
if (inherits(x, "ContextPV")) return(TRUE)
ch = private$.check(x)
ifelse(isTRUE(ch) || has_element(self$special_vals, x), TRUE, ch)
},
Expand Down
57 changes: 48 additions & 9 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ ParamSet = R6Class("ParamSet",
#' Default is `TRUE`, only switch this off if you know what you are doing.
assert_values = TRUE,

#' @field context_available (`character`)\cr
#' Context that [`ContextPV`] values can rely on being present. When assigning a [`ContextPV`]
#' to a `$values` element, its argument names must be a subset of `$context_available`.
#' Conversely, the `context` argument of `$get_values()` must contain named elements
#' corresponding to all `$context_available` entries.
context_available = character(0),

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
Expand Down Expand Up @@ -138,15 +145,20 @@ ParamSet = R6Class("ParamSet",
#' Return values `with_token`, `without_token` or `only_token`?
#' @param check_required (`logical(1)`)\cr
#' Check if all required parameters are set?
#' @param context (`environment` | named `list`)\cr
#' Must have elements named by `$context_available`, which are then given to the
#' respective arguments of [`ContextPV`] values.
#' @return Named `list()`.
get_values = function(class = NULL, is_bounded = NULL, tags = NULL,
type = "with_token", check_required = TRUE) {
type = "with_token", check_required = TRUE, context = parent.frame()) {
assert_choice(type, c("with_token", "without_token", "only_token"))
assert_flag(check_required)
values = self$values
params = self$params_unid
ns = names(values)

if (type != "only_token") assert_names(names(context), type = "unique", must.include = self$context_available)

if (type == "without_token") {
values = discard(values, is, "TuneToken")
} else if (type == "only_token") {
Expand All @@ -160,7 +172,23 @@ ParamSet = R6Class("ParamSet",
}
}

values[intersect(names(values), self$ids(class = class, is_bounded = is_bounded, tags = tags))]
values = values[intersect(names(values), self$ids(class = class, is_bounded = is_bounded, tags = tags))]

if (type == "only_token") return(values)

imap(values, function(x, name) {
if (!inherits(x, "ContextPV")) return(x)
x = do.call(x, lapply(names(formals(args(x))), get, pos = context))
checked = params[[name]]$check(x)
if (!isTRUE(checked)) {
stopf("ContextPV for %s resulted in infeasible value:\n%s",
name, checked)
}
if (!has_element(params[[name]]$special_vals, x)) {
x = params[[name]]$convert(x)
}
x
})
},

#' @description
Expand Down Expand Up @@ -225,11 +253,18 @@ ParamSet = R6Class("ParamSet",
return(sprintf("Parameter '%s' not available.%s", ns[extra], did_you_mean(extra, ids)))
}

# check each parameters feasibility
for (n in ns) {
ch = params[[n]]$check(xs[[n]])
if (test_string(ch)) { # we failed a check, return string
return(paste0(n, ": ", ch))
# check each parameters feasibility, only necessary if we are a leaf ParamSet
if (!inherits(self, "ParamSetCollection")) {
for (n in ns) {
if (inherits(xs[[n]], "ContextPV")) {
ch = check_names(names(formals(args(xs[[n]]))), type = "unique", subset.of = self$context_available)
if (!isTRUE(ch)) ch = sprintf("Argument names of ContextPV %s", ch)
} else {
ch = params[[n]]$check(xs[[n]])
}
if (test_string(ch)) { # we failed a check, return string
return(paste0(n, ": ", ch))
}
}
}

Expand All @@ -247,7 +282,9 @@ ParamSet = R6Class("ParamSet",
# - if param is there, then parent must be there, then cond must be true
# - if param is not there
cond = deps$cond[[j]]
ok = (p1id %in% ns && p2id %in% ns && cond$test(xs[[p2id]])) ||
ok = (p1id %in% ns && p2id %in% ns &&
!inherits(xs[[p2id]], "ContextPV") &&
cond$test(xs[[p2id]])) ||
(p1id %nin% ns)
if (isFALSE(ok)) {
message = sprintf("The parameter '%s' can only be set if the following condition is met '%s'.",
Expand All @@ -256,6 +293,8 @@ ParamSet = R6Class("ParamSet",
if (is.null(val)) {
message = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.",
"Try setting '%s' to a value that satisfies the condition"), message, p2id, p2id)
} else if (inherits(val, "ContextPV")) {
message = sprintf("%s However, %s is a ContextPV. Conditions on ContextPV values default to FALSE / unmet.", message, p2id)
} else {
message = sprintf("%s Instead the current parameter value is: %s=%s", message, p2id, val)
}
Expand Down Expand Up @@ -584,7 +623,7 @@ ParamSet = R6Class("ParamSet",
for (n in names(xs)) {
p = params[[n]]
x = xs[[n]]
if (inherits(x, "TuneToken")) next
if (inherits(x, "TuneToken") || inherits(x, "ContextPV")) next
if (has_element(p$special_vals, x)) next
xs[[n]] = p$convert(x)
}
Expand Down
2 changes: 1 addition & 1 deletion R/Sampler.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Sampler = R6Class("Sampler",
#' Note that this object is typically constructed via derived classes,
#' e.g., [Sampler1D].
initialize = function(param_set) {
assert_param_set(param_set, no_untyped = TRUE)
assert_param_set(param_set)
self$param_set = param_set$clone(deep = TRUE)
},

Expand Down
39 changes: 39 additions & 0 deletions man/ContextPV.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion man/ParamSet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

108 changes: 108 additions & 0 deletions tests/testthat/test_ContextPV.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@

context("ContextPV")

test_that("ContextPV construction", {
y = 2
fpv = ContextPV(function(x) x * y, y)
expect_equal(fpv(10), 20)
expect_output(print(fpv), "ContextPV function.*x \\* y.*\\$y.*2")
})

test_that("ContextPV applied in get_values()", {
p = ParamSet$new(list(ParamInt$new("x")))
y = 2
expect_error({p$values$x = ContextPV(function(x) x * y, y)}, "Argument names of ContextPV Must be a subset of set \\{\\}")

p$context_available = c("a", "x")

p$values$x = ContextPV(function(x) x * y, y)
expect_output(print(p$values$x), "x \\* y.*\\$y.*2")

expect_error(p$get_values(), "context.*Must include the elements \\{a,x\\}")

expect_equal(p$get_values(context = list(x = 10, a = 20)), list(x = 20))

x = 30
a = 20
expect_equal(p$get_values(), list(x = 60)) # default context is parent.frame

})

test_that("ContextPV checks range", {
p = ParamSet$new(list(ParamInt$new("x", lower = 0, upper = 30)))
p$context_available = "x"

y = 2
p$values$x = ContextPV(function(x) x * y, y)

expect_equal(p$get_values(context = list(x = 10)), list(x = 20))
expect_error(p$get_values(context = list(x = 20)), " x resulted in infeasible value.*is not <= 30")

expect_error(p$get_values(context = list(x = 10.25)), " x resulted in infeasible value.*not 'double'")

})

test_that("ContextPV convert", {

p = ParamSet$new(list(ParamDbl$new("x", lower = 0, upper = 30)))
p$context_available = "x"

y = 2
p$values$x = ContextPV(function(x) x * y, y)

expect_equal(p$get_values(context = list(x = 10)), list(x = 20))
# convert to within range
expect_equal(p$get_values(context = list(x = 15.0000000001)), list(x = 30), tolerance = 1e-100)
})

test_that("ContextPV may not be in variable with dependency", {
p = ParamSet$new(list(
ParamInt$new("x"),
ParamInt$new("y")
))
p$add_dep("x", "y", CondEqual$new(0))
y = 2
expect_error({p$values = list(x = 1, y = 1)}, "can only be set if the following.*y = 0")

expect_error({p$values = list(x = 1, y = 0)}, NA)
expect_error({p$values = list(y = 1)}, NA)

fpv = ContextPV(function(x) x * y, y)
p$context_available = "x"
expect_error({p$values = list(y = fpv)}, NA)

expect_error({p$values = list(x = 1, y = fpv)}, "y is a ContextPV")
})

test_that("ContextPV in Tuning PS", {
p = ParamSet$new(list(
ParamInt$new("x"),
ParamInt$new("y")
))
p$context_available = "scale"
p2 = p$search_space(list(
x = to_tune(0, 10),
y = to_tune(p_dbl(0, 1,
trafo = function(x) ContextPV(function(scale) scale * x, x)))
))

paramval = generate_design_grid(p2, 3)$transpose()[[5]] # 5, 0.5
p$values = paramval
expect_equal(p$get_values(context = list(scale = 100)), list(x = 5, y = 50))
expect_equal(p$get_values(context = list(scale = 10)), list(x = 5, y = 5))
expect_error(p$get_values(context = list(scale = 1)), "infeasible value.*not 'double'")

})


test_that("ContextPV with PSC", {
p = ps(x = p_dbl())
p$context_available = "x"
p$set_id = "n"
psc = ParamSetCollection$new(list(p))
expect_error({psc$values$n.x = ContextPV(function(y) x * 2)}, "Must be a subset of set \\{x\\}")
psc$values$n.x = ContextPV(function(x) x * 2)

expect_equal(p$get_values(context = list(x = 10)), list(x = 20))

})