@@ -108,7 +108,7 @@ solver <- function(A, b, constraint) {
108
108
# to Kernel SHAP weights -> (m x p) matrix.
109
109
# The argument S can be used to restrict the range of sum(z).
110
110
sample_Z <- function (p , m , feature_names , S = 1 : (p - 1L )) {
111
- probs <- kernel_weights( p , per_coalition_size = TRUE , S = S )
111
+ probs <- kernel_weights_per_coalition_size( p , S = S )
112
112
N <- S [sample.int(length(S ), m , replace = TRUE , prob = probs )]
113
113
114
114
# Then, conditional on that number, set random positions of z to 1
@@ -159,7 +159,7 @@ input_sampling <- function(p, m, deg, feature_names) {
159
159
input_exact <- function (p , feature_names ) {
160
160
Z <- exact_Z(p , feature_names = feature_names )
161
161
Z <- Z [2L : (nrow(Z ) - 1L ), , drop = FALSE ]
162
- kw <- kernel_weights(p , per_coalition_size = FALSE ) # Kernel weights for all subsets
162
+ kw <- kernel_weights(p ) # Kernel weights for all subsets
163
163
w <- kw [rowSums(Z )] # Corresponding weight for each row in Z
164
164
w <- w / sum(w )
165
165
list (Z = Z , w = w , A = crossprod(Z , w * Z ))
@@ -204,7 +204,7 @@ input_partly_exact <- function(p, deg, feature_names) {
204
204
stop(" p must be >=2*deg" )
205
205
}
206
206
207
- kw <- kernel_weights(p , per_coalition_size = FALSE )
207
+ kw <- kernel_weights(p )
208
208
209
209
Z <- vector(" list" , deg )
210
210
for (k in seq_len(deg )) {
@@ -217,16 +217,17 @@ input_partly_exact <- function(p, deg, feature_names) {
217
217
list (Z = Z , w = w , A = crossprod(Z , w * Z ))
218
218
}
219
219
220
- # Kernel weight distribution
221
- #
222
- # `per_coalition_size = TRUE` is required, e.g., when one wants to sample random masks
223
- # according to the Kernel SHAP distribution: Pick a coalition size as per
224
- # these weights, then randomly place "on" positions. `FALSE` refer to weights
225
- # if all masks has been calculated and one wants to calculate their weights based
226
- # on the number of "on" positions.
227
- kernel_weights <- function (p , per_coalition_size , S = seq_len(p - 1L )) {
228
- const <- if (per_coalition_size ) 1 else choose(p , S )
229
- probs <- (p - 1 ) / (const * S * (p - S )) # could drop the numerator
220
+ # Kernel weight distribution. Gives the weight of each coalition vector of sum k
221
+ kernel_weights <- function (p ) {
222
+ S <- seq_len(p - 1L )
223
+ probs <- 1 / (choose(p , S ) * S * (p - S ))
224
+ return (probs / sum(probs ))
225
+ }
226
+
227
+ # Kernel weights per coalition size. Sums the kernel_weights over the number of
228
+ # coalitions with same sum.
229
+ kernel_weights_per_coalition_size <- function (p , S = seq_len(p - 1L )) {
230
+ probs <- 1 / (S * (p - S ))
230
231
return (probs / sum(probs ))
231
232
}
232
233
@@ -236,7 +237,7 @@ prop_exact <- function(p, deg) {
236
237
if (deg == 0 ) {
237
238
return (0 )
238
239
}
239
- w <- kernel_weights( p , per_coalition_size = TRUE )
240
+ w <- kernel_weights_per_coalition_size( p )
240
241
w_total <- 2 * sum(w [seq_len(deg )]) - w [deg ] * (p == 2 * deg )
241
242
return (w_total )
242
243
}
0 commit comments