Skip to content
2 changes: 1 addition & 1 deletion R/NoDefault.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ NoDefault = R6Class("NoDefault",

#' @export
NO_DEF = NoDefault$new() # nolint
is_nodefault = function(x) test_r6(x, "NoDefault")
is_nodefault = function(x) inherits(x, "NoDefault")
92 changes: 60 additions & 32 deletions R/Param.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,6 @@
#' @export
Param = R6Class("Param",
public = list(
#' @field id (`character(1)`)\cr
#' Identifier of the object.
id = NULL,

#' @field special_vals (`list()`)\cr
#' Arbitrary special values this parameter is allowed to take.
special_vals = NULL,

#' @field default (`any`)\cr
#' Default value.
default = NULL,

#' @field tags (`character()`)\cr
#' Arbitrary tags to group and subset parameters.
tags = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
Expand All @@ -46,10 +30,10 @@ Param = R6Class("Param",
assert_list(special_vals)
assert_character(tags, any.missing = FALSE, unique = TRUE)

self$id = id
self$special_vals = special_vals
self$default = default
self$tags = tags
private$.id = id
private$.special_vals = special_vals
private$.default = default
private$.tags = tags
if (!is_nodefault(default)) { # check that default is feasible
self$assert(default)
}
Expand All @@ -72,7 +56,7 @@ Param = R6Class("Param",
}, error = function(e) paste("tune token invalid:", conditionMessage(e))))
}
ch = private$.check(x)
ifelse(isTRUE(ch) || has_element(self$special_vals, x), TRUE, ch)
ifelse(isTRUE(ch) || has_element(private$.special_vals, x), TRUE, ch)
},

#' @description
Expand Down Expand Up @@ -100,18 +84,42 @@ Param = R6Class("Param",
#' Each parameter is named "\[id\]_rep_\[k\]" and gets the additional tag "\[id\]_rep".
#'
#' @param n (`integer(1)`).
#' @return [ParamSet].
#' @return [`ParamSet`].
rep = function(n) {
assert_count(n)
pid = self$id
join_id = paste0(pid, "_rep")
ps = replicate(n, self$clone(), simplify = FALSE)
for (i in 1:n) {
p = ps[[i]]
p$id = paste0(join_id, "_", i)
p$tags = c(p$tags, join_id)
}
ParamSet$new(ps)
taggedself = self$with_tags(c(self$param_tags, join_id))
repeatedself = structure(rep(list(taggedself), n), names = sprintf("%s_%s", join_id, seq_len(n)))
ParamSet$new(repeatedself, ignore_ids = TRUE)
},

#' @description
#' Creates a clone of this parameter with changed `$tags` slot.
#' If the tags do not change, no clone is created.
#' @param tags (`character`) tags of the clone.
#' @return [`Param`].
with_tags = function(tags) {
origtags = private$.tags
if (identical(tags, origtags)) return(self)
assert_character(tags, any.missing = FALSE, unique = TRUE)
on.exit({private$.tags = origtags})
private$.tags = tags
self$clone(deep = TRUE)
},

#' @description
#' Creates a clone of this parameter with changed `$id` slot.
#' If the `id` does not change, no clone is created.
#' @param id (`character(1)`) id of the clone.
#' @return [`Param`].
with_id = function(id) {
if (identical(id, self$id)) return(self)
assert_id(id)
origid = private$.id
on.exit({private$.id = origid})
private$.id = id
self$clone(deep = TRUE)
},

#' @description
Expand Down Expand Up @@ -159,6 +167,22 @@ Param = R6Class("Param",
),

active = list(
#' @field id (`character(1)`)\cr
#' Identifier of the object.
id = function() private$.id,

#' @field special_vals (`list()`)\cr
#' Arbitrary special values this parameter is allowed to take.
special_vals = function() private$.special_vals,

#' @field default (`any`)\cr
#' Default value.
default = function() private$.default,

#' @field param_tags (`character()`)\cr
#' Arbitrary tags to group and subset parameters.
param_tags = function() private$.tags,

#' @field class (`character(1)`)\cr
#' R6 class name. Read-only.
class = function() class(self)[[1L]],
Expand All @@ -173,12 +197,16 @@ Param = R6Class("Param",

#' @field has_default (`logical(1)`)\cr
#' Is there a default value?
has_default = function() !is_nodefault(self$default)
has_default = function() !is_nodefault(private$.default)
),

private = list(
.check = function(x) stop("abstract"),
.qunif = function(x) stop("abstract") # should be implemented by subclasses, argcheck happens in Param$qunif
.qunif = function(x) stop("abstract"), # should be implemented by subclasses, argcheck happens in Param$qunif
.special_vals = NULL,
.default = NULL,
.tags = NULL,
.id = NULL
)
)

Expand All @@ -195,6 +223,6 @@ as.data.table.Param = function(x, ...) { # nolint
special_vals = list(x$special_vals),
default = list(x$default),
storage_type = x$storage_type,
tags = list(x$tags)
tags = list(x$param_tags)
)
}
44 changes: 22 additions & 22 deletions R/ParamDbl.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,12 @@
#' ParamDbl$new("ratio", lower = 0, upper = 1, default = 0.5)
ParamDbl = R6Class("ParamDbl", inherit = Param,
public = list(
#' @template field_lower
lower = NULL,

#' @template field_upper
upper = NULL,

#' @field tolerance (`numeric(1)`)\cr
#' tolerance of values to accept beyond `$lower` and `$upper`.
#' Used both for relative and absolute tolerance.
tolerance = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character(), tolerance = sqrt(.Machine$double.eps)) {
self$lower = assert_number(lower)
self$upper = assert_number(upper)
self$tolerance = assert_number(tolerance, lower = 0)
private$.lower = assert_number(lower)
private$.upper = assert_number(upper)
private$.tolerance = assert_number(tolerance, lower = 0)
assert_true(lower <= upper)
super$initialize(id, special_vals = special_vals, default = default, tags = tags)
},
Expand All @@ -52,35 +41,46 @@ ParamDbl = R6Class("ParamDbl", inherit = Param,
#' @param x (`numeric(1)`)\cr
#' Value to convert.
convert = function(x) {
min(max(x, self$lower), self$upper)
min(max(x, private$.lower), private$.upper)
}
),

active = list(
#' @template field_lower
lower = function() private$.lower,
#' @template field_upper
upper = function() private$.upper,
#' @field tolerance (`numeric(1)`)\cr
#' tolerance of values to accept beyond `$lower` and `$upper`.
#' Used both for relative and absolute tolerance.
tolerance = function() private$.tolerance,
#' @template field_levels
levels = function() NULL,
#' @template field_nlevels
nlevels = function() Inf,
#' @template field_is_bounded
is_bounded = function() is.finite(self$lower) && is.finite(self$upper),
is_bounded = function() is.finite(private$.lower) && is.finite(private$.upper),
#' @template field_storage_type
storage_type = function() "numeric"
),

private = list(
.check = function(x) {
# Accept numbers between lower and upper bound, with tolerance self$tolerance
# Accept numbers between lower and upper bound, with tolerance `$tolerance`
# Tolerance is both absolute & relative tolerance (if either tolerance is
# undercut the value is accepted:
# Values that go beyond the bound by less than `self$tolerance` are also
# Values that go beyond the bound by less than `tolerance` are also
# accepted (absolute tolerance)
# Values that go beyond the bound by less than `abs(<bound>) * self$tolerance`
# Values that go beyond the bound by less than `abs(<bound>) * tolerance`
# are also accepted (relative tolerance)
checkNumber(x,
lower = self$lower - self$tolerance * max(1, abs(self$lower)),
upper = self$upper + self$tolerance * max(1, abs(self$upper))
lower = private$.lower - private$.tolerance * max(1, abs(private$.lower)),
upper = private$.upper + private$.tolerance * max(1, abs(private$.upper))
)
},
.qunif = function(x) x * self$upper - (x-1) * self$lower
.qunif = function(x) x * private$.upper - (x-1) * private$.lower,
.lower = NULL,
.upper = NULL,
.tolerance = NULL
)
)
17 changes: 9 additions & 8 deletions R/ParamFct.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,41 @@
ParamFct = R6Class("ParamFct", inherit = Param,
public = list(

#' @template field_levels
levels = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param levels (`character()`)\cr
#' Set of allowed levels.
initialize = function(id, levels, special_vals = list(), default = NO_DEF, tags = character()) {
assert_character(levels, any.missing = FALSE, unique = TRUE)
self$levels = levels
private$.levels = levels
super$initialize(id, special_vals = special_vals, default = default, tags = tags)
}
),

active = list(
#' @template field_levels
levels = function() private$.levels,
#' @template field_lower
lower = function() NA_real_,
#' @template field_upper
upper = function() NA_real_,
#' @template field_nlevels
nlevels = function() length(self$levels),
nlevels = function() length(private$.levels),
#' @template field_is_bounded
is_bounded = function() TRUE,
#' @template field_storage_type
storage_type = function() "character"
),

private = list(
.check = function(x) check_choice(x, choices = self$levels),
.check = function(x) check_choice(x, choices = private$.levels),

.qunif = function(x) {
z = floor(x * self$nlevels * (1 - 1e-16)) + 1 # make sure we dont map to upper+1
self$levels[z]
}
private$.levels[z]
},

.levels = NULL
)
)
28 changes: 14 additions & 14 deletions R/ParamInt.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,18 @@
#' ParamInt$new("count", lower = 0, upper = 10, default = 1)
ParamInt = R6Class("ParamInt", inherit = Param,
public = list(
#' @template field_lower
lower = NULL,

#' @template field_upper
upper = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, lower = -Inf, upper = Inf, special_vals = list(), default = NO_DEF, tags = character()) {
if (isTRUE(is.infinite(lower))) {
self$lower = lower
private$.lower = lower
} else {
self$lower = assert_int(lower)
private$.lower = assert_int(lower)
}
if (isTRUE(is.infinite(upper))) {
self$upper = upper
private$.upper = upper
} else {
self$upper = assert_int(upper)
private$.upper = assert_int(upper)
}
assert_true(lower <= upper)
super$initialize(id, special_vals = special_vals, default = default, tags = tags)
Expand All @@ -53,18 +47,24 @@ ParamInt = R6Class("ParamInt", inherit = Param,
),

active = list(
#' @template field_lower
lower = function() private$.lower,
#' @template field_upper
upper = function() private$.upper,
#' @template field_levels
levels = function() NULL,
#' @template field_nlevels
nlevels = function() (self$upper - self$lower) + 1L,
nlevels = function() (private$.upper - private$.lower) + 1L,
#' @template field_is_bounded
is_bounded = function() is.finite(self$lower) && is.finite(self$upper),
is_bounded = function() is.finite(private$.lower) && is.finite(private$.upper),
#' @template field_storage_type
storage_type = function() "integer"
),

private = list(
.check = function(x) checkInt(x, lower = self$lower, upper = self$upper, tol = 1e-300),
.qunif = function(x) as.integer(floor(x * self$nlevels * (1 - 1e-16)) + self$lower) # make sure we dont map to upper+1
.check = function(x) checkInt(x, lower = private$.lower, upper = private$.upper, tol = 1e-300),
.qunif = function(x) as.integer(floor(x * self$nlevels * (1 - 1e-16)) + private$.lower), # make sure we dont map to upper+1
.lower = NULL,
.upper = NULL
)
)
Loading