Skip to content

Commit 85a19a5

Browse files
AnnaNzrvsebffischerbe-marc
authored
feat: classif stepPlr learner from old mlr2 (#448)
Thank you for contributing a learner to the mlr3 ecosystem. Please make sure that: - [x] The added learner(s) are sufficiently tested - [x] All the CI tests are passing (including the CodeFactor) - [x] You ran `devtools::document()` - [x] You updated the `NEWS.md` field to include the addition of the learner - [x] You did not modify anything **not** related to the new learner - [x] You are listed as a contributor in the `DESCRIPTION` of the R package --------- Co-authored-by: Sebastian Fischer <[email protected]> Co-authored-by: be-marc <[email protected]>
1 parent 78f4750 commit 85a19a5

File tree

8 files changed

+276
-1
lines changed

8 files changed

+276
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ Suggests:
125125
sm,
126126
sparsediscrim,
127127
stats,
128+
stepPlr,
128129
survival,
129130
survivalmodels (>= 0.1.19),
130131
survivalsvm (>= 0.0.5),

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ export(LearnerClassifSMO)
7070
export(LearnerClassifSda)
7171
export(LearnerClassifSdlda)
7272
export(LearnerClassifSimpleLogistic)
73+
export(LearnerClassifStepPlr)
7374
export(LearnerClassifTabPFN)
7475
export(LearnerClassifVotedPerceptron)
7576
export(LearnerCompRisksRandomForestSRC)

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- `LearnerClassifAdaBoosting`
1313
- `Learner{Classif,Regr}Evtree`
1414
- `LearnerClassifKnn`
15+
- `LearnerClassifStepPlr`
1516
- `LearnerClassifMda`
1617
- `LearnerClassifRferns`
1718
- `LearnerClassifNeuralnet`
@@ -49,7 +50,6 @@
4950
This means the auto tests will be stochastic like they should be
5051
* The CI now checks that RCMD-check passes when suggested packages are not available.
5152

52-
5353
# mlr3extralearners 1.1.0
5454

5555
New Features:

R/bibentries.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,19 @@ bibentries = c( # nolint start
812812
booktitle = "International Conference on Learning Representations 2023",
813813
year = "2023"
814814
),
815+
park2008plr = bibentry("article",
816+
title = "Penalized logistic regression for detecting gene interactions",
817+
author = "Park, Mee Young and Hastie, Trevor",
818+
journal = "Biostatistics",
819+
volume = "9",
820+
number = "1",
821+
pages = "30-50",
822+
year = "2007",
823+
month = "04",
824+
issn = "1465-4644",
825+
doi = "10.1093/biostatistics/kxm010",
826+
url = "https://doi.org/10.1093/biostatistics/kxm010"
827+
),
815828
knn2002 = bibentry("book",
816829
title = "Modern Applied Statistics with S",
817830
author = "W. N. Venables and B. D. Ripley",

R/learner_stepPlr_classif_plr.R

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#' @title Classification Logistic Regression Learner
2+
#' @author annanzrv
3+
#' @name mlr_learners_classif.stepPlr
4+
#'
5+
#' @description
6+
#' Logistic regression with a quadratic penalization on the coefficient.
7+
#' Calls [stepPlr::plr()] from \CRANpkg{stepPlr}.
8+
#'
9+
#' @templateVar id classif.stepPlr
10+
#' @template learner
11+
#'
12+
#' @references
13+
#' `r format_bib("park2008plr")`
14+
#'
15+
#' @template seealso_learner
16+
#' @template example
17+
#' @export
18+
LearnerClassifStepPlr = R6Class("LearnerClassifStepPlr",
19+
inherit = LearnerClassif,
20+
public = list(
21+
#' @description
22+
#' Creates a new instance of this [R6][R6::R6Class] class.
23+
initialize = function() {
24+
param_set = ps(
25+
cp = p_fct(default = "aic", levels = c("aic", "bic"), tags = "train"),
26+
lambda = p_dbl(default = 1e-4, lower = 0, tags = "train"),
27+
offset.coefficients = p_uty(tags = "train"),
28+
offset.subset = p_uty(tags = "train")
29+
)
30+
31+
super$initialize(
32+
id = "classif.stepPlr",
33+
packages = "stepPlr",
34+
feature_types = c("logical", "integer", "numeric"),
35+
predict_types = c("response", "prob"),
36+
param_set = param_set,
37+
properties = c("twoclass", "weights"),
38+
man = "mlr3extralearners::mlr_learners_classif.stepPlr",
39+
label = "Logistic Regression with a L2 Penalty"
40+
)
41+
}
42+
),
43+
private = list(
44+
.train = function(task) {
45+
pars = self$param_set$get_values(tags = "train")
46+
data = as.matrix(task$data(cols = task$feature_names))
47+
y = as.numeric(task$data()[[task$target_names]]) - 1
48+
pars$weights = private$.get_weights(task)
49+
invoke(
50+
stepPlr::plr,
51+
x = data,
52+
y = y,
53+
.args = pars
54+
)
55+
},
56+
.predict = function(task) {
57+
pars = self$param_set$get_values(tags = "predict")
58+
newdata = ordered_features(task, self)
59+
# Remove target column if present in newdata
60+
if (
61+
length(task$target_names) > 0 && task$target_names %in% colnames(newdata)
62+
) {
63+
newx = as.matrix(newdata[, !task$target_names, with = FALSE])
64+
} else {
65+
newx = as.matrix(newdata)
66+
}
67+
68+
type = if (self$predict_type == "prob") "response" else "class"
69+
pred = invoke(predict, self$model, newx = newx, type = type, .args = pars)
70+
71+
if (type == "class") {
72+
levels = task$class_names
73+
response = factor(pred, levels = seq_along(levels) - 1, labels = levels)
74+
list(response = response)
75+
} else {
76+
prob = pprob_to_matrix(1 - unname(pred), task)
77+
list(prob = prob)
78+
}
79+
}
80+
)
81+
)
82+
83+
.extralrns_dict$add("classif.stepPlr", LearnerClassifStepPlr)

man/mlr_learners_classif.stepPlr.Rd

Lines changed: 138 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
skip_if_not_installed("stepPlr")
2+
3+
test_that("classif.plr train", {
4+
learner = lrn("classif.stepPlr")
5+
fun = stepPlr::plr
6+
exclude = c(
7+
"x", # handled internally
8+
"y", # handled internally
9+
"weights" # set internally
10+
)
11+
12+
# note that you can also pass a list of functions in case $.train calls more than one
13+
# function, e.g. for control arguments
14+
paramtest = run_paramtest(learner, fun, exclude, tag = "train")
15+
expect_paramtest(paramtest)
16+
})
17+
18+
test_that("classif.plr predict", {
19+
learner = lrn("classif.stepPlr")
20+
fun = stepPlr::predict.plr # nolint
21+
exclude = c(
22+
"object", # handled internally
23+
"data", # handled internally
24+
"newx", # handled internally
25+
"type" # set internally
26+
)
27+
28+
paramtest = run_paramtest(learner, fun, exclude, tag = "predict")
29+
expect_paramtest(paramtest)
30+
})
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
skip_if_not_installed("stepPlr")
2+
3+
test_that("autotest", {
4+
learner = lrn("classif.stepPlr")
5+
expect_learner(learner)
6+
# note that you can skip tests using the exclude argument
7+
result = run_autotest(learner)
8+
expect_true(result, info = result$error)
9+
})

0 commit comments

Comments
 (0)