Skip to content

Commit df48e60

Browse files
AnnaNzrvsebffischerbe-marc
authored
feat: adabag learner from old mlr2 (#440)
Thank you for contributing a learner to the mlr3 ecosystem. Please make sure that: - [x] The added learner(s) are sufficiently tested - [x] All the CI tests are passing (including the CodeFactor) - [x] You ran `devtools::document()` - [x] You updated the `NEWS.md` field to include the addition of the learner - [x] You did not modify anything **not** related to the new learner - [x] You are listed as a contributor in the `DESCRIPTION` of the R package --------- Co-authored-by: Sebastian Fischer <[email protected]> Co-authored-by: be-marc <[email protected]>
1 parent 85a19a5 commit df48e60

8 files changed

+360
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Imports:
4747
Suggests:
4848
abess,
4949
actuar,
50+
adabag,
5051
ada,
5152
aorsf (>= 0.1.5),
5253
apcluster,

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ S3method(unmarshal_model,tabpfn_model_marshaled)
1010
S3method(unmarshal_model,xgboost_cox_model_marshaled)
1111
export(LearnerClassifAbess)
1212
export(LearnerClassifAdaBoostM1)
13+
export(LearnerClassifAdabag)
1314
export(LearnerClassifAdaBoosting)
1415
export(LearnerClassifBart)
1516
export(LearnerClassifBayesNet)

NEWS.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
## New Features
44

55
* New Learners:
6-
6+
- `LearnerCompRisksRandomForestSRC`
7+
- `LearnerSurvBlockForest`
78
- `Learner{Classif,Regr,Surv}BlockForest`
89
- `Learner{Classif,Regr}ExhaustiveSearch`
910
- `LearnerClassifFastai`
1011
- `Learner{Classif,Regr}Penalized`
1112
- `Learner{Classif,Regr}Bst`
13+
- `LearnerClassifAdabag`
1214
- `LearnerClassifAdaBoosting`
1315
- `Learner{Classif,Regr}Evtree`
1416
- `LearnerClassifKnn`
@@ -21,11 +23,16 @@
2123
* Add new `control_custom_fun` parameter in `surv.aorsf`
2224
* New function `learner_is_runnable()` to check whether the
2325
required packages to train a learner are available.
26+
* Added `selected_features` property to RandomForestSRC learners (prediction doesn't work if `vars.used = 'all.trees'`)
2427

2528
## Bug fixes
2629

2730
* Tests are now skipped when the suggested packages is not available.
2831
This will make local development much more convenient.
32+
* Removed parameters from RandomForestSRC learners that weren't used + optimized tests
33+
* Removed `discrete` parameter from `surv.parametric`, so that it is impossible to return `distr6::VectorDistribution` survival predictions (softly deprecated in `[email protected]`)
34+
35+
2936

3037
## Breaking Changes
3138

R/bibentries.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,17 @@ bibentries = c( # nolint start
812812
booktitle = "International Conference on Learning Representations 2023",
813813
year = "2023"
814814
),
815+
adabag2013 = bibentry("article",
816+
title = "adabag: An R Package for Classification with Boosting and Bagging",
817+
volume = "54",
818+
url = "https://www.jstatsoft.org/index.php/jss/article/view/v054i02",
819+
doi = "10.18637/jss.v054.i02",
820+
number = "2",
821+
journal = "Journal of Statistical Software",
822+
author = "Alfaro, Esteban and Gamez, Matias and Garc\xc3\xada, Noelia",
823+
year = "2013",
824+
pages = "1-35"
825+
),
815826
park2008plr = bibentry("article",
816827
title = "Penalized logistic regression for detecting gene interactions",
817828
author = "Park, Mee Young and Hastie, Trevor",

R/learner_adabag_classif_adabag.R

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#' @title Classification Boosting Learner
2+
#' @author annanzrv
3+
#' @name mlr_learners_classif.adabag
4+
#'
5+
#' @description
6+
#' Classification boosting algorithm.
7+
#' Calls [adabag::boosting()] from \CRANpkg{adabag}.
8+
#'
9+
#' @section Initial parameter values:
10+
#' - `xval`:
11+
#' * Actual default: 10L
12+
#' * Initial value: 0L
13+
#' * Reason for change: Set to 0 for speed.
14+
#'
15+
#' @references
16+
#' `r format_bib("adabag2013")`
17+
#'
18+
#' @templateVar id classif.adabag
19+
#' @template learner
20+
#'
21+
#'
22+
#' @template seealso_learner
23+
#' @template example
24+
#' @export
25+
LearnerClassifAdabag = R6Class("LearnerClassifAdabag",
26+
inherit = LearnerClassif,
27+
public = list(
28+
#' @description
29+
#' Creates a new instance of this [R6][R6::R6Class] class.
30+
initialize = function() {
31+
param_set = ps(
32+
boos = p_lgl(default = TRUE, tags = "train"),
33+
coeflearn = p_fct(default = "Breiman", levels = c("Breiman", "Freund", "Zhu"), tags = "train"),
34+
cp = p_dbl(default = 0.01, lower = 0, upper = 1, tags = "train"),
35+
maxcompete = p_int(default = 4L, lower = 0L, tags = "train"),
36+
maxdepth = p_int(default = 30L, lower = 1L, upper = 30L, tags = "train"),
37+
maxsurrogate = p_int(default = 5L, lower = 0L, tags = "train"),
38+
mfinal = p_int(default = 100L, lower = 1L, tags = "train"),
39+
minbucket = p_int(lower = 1L, tags = "train"),
40+
minsplit = p_int(default = 20L, lower = 1L, tags = "train"),
41+
newmfinal = p_int(tags = "predict"),
42+
surrogatestyle = p_int(default = 0L, lower = 0L, upper = 1L, tags = "train"),
43+
usesurrogate = p_int(default = 2L, lower = 0L, upper = 2L, tags = "train"),
44+
xval = p_int(default = 0L, lower = 0L, tags = "train")
45+
)
46+
param_set$values = list(xval = 0L)
47+
48+
super$initialize(
49+
id = "classif.adabag",
50+
packages = c("adabag", "rpart"),
51+
feature_types = c("integer", "numeric", "factor"),
52+
predict_types = c("response", "prob"),
53+
param_set = param_set,
54+
properties = c("importance", "missings", "multiclass", "twoclass"),
55+
man = "mlr3extralearners::mlr_learners_classif.adabag",
56+
label = "Adabag Boosting"
57+
)
58+
},
59+
#' @description
60+
#' The importance scores are extracted from the model.
61+
#' @return Named `numeric()`.
62+
importance = function() {
63+
if (is.null(self$model)) {
64+
stopf("No model stored")
65+
}
66+
sort(self$model$importance, decreasing = TRUE)
67+
}
68+
),
69+
70+
private = list(
71+
.train = function(task) {
72+
# get parameters for training
73+
pars = self$param_set$get_values(tags = "train")
74+
75+
args_ctrl = formalArgs(rpart::rpart.control)
76+
pars_ctrl = pars[names(pars) %in% args_ctrl]
77+
78+
# Create rpart control object
79+
ctrl = invoke(
80+
rpart::rpart.control,
81+
.args = pars_ctrl
82+
)
83+
84+
# Remove rpart control parameters from pars
85+
pars = pars[names(pars) %nin% args_ctrl]
86+
87+
# Add control to pars
88+
pars$control = ctrl
89+
90+
# Get formula and data
91+
formula = task$formula()
92+
data = task$data()
93+
94+
# Train model
95+
invoke(adabag::boosting,
96+
formula = formula,
97+
data = data,
98+
.args = pars
99+
)
100+
},
101+
.predict = function(task) {
102+
# get parameters with tag "predict"
103+
pars = self$param_set$get_values(tags = "predict")
104+
105+
# get newdata and ensure same ordering in train and predict
106+
newdata = ordered_features(task, self)
107+
108+
# Calculate predictions for the selected predict type
109+
type = self$predict_type
110+
111+
# adaboost requires target column
112+
newdata[, "target"] = factor(rep(1, nrow(newdata)), levels = task$class_names)
113+
114+
pred = invoke(predict, self$model, newdata = newdata, .args = pars)
115+
116+
if (type == "prob") {
117+
# Ensure probabilities are ordered according to task class levels
118+
prob = mlr3misc::set_col_names(pred$prob, task$class_names)
119+
list(prob = prob)
120+
} else {
121+
# Create response factor with correct levels
122+
response = factor(pred$class, levels = task$class_names)
123+
list(response = response)
124+
}
125+
}
126+
)
127+
)
128+
129+
.extralrns_dict$add("classif.adabag", LearnerClassifAdabag)

man/mlr_learners_classif.adabag.Rd

Lines changed: 173 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
skip_if_not_installed("adabag")
2+
3+
test_that("autotest", {
4+
learner = lrn("classif.adabag")
5+
expect_learner(learner)
6+
# note that you can skip tests using the exclude argument
7+
result = run_autotest(learner, exclude = "utf8_feature_names")
8+
expect_true(result, info = result$error)
9+
})

0 commit comments

Comments
 (0)