Skip to content

Commit 14c430b

Browse files
authored
Merge pull request #102 from ModelOriented/code_optim
Code optimization
2 parents 4851434 + be3c755 commit 14c430b

File tree

8 files changed

+181
-113
lines changed

8 files changed

+181
-113
lines changed

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
## Maintenance
1212

1313
- Added explanation of sampling Kernel SHAP to help file.
14-
- Internal code optimizations.
14+
- In internal calculations, use explicit `feature_names` as dimnames (https://github.com/ModelOriented/kernelshap/issues/96)
1515

1616
# kernelshap 0.3.7
1717

R/exact.R

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,59 @@
1-
# Functions required only for handling exact cases
1+
# Functions required only for handling (partly) exact cases
22

33
# Provides fixed input for the exact case:
44
# - Z: Matrix with all 2^p-2 on-off vectors z
55
# - w: Vector with row weights of Z ensuring that the distribution of sum(z) matches
66
# the SHAP kernel distribution
77
# - A: Exact matrix A = Z'wZ
8-
input_exact <- function(p) {
9-
Z <- exact_Z(p)
8+
input_exact <- function(p, feature_names) {
9+
Z <- exact_Z(p, feature_names = feature_names)
1010
# Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
1111
w <- kernel_weights(p) / choose(p, 1:(p - 1L))
12-
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p))
12+
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p, feature_names = feature_names))
1313
}
1414

15-
# Calculates exact A. Notice the difference to the off-diagnonals in the Supplement of
16-
# Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
17-
# see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
18-
exact_A <- function(p) {
15+
#' Exact Matrix A
16+
#'
17+
#' Internal function that calculates exact A.
18+
#' Notice the difference to the off-diagnonals in the Supplement of
19+
#' Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
20+
#' see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
21+
#'
22+
#' @noRd
23+
#' @keywords internal
24+
#'
25+
#' @param p Number of features.
26+
#' @param feature_names Feature names.
27+
#' @returns A (p x p) matrix.
28+
exact_A <- function(p, feature_names) {
1929
S <- 1:(p - 1L)
2030
c_pr <- S * (S - 1) / p / (p - 1)
2131
off_diag <- sum(kernel_weights(p) * c_pr)
22-
A <- matrix(off_diag, nrow = p, ncol = p)
32+
A <- matrix(
33+
off_diag, nrow = p, ncol = p, dimnames = list(feature_names, feature_names)
34+
)
2335
diag(A) <- 0.5
2436
A
2537
}
2638

27-
# Creates (2^p-2) x p matrix with all on-off vectors z of length p
28-
# Instead of calculating this object, we could evaluate it for different p <= p_max
29-
# and store it as a list in the package.
30-
exact_Z <- function(p) {
39+
#' All on-off Vectors
40+
#'
41+
#' Internal function that creates matrix of all on-off vectors of length `p`.
42+
#'
43+
#' @noRd
44+
#' @keywords internal
45+
#'
46+
#' @param p Number of features.
47+
#' @param feature_names Feature names.
48+
#' @returns An integer ((2^p - 2) x p) matrix of all on-off vectors of length `p`.
49+
exact_Z <- function(p, feature_names) {
3150
Z <- as.matrix(do.call(expand.grid, replicate(p, 0:1, simplify = FALSE)))
32-
dimnames(Z) <- NULL
51+
colnames(Z) <- feature_names
3352
Z[2:(nrow(Z) - 1L), , drop = FALSE]
3453
}
3554

3655
# List all length p vectors z with sum(z) in {k, p - k}
37-
partly_exact_Z <- function(p, k) {
56+
partly_exact_Z <- function(p, k, feature_names) {
3857
if (k < 1L) {
3958
stop("k must be at least 1")
4059
}
@@ -48,17 +67,18 @@ partly_exact_Z <- function(p, k) {
4867
utils::combn(seq_len(p), k, FUN = function(z) {x <- numeric(p); x[z] <- 1; x})
4968
)
5069
}
51-
if (p == 2L * k) {
52-
return(Z)
70+
if (p != 2L * k) {
71+
Z <- rbind(Z, 1 - Z)
5372
}
54-
return(rbind(Z, 1 - Z))
73+
colnames(Z) <- feature_names
74+
Z
5575
}
5676

5777
# Create Z, w, A for vectors z with sum(z) in {k, p-k} for k in {1, ..., deg}.
5878
# The total weights do not sum to one, except in the special (exact) case deg=p-deg.
5979
# (The remaining weight will be added via input_sampling(p, deg=deg)).
6080
# Note that for a given k, the weights are constant.
61-
input_partly_exact <- function(p, deg) {
81+
input_partly_exact <- function(p, deg, feature_names) {
6282
if (deg < 1L) {
6383
stop("deg must be at least 1")
6484
}
@@ -70,7 +90,7 @@ input_partly_exact <- function(p, deg) {
7090
Z <- w <- vector("list", deg)
7191

7292
for (k in seq_len(deg)) {
73-
Z[[k]] <- partly_exact_Z(p, k = k)
93+
Z[[k]] <- partly_exact_Z(p, k = k, feature_names = feature_names)
7494
n <- nrow(Z[[k]])
7595
w_tot <- kw[k] * (2 - (p == 2L * k))
7696
w[[k]] <- rep(w_tot / n, n)
@@ -82,20 +102,21 @@ input_partly_exact <- function(p, deg) {
82102
}
83103

84104
# Case p = 1 returns exact Shapley values
85-
case_p1 <- function(n, nms, v0, v1, X, verbose) {
105+
case_p1 <- function(n, feature_names, v0, v1, X, verbose) {
86106
txt <- "Exact Shapley values (p = 1)"
87107
if (verbose) {
88108
message(txt)
89109
}
90-
S <- v1 - v0[rep(1L, n), , drop = FALSE]
91-
SE <- matrix(numeric(n), dimnames = list(NULL, nms))
110+
S <- v1 - v0[rep(1L, n), , drop = FALSE] # (n x K)
111+
SE <- matrix(numeric(n), dimnames = list(NULL, feature_names)) # (n x 1)
92112
if (ncol(v1) > 1L) {
93113
SE <- replicate(ncol(v1), SE, simplify = FALSE)
94114
S <- lapply(
95-
asplit(S, MARGIN = 2L), function(M) as.matrix(M, dimnames = list(NULL, nms))
115+
asplit(S, MARGIN = 2L), function(M)
116+
as.matrix(M, dimnames = list(NULL, feature_names))
96117
)
97118
} else {
98-
colnames(S) <- nms
119+
colnames(S) <- feature_names
99120
}
100121
out <- list(
101122
S = S,

R/kernelshap.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
226226
# For p = 1, exact Shapley values are returned
227227
if (p == 1L) {
228228
return(
229-
case_p1(n = n, nms = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose)
229+
case_p1(
230+
n = n, feature_names = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose
231+
)
230232
)
231233
}
232234

@@ -238,7 +240,11 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
238240

239241
# Precalculations for the real Kernel SHAP
240242
if (exact || hybrid_degree >= 1L) {
241-
precalc <- if (exact) input_exact(p) else input_partly_exact(p, hybrid_degree)
243+
if (exact) {
244+
precalc <- input_exact(p, feature_names = feature_names)
245+
} else {
246+
precalc <- input_partly_exact(p, deg = hybrid_degree, feature_names = feature_names)
247+
}
242248
m_exact <- nrow(precalc[["Z"]])
243249
prop_exact <- sum(precalc[["w"]])
244250
precalc[["bg_X_exact"]] <- bg_X[rep(seq_len(bg_n), times = m_exact), , drop = FALSE]
@@ -317,10 +323,10 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
317323
warning("\nNon-convergence for ", sum(!converged), " rows.")
318324
}
319325
out <- list(
320-
S = reorganize_list(lapply(res, `[[`, "beta"), nms = feature_names),
326+
S = reorganize_list(lapply(res, `[[`, "beta")),
321327
X = X,
322328
baseline = as.vector(v0),
323-
SE = reorganize_list(lapply(res, `[[`, "sigma"), nms = feature_names),
329+
SE = reorganize_list(lapply(res, `[[`, "sigma")),
324330
n_iter = vapply(res, `[[`, "n_iter", FUN.VALUE = integer(1L)),
325331
converged = converged,
326332
m = m,

R/sampling.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Draw m binary vectors z of length p with sum(z) distributed according
44
# to Kernel SHAP weights -> (m x p) matrix.
55
# The argument S can be used to restrict the range of sum(z).
6-
sample_Z <- function(p, m, S = 1:(p - 1L)) {
6+
sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
77
# First draw s = sum(z) according to Kernel weights (renormalized to sum 1)
88
probs <- kernel_weights(p, S = S)
99
N <- S[sample.int(length(S), m, replace = TRUE, prob = probs)]
@@ -22,6 +22,7 @@ sample_Z <- function(p, m, S = 1:(p - 1L)) {
2222
dim(out) <- c(p, m)
2323
ord <- order(col(out), sample.int(m * p))
2424
out[] <- out[ord]
25+
rownames(out) <- feature_names
2526
t(out)
2627
}
2728

@@ -46,17 +47,18 @@ conv_crit <- function(sig, bet) {
4647
#
4748
# If deg > 0, vectors z with sum(z) restricted to [deg+1, p-deg-1] are sampled.
4849
# This case is used in combination with input_partly_hybrid(). Consequently, sum(w) < 1.
49-
input_sampling <- function(p, m, deg, paired) {
50+
input_sampling <- function(p, m, deg, paired, feature_names) {
5051
if (p < 2L * deg + 2L) {
5152
stop("p must be >=2*deg + 2")
5253
}
5354
S <- (deg + 1L):(p - deg - 1L)
54-
Z <- sample_Z(m = if (paired) m / 2 else m, p = p, S = S)
55+
Z <- sample_Z(
56+
p = p, m = if (paired) m / 2 else m, feature_names = feature_names, S = S
57+
)
5558
if (paired) {
5659
Z <- rbind(Z, 1 - Z)
5760
}
5861
w_total <- if (deg == 0L) 1 else 1 - 2 * sum(kernel_weights(p)[seq_len(deg)])
5962
w <- w_total / m
6063
list(Z = Z, w = rep(w, m), A = crossprod(Z) * w)
6164
}
62-

0 commit comments

Comments
 (0)