diff --git a/DESCRIPTION b/DESCRIPTION index 621fe625..69ab6679 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,7 +26,7 @@ Authors@R: comment = c(ORCID = "0000-0003-3269-2307")), person(given = "Martin", family = "Binder", - role = c("aut", "cre"), + role = c("aut", "cre"), email = "mlr.developer@mb706.com"), person(given = "Marc", family = "Becker", @@ -59,7 +59,8 @@ Suggests: Encoding: UTF-8 Config/testthat/edition: 3 Config/testthat/parallel: false -NeedsCompilation: no +NeedsCompilation: yes +LinkingTo: checkmate Roxygen: list(markdown = TRUE, r6 = TRUE) RoxygenNote: 7.3.1 VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index 71371801..0bfb6d71 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -105,3 +105,4 @@ importFrom(data.table,as.data.table) importFrom(methods,is) importFrom(stats,rnorm) importFrom(stats,runif) +useDynLib(paradox,c_paramset_ids) diff --git a/R/ParamSet.R b/R/ParamSet.R index 529338dd..c7e86fde 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -127,22 +127,10 @@ ParamSet = R6Class("ParamSet", #' @param any_tags (`character()`). #' Return only IDs of dimensions that have at least one of the tags given in this argument. #' @return `character()`. + #' @useDynLib paradox c_paramset_ids ids = function(class = NULL, tags = NULL, any_tags = NULL) { - assert_character(class, any.missing = FALSE, null.ok = TRUE) - assert_character(tags, any.missing = FALSE, null.ok = TRUE) - assert_character(any_tags, any.missing = FALSE, null.ok = TRUE) - - if (is.null(class) && is.null(tags) && is.null(any_tags)) { - return(private$.params$id) - } - ptbl = if (is.null(class)) private$.params else private$.params[cls %in% class, .(id)] - if (is.null(tags) && is.null(any_tags)) { - return(ptbl$id) - } - tagtbl = private$.tags[ptbl, nomatch = 0] - idpool = if (is.null(any_tags)) list() else list(tagtbl[tag %in% any_tags, id]) - idpool = c(idpool, lapply(tags, function(t) tagtbl[t, id, on = "tag", nomatch = 0])) - Reduce(intersect, idpool) + # argchecks are done in C code + .Call("c_paramset_ids", private$.params, private$.tags, class, tags, any_tags) }, #' @description diff --git a/src/init.c b/src/init.c new file mode 100644 index 00000000..3bf0e59a --- /dev/null +++ b/src/init.c @@ -0,0 +1,15 @@ +#include +#include +#include + +extern SEXP c_paramset_ids(SEXP, SEXP, SEXP, SEXP, SEXP); + +static const R_CallMethodDef CallEntries[] = { + {"c_paramset_ids", (DL_FUNC) &c_paramset_ids, 5}, + {NULL, NULL, 0} +}; + +void R_init_mypackage(DllInfo *dll) { + R_registerRoutines(dll, NULL, CallEntries, NULL, NULL); + R_useDynamicSymbols(dll, FALSE); +} diff --git a/src/paramset.c b/src/paramset.c new file mode 100644 index 00000000..b6adcb44 --- /dev/null +++ b/src/paramset.c @@ -0,0 +1,98 @@ +#include +#include + + + +// typedef for checkmate::qassert signature +typedef void (*fun_t)(SEXP x, const char *rule, const char *name); + +SEXP c_paramset_ids(SEXP s_paramtbl, SEXP s_tagtbl, SEXP s_classes, SEXP s_tags, SEXP s_anytags) { + fun_t qassert = (fun_t) R_GetCCallable("checkmate", "qassert"); + + if (!isNull(s_classes)) qassert(s_classes, "S", "class"); + if (!isNull(s_tags)) qassert(s_tags, "S", "tags"); + if (!isNull(s_anytags)) qassert(s_anytags, "S", "any_tags"); + + int s_classes_n = LENGTH(s_classes); + int s_tags_n = LENGTH(s_tags); + int s_anytags_n = LENGTH(s_anytags); + /* Rprintf("s_classes_n=%i, s_tags=%i, s_anytags_n=%i\n", s_classes_n, s_tags_n, s_anytags_n); */ + + int paramtbl_nrows = LENGTH(VECTOR_ELT(s_paramtbl, 0)); + int tagtbl_nrows = LENGTH(VECTOR_ELT(s_tagtbl, 0)); + // FIXME: i am not sure if we want to index cols by nr here... + SEXP s_paramtbl_ids = VECTOR_ELT(s_paramtbl, 0); + + SEXP s_paramtbl_classes = VECTOR_ELT(s_paramtbl, 1); + SEXP s_tagtbl_ids = VECTOR_ELT(s_tagtbl, 0); + SEXP s_tagtbl_tags = VECTOR_ELT(s_tagtbl, 1); + + // result; potentially too large. has as many els as we have params + SEXP s_ids = PROTECT(allocVector(STRSXP, paramtbl_nrows)); + int s_ids_count = 0; + + // iter thru all rows in paramtbl and check that for each param all conditions hold + for (int i = 0; i < paramtbl_nrows; i++) { + /* Rprintf("i=%i, id=%s, class=%s\n", i, id, class); */ + // check that params's class is in "s_classes" + // if s_classes is NULL or empty, we dont need to check + if (s_classes_n > 0) { + SEXP class = STRING_ELT(s_paramtbl_classes, i); + int ok_classes = 0; + for (int j = 0; j < s_classes_n; j++) { + if (class == STRING_ELT(s_classes, j)) { + /* Rprintf("class ok, j = %i\n", j); */ + ok_classes = 1; + } + } + if (!ok_classes) continue; + } + SEXP id = STRING_ELT(s_paramtbl_ids, i); + + // check that param has all tags that are in "s_tags" + // if s_tags is NULL or empty, we dont need to check + if (s_tags_n > 0) { + int ok_tags = 0; + for (int j = 0; j < s_tags_n; j++) { + SEXP tag = STRING_ELT(s_tags, j); + // FIXME: this search is super slow, we iterate the tbl again and again + for (int k = 0; k < tagtbl_nrows; k++) { + SEXP tagtbl_id = STRING_ELT(s_tagtbl_ids, k); + SEXP tagtbl_tag = STRING_ELT(s_tagtbl_tags, k); + if (id == tagtbl_id && tag == tagtbl_tag) + ok_tags++; + } + } + /* Rprintf("ok_tags=%i\n", ok_tags); */ + // we didnt find all tags, so we skip current param + if (ok_tags < s_tags_n) continue; + } + + // check that param has at least one tag from "s_anytags" + // if s_anytags is NULL or empty, we dont need to check + if (s_anytags_n > 0) { + int ok_anytags = 0; + for (int j = 0; j < s_anytags_n; j++) { + SEXP anytag = STRING_ELT(s_anytags, j); + for (int k = 0; k < tagtbl_nrows; k++) { + SEXP tagtbl_id = STRING_ELT(s_tagtbl_ids, k); + SEXP tagtbl_tag = STRING_ELT(s_tagtbl_tags, k); + if (id == tagtbl_id && anytag == tagtbl_tag) + ok_anytags = 1; + } + } + if (!ok_anytags) continue; + } + + // if we ended up here, we add param to result + SET_STRING_ELT(s_ids, s_ids_count++, mkChar(CHAR(id))); + } + + // copy result to shorter charvec of correct size + SEXP s_ids_2 = PROTECT(allocVector(STRSXP, s_ids_count)); + for (int i = 0; i < s_ids_count; i++) + SET_STRING_ELT(s_ids_2, i, STRING_ELT(s_ids, i)); + UNPROTECT(2); + return s_ids_2; +} +