Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Authors@R:
comment = c(ORCID = "0000-0003-3269-2307")),
person(given = "Martin",
family = "Binder",
role = c("aut", "cre"),
role = c("aut", "cre"),
email = "[email protected]"),
person(given = "Marc",
family = "Becker",
Expand Down Expand Up @@ -59,7 +59,8 @@ Suggests:
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
NeedsCompilation: yes
LinkingTo: checkmate
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
VignetteBuilder: knitr
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,4 @@ importFrom(data.table,as.data.table)
importFrom(methods,is)
importFrom(stats,rnorm)
importFrom(stats,runif)
useDynLib(paradox,c_paramset_ids)
18 changes: 3 additions & 15 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,10 @@ ParamSet = R6Class("ParamSet",
#' @param any_tags (`character()`).
#' Return only IDs of dimensions that have at least one of the tags given in this argument.
#' @return `character()`.
#' @useDynLib paradox c_paramset_ids
ids = function(class = NULL, tags = NULL, any_tags = NULL) {
assert_character(class, any.missing = FALSE, null.ok = TRUE)
assert_character(tags, any.missing = FALSE, null.ok = TRUE)
assert_character(any_tags, any.missing = FALSE, null.ok = TRUE)

if (is.null(class) && is.null(tags) && is.null(any_tags)) {
return(private$.params$id)
}
ptbl = if (is.null(class)) private$.params else private$.params[cls %in% class, .(id)]
if (is.null(tags) && is.null(any_tags)) {
return(ptbl$id)
}
tagtbl = private$.tags[ptbl, nomatch = 0]
idpool = if (is.null(any_tags)) list() else list(tagtbl[tag %in% any_tags, id])
idpool = c(idpool, lapply(tags, function(t) tagtbl[t, id, on = "tag", nomatch = 0]))
Reduce(intersect, idpool)
# argchecks are done in C code
.Call("c_paramset_ids", private$.params, private$.tags, class, tags, any_tags)
},

#' @description
Expand Down
15 changes: 15 additions & 0 deletions src/init.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <R.h>
#include <Rinternals.h>
#include <R_ext/Rdynload.h>

extern SEXP c_paramset_ids(SEXP, SEXP, SEXP, SEXP, SEXP);

static const R_CallMethodDef CallEntries[] = {
{"c_paramset_ids", (DL_FUNC) &c_paramset_ids, 5},
{NULL, NULL, 0}
};

void R_init_mypackage(DllInfo *dll) {
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
R_useDynamicSymbols(dll, FALSE);
}
98 changes: 98 additions & 0 deletions src/paramset.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include <R.h>
#include <Rinternals.h>



// typedef for checkmate::qassert signature
typedef void (*fun_t)(SEXP x, const char *rule, const char *name);

SEXP c_paramset_ids(SEXP s_paramtbl, SEXP s_tagtbl, SEXP s_classes, SEXP s_tags, SEXP s_anytags) {
fun_t qassert = (fun_t) R_GetCCallable("checkmate", "qassert");

if (!isNull(s_classes)) qassert(s_classes, "S", "class");
if (!isNull(s_tags)) qassert(s_tags, "S", "tags");
if (!isNull(s_anytags)) qassert(s_anytags, "S", "any_tags");

int s_classes_n = LENGTH(s_classes);
int s_tags_n = LENGTH(s_tags);
int s_anytags_n = LENGTH(s_anytags);
/* Rprintf("s_classes_n=%i, s_tags=%i, s_anytags_n=%i\n", s_classes_n, s_tags_n, s_anytags_n); */

int paramtbl_nrows = LENGTH(VECTOR_ELT(s_paramtbl, 0));
int tagtbl_nrows = LENGTH(VECTOR_ELT(s_tagtbl, 0));
// FIXME: i am not sure if we want to index cols by nr here...
SEXP s_paramtbl_ids = VECTOR_ELT(s_paramtbl, 0);

SEXP s_paramtbl_classes = VECTOR_ELT(s_paramtbl, 1);
SEXP s_tagtbl_ids = VECTOR_ELT(s_tagtbl, 0);
SEXP s_tagtbl_tags = VECTOR_ELT(s_tagtbl, 1);

// result; potentially too large. has as many els as we have params
SEXP s_ids = PROTECT(allocVector(STRSXP, paramtbl_nrows));
int s_ids_count = 0;

// iter thru all rows in paramtbl and check that for each param all conditions hold
for (int i = 0; i < paramtbl_nrows; i++) {
/* Rprintf("i=%i, id=%s, class=%s\n", i, id, class); */
// check that params's class is in "s_classes"
// if s_classes is NULL or empty, we dont need to check
if (s_classes_n > 0) {
SEXP class = STRING_ELT(s_paramtbl_classes, i);
int ok_classes = 0;
for (int j = 0; j < s_classes_n; j++) {
if (class == STRING_ELT(s_classes, j)) {
/* Rprintf("class ok, j = %i\n", j); */
ok_classes = 1;
}
}
if (!ok_classes) continue;
}
SEXP id = STRING_ELT(s_paramtbl_ids, i);

// check that param has all tags that are in "s_tags"
// if s_tags is NULL or empty, we dont need to check
if (s_tags_n > 0) {
int ok_tags = 0;
for (int j = 0; j < s_tags_n; j++) {
SEXP tag = STRING_ELT(s_tags, j);
// FIXME: this search is super slow, we iterate the tbl again and again
for (int k = 0; k < tagtbl_nrows; k++) {
SEXP tagtbl_id = STRING_ELT(s_tagtbl_ids, k);
SEXP tagtbl_tag = STRING_ELT(s_tagtbl_tags, k);
if (id == tagtbl_id && tag == tagtbl_tag)
ok_tags++;
}
}
/* Rprintf("ok_tags=%i\n", ok_tags); */
// we didnt find all tags, so we skip current param
if (ok_tags < s_tags_n) continue;
}

// check that param has at least one tag from "s_anytags"
// if s_anytags is NULL or empty, we dont need to check
if (s_anytags_n > 0) {
int ok_anytags = 0;
for (int j = 0; j < s_anytags_n; j++) {
SEXP anytag = STRING_ELT(s_anytags, j);
for (int k = 0; k < tagtbl_nrows; k++) {
SEXP tagtbl_id = STRING_ELT(s_tagtbl_ids, k);
SEXP tagtbl_tag = STRING_ELT(s_tagtbl_tags, k);
if (id == tagtbl_id && anytag == tagtbl_tag)
ok_anytags = 1;
}
}
if (!ok_anytags) continue;
}

// if we ended up here, we add param to result
SET_STRING_ELT(s_ids, s_ids_count++, mkChar(CHAR(id)));
}

// copy result to shorter charvec of correct size
SEXP s_ids_2 = PROTECT(allocVector(STRSXP, s_ids_count));
for (int i = 0; i < s_ids_count; i++)
SET_STRING_ELT(s_ids_2, i, STRING_ELT(s_ids, i));
UNPROTECT(2);
return s_ids_2;
}