Skip to content

Commit 61e82e4

Browse files
authored
Merge pull request #172 from ModelOriented/split-kernel-weights-into-two
Split kernel weights into two
2 parents 99275ef + 6e6a1d5 commit 61e82e4

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

R/utils_kernelshap.R

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ solver <- function(A, b, constraint) {
108108
# to Kernel SHAP weights -> (m x p) matrix.
109109
# The argument S can be used to restrict the range of sum(z).
110110
sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
111-
probs <- kernel_weights(p, per_coalition_size = TRUE, S = S)
111+
probs <- kernel_weights_per_coalition_size(p, S = S)
112112
N <- S[sample.int(length(S), m, replace = TRUE, prob = probs)]
113113

114114
# Then, conditional on that number, set random positions of z to 1
@@ -159,7 +159,7 @@ input_sampling <- function(p, m, deg, feature_names) {
159159
input_exact <- function(p, feature_names) {
160160
Z <- exact_Z(p, feature_names = feature_names)
161161
Z <- Z[2L:(nrow(Z) - 1L), , drop = FALSE]
162-
kw <- kernel_weights(p, per_coalition_size = FALSE) # Kernel weights for all subsets
162+
kw <- kernel_weights(p) # Kernel weights for all subsets
163163
w <- kw[rowSums(Z)] # Corresponding weight for each row in Z
164164
w <- w / sum(w)
165165
list(Z = Z, w = w, A = crossprod(Z, w * Z))
@@ -204,7 +204,7 @@ input_partly_exact <- function(p, deg, feature_names) {
204204
stop("p must be >=2*deg")
205205
}
206206

207-
kw <- kernel_weights(p, per_coalition_size = FALSE)
207+
kw <- kernel_weights(p)
208208

209209
Z <- vector("list", deg)
210210
for (k in seq_len(deg)) {
@@ -217,16 +217,17 @@ input_partly_exact <- function(p, deg, feature_names) {
217217
list(Z = Z, w = w, A = crossprod(Z, w * Z))
218218
}
219219

220-
# Kernel weight distribution
221-
#
222-
# `per_coalition_size = TRUE` is required, e.g., when one wants to sample random masks
223-
# according to the Kernel SHAP distribution: Pick a coalition size as per
224-
# these weights, then randomly place "on" positions. `FALSE` refer to weights
225-
# if all masks has been calculated and one wants to calculate their weights based
226-
# on the number of "on" positions.
227-
kernel_weights <- function(p, per_coalition_size, S = seq_len(p - 1L)) {
228-
const <- if (per_coalition_size) 1 else choose(p, S)
229-
probs <- (p - 1) / (const * S * (p - S)) # could drop the numerator
220+
# Kernel weight distribution. Gives the weight of each coalition vector of sum k
221+
kernel_weights <- function(p) {
222+
S <- seq_len(p - 1L)
223+
probs <- 1 / (choose(p, S) * S * (p - S))
224+
return(probs / sum(probs))
225+
}
226+
227+
# Kernel weights per coalition size. Sums the kernel_weights over the number of
228+
# coalitions with same sum.
229+
kernel_weights_per_coalition_size <- function(p, S = seq_len(p - 1L)) {
230+
probs <- 1 / (S * (p - S))
230231
return(probs / sum(probs))
231232
}
232233

@@ -236,7 +237,7 @@ prop_exact <- function(p, deg) {
236237
if (deg == 0) {
237238
return(0)
238239
}
239-
w <- kernel_weights(p, per_coalition_size = TRUE)
240+
w <- kernel_weights_per_coalition_size(p)
240241
w_total <- 2 * sum(w[seq_len(deg)]) - w[deg] * (p == 2 * deg)
241242
return(w_total)
242243
}

tests/testthat/test-kernelshap-utils.R

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
test_that("sum of kernel weights is 1", {
22
for (p in 2:10) {
3-
expect_equal(sum(kernel_weights(p, per_coalition_size = FALSE)), 1.0)
4-
expect_equal(sum(kernel_weights(p, per_coalition_size = TRUE)), 1.0)
3+
expect_equal(sum(kernel_weights(p)), 1.0)
4+
expect_equal(sum(kernel_weights_per_coalition_size(p)), 1.0)
55
}
66
})
77

88
test_that("Sum of kernel weights is 1, even for subset of domain", {
9-
expect_equal(sum(kernel_weights(10L, S = 2:5, per_coalition_size = FALSE)), 1.0)
10-
expect_equal(sum(kernel_weights(10L, S = 2:5, per_coalition_size = TRUE)), 1.0)
9+
expect_equal(sum(kernel_weights_per_coalition_size(10L, S = 2:5)), 1.0)
1110
})
1211

1312
p <- 10L

0 commit comments

Comments
 (0)