Skip to content

Commit 98b413b

Browse files
committed
update progress bar, documentation
1 parent 2332be2 commit 98b413b

15 files changed

+76
-71
lines changed

NEWS.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# modelStudio 2.1.0
2-
* **DEFAULTS CHANGES**: if `new_observation = NULL` then choose `new_observation_n = 3` observations, evenly spread by the histogram bins of `y_hat`. This shall always include the observations, which ids are `which.min(y_hat)` and `which.max(y_hat)`. Additionally, improve the observation dropdown text in dashboard. [(#94)](https://github.com/ModelOriented/modelStudio/issues/94)
3-
* This version requires `DALEX v2.0.1`
2+
* **DEFAULTS CHANGES**: if `new_observation = NULL` then choose `new_observation_n = 3` observations, evenly spread by the order of `y_hat`. This shall always include the observations, which ids are `which.min(y_hat)` and `which.max(y_hat)`. Additionally, improve the observation dropdown text in dashboard. [(#94)](https://github.com/ModelOriented/modelStudio/issues/94)
3+
* updated the progress printing
4+
* this version requires `DALEX v2.0.1`
45
* added new options to `ms_options`: `ms_subtitle`, `ms_margin_top` and `ms_margin_bottom`
56
* added new parameters to `modelStudio()`: `N_fi = 10*N` and `B_fi = B`
67
* added new `license` parameter to `modelStudio()` which allows to specify the connection for `readLines()` (e.g. `'LICENSE'`) which will add file contents into the HTML output as a comment

R/modelStudio.R

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,13 @@ modelStudio.explainer <- function(explainer,
203203
model_type <- explainer$model_info$type
204204

205205
if (!is.null(max_vars)) max_features <- max_vars
206-
if (identical(N_fi, numeric(0))) N_fi <- NULL
206+
if (is.null(N)) stop("`N` argument must be an integer")
207+
#if (identical(N_fi, numeric(0))) N_fi <- NULL
207208

208209
if (is.null(new_observation)) {
209-
if (show_info) message("`new_observation` argument is NULL.\n",
210-
"`new_observation_n` observations needed to calculate local explanations are taken from the data.\n")
210+
if (show_info) message(paste0("`new_observation` argument is NULL. ",
211+
"`new_observation_n` observations needed to ",
212+
"calculate local explanations are taken from the data.\n"))
211213
ret <- sample_new_observation(explainer, new_observation_n)
212214
new_observation <- ret[['no']]
213215
new_observation_y <- ret[['no_y']]
@@ -263,9 +265,11 @@ modelStudio.explainer <- function(explainer,
263265

264266
## later update progress bar after all explanation functions
265267
if (show_info) {
268+
increment <- ifelse(eda, 1, 0)
269+
266270
pb <- progress_bar$new(
267271
format = " Calculating :what \n Elapsed time: :elapsedfull ETA::eta ", # :percent [:bar]
268-
total = (3*B + 2 + 1)*obs_count + (2*B_fi + 3*B_fi + B_fi) + 1,
272+
total = 1 + increment + (3*B + 2 + 1)*obs_count + (2*B_fi + N/30 + N/10) + 2,
269273
show_after = 0,
270274
width = 110
271275
)
@@ -288,59 +292,65 @@ modelStudio.explainer <- function(explainer,
288292
ingredients::partial_dependence(
289293
model, data, predict_function, variable_type = "numerical", N = N,
290294
variable_splits_type=variable_splits_type),
291-
"ingredients::partial_dependence (numerical)", show_info, pb, B)
295+
"ingredients::partial_dependence (numerical)", show_info, pb, N/30)
292296
pd_c <- NULL
293297
ad_n <- calculate(
294298
ingredients::accumulated_dependence(
295299
model, data, predict_function, variable_type = "numerical", N = N,
296300
variable_splits_type=variable_splits_type),
297-
"ingredients::accumulated_dependence (numerical)", show_info, pb, 3*B)
301+
"ingredients::accumulated_dependence (numerical)", show_info, pb, N/10)
298302
ad_c <- NULL
299303
} else if (all(!which_numerical)) {
300304
pd_n <- NULL
301305
pd_c <- calculate(
302306
ingredients::partial_dependence(
303307
model, data, predict_function, variable_type = "categorical", N = N,
304308
variable_splits_type=variable_splits_type),
305-
"ingredients::partial_dependence (categorical)", show_info, pb, B)
309+
"ingredients::partial_dependence (categorical)", show_info, pb, N/30)
306310
ad_n <- NULL
307311
ad_c <- calculate(
308312
ingredients::accumulated_dependence(
309313
model, data, predict_function, variable_type = "categorical", N = N,
310314
variable_splits_type=variable_splits_type),
311-
"ingredients::accumulated_dependence (categorical)", show_info, pb, 3*B)
315+
"ingredients::accumulated_dependence (categorical)", show_info, pb, N/10)
312316
} else {
313317
pd_n <- calculate(
314318
ingredients::partial_dependence(
315319
model, data, predict_function, variable_type = "numerical", N = N,
316320
variable_splits_type=variable_splits_type),
317-
"ingredients::partial_dependence (numerical)", show_info, pb, B/2)
321+
"ingredients::partial_dependence (numerical)", show_info, pb, N/60)
318322
pd_c <- calculate(
319323
ingredients::partial_dependence(
320324
model, data, predict_function, variable_type = "categorical", N = N,
321325
variable_splits_type=variable_splits_type),
322-
"ingredients::partial_dependence (categorical)", show_info, pb, B/2)
326+
"ingredients::partial_dependence (categorical)", show_info, pb, N/60)
323327
ad_n <- calculate(
324328
ingredients::accumulated_dependence(
325329
model, data, predict_function, variable_type = "numerical", N = N,
326330
variable_splits_type=variable_splits_type),
327-
"ingredients::accumulated_dependence (numerical)", show_info, pb, 2*B)
331+
"ingredients::accumulated_dependence (numerical)", show_info, pb, 2*N/30)
328332
ad_c <- calculate(
329333
ingredients::accumulated_dependence(
330334
model, data, predict_function, variable_type = "categorical", N = N,
331335
variable_splits_type=variable_splits_type),
332-
"ingredients::accumulated_dependence (categorical)", show_info, pb, B)
336+
"ingredients::accumulated_dependence (categorical)", show_info, pb, N/30)
333337
}
334338

335339
fi_data <- prepare_feature_importance(fi, max_features, options$show_boxplot,
336340
attr(loss_function, "loss_name"), ...)
337341
pd_data <- prepare_partial_dependence(pd_n, pd_c, variables = variable_names)
338342
ad_data <- prepare_accumulated_dependence(ad_n, ad_c, variables = variable_names)
339-
mp_data <- DALEX::model_performance(explainer)$measures
343+
mp_ret <- calculate(
344+
DALEX::model_performance(explainer),
345+
"DALEX::model_performance", show_info, pb, 1)
346+
mp_data <- mp_ret$measures
340347

341348
if (eda) {
342349
#:# fd_data is used by targetVs and residualsVs plots
343-
residuals <- DALEX::model_diagnostics(explainer)$residuals
350+
md_ret <- calculate(
351+
DALEX::model_diagnostics(explainer),
352+
"DALEX::model_diagnostics", show_info, pb, 1)
353+
residuals <- md_ret$residuals
344354
fd_data <- prepare_feature_distribution(data, y, variables = variable_names,
345355
residuals = residuals)
346356
at_data <- prepare_average_target(data, y, variables = variable_names)
@@ -514,6 +524,8 @@ modelStudio.explainer <- function(explainer,
514524

515525
class(model_studio) <- c(class(model_studio), "modelStudio")
516526

527+
if (show_info) pb$tick(1, tokens = list(what = "..."))
528+
517529
model_studio
518530
}
519531

@@ -650,7 +662,7 @@ sample_new_observation <- function(explainer, new_observation_n = 3) {
650662
new_observation_n <- dim(explainer$data)[1]
651663
}
652664

653-
ids <- unique(quantile(seq_along(y_hat), seq(0, 1, length.out = new_observation_n), type = 4))
665+
ids <- unique(round(seq(1, length(y_hat), length.out = new_observation_n)))
654666
new_observation_ids <- order(y_hat)[ids]
655667

656668
list(no = explainer$data[new_observation_ids,], no_y = explainer$y[new_observation_ids])

README.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ install.packages("DALEXtra")
7373

7474
Make a studio for the regression `ranger` model on `apartments` data.
7575

76-
<details>
76+
<details open>
7777
<summary><strong><em>code</em></strong></summary>
7878

7979
```r
@@ -105,8 +105,7 @@ new_observation <- test[1:2,]
105105
rownames(new_observation) <- c("id1", "id2")
106106

107107
# make a studio for the model
108-
modelStudio(explainer,
109-
new_observation)
108+
modelStudio(explainer, new_observation)
110109
```
111110

112111
</details>
@@ -150,8 +149,7 @@ new_observation <- test_matrix[1:2, , drop=FALSE]
150149
rownames(new_observation) <- c("id1", "id2")
151150

152151
# make a studio for the model
153-
modelStudio(explainer,
154-
new_observation,
152+
modelStudio(explainer, new_observation,
155153
options = ms_options(margin_left = 140))
156154
```
157155

@@ -319,7 +317,7 @@ or with [`r2d3::save_d3_html()`](https://rstudio.github.io/r2d3/articles/publish
319317

320318
If you use `modelStudio`, please cite our [JOSS article](https://joss.theoj.org/papers/10.21105/joss.01798):
321319

322-
```python
320+
```
323321
@Article{modelStudio,
324322
author = {Hubert Baniecki and Przemyslaw Biecek},
325323
title = {{modelStudio}: Interactive Studio with Explanations for {ML} Predictive Models},

pkgdown/favicon/caret.html

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

pkgdown/favicon/demo.html

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

pkgdown/favicon/h2o.html

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

pkgdown/favicon/lightgbm.html

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

pkgdown/favicon/mlr.html

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

pkgdown/favicon/mlr3.html

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

pkgdown/favicon/parsnip.html

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)