Skip to content

Commit e35047e

Browse files
committed
linter
Signed-off-by: 1andrin <[email protected]>
1 parent 1541247 commit e35047e

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

causaltune/erupt.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ def score(
7373
return (w * outcome).mean()
7474

7575
def weights(
76-
self,
77-
df: pd.DataFrame,
78-
policy: Union[Callable, np.ndarray, pd.Series]
76+
self, df: pd.DataFrame, policy: Union[Callable, np.ndarray, pd.Series]
7977
) -> pd.Series:
8078
W = df[self.treatment_name].astype(int)
81-
assert all([x >= 0 for x in W.unique()]), "Treatment values must be non-negative integers"
79+
assert all(
80+
[x >= 0 for x in W.unique()]
81+
), "Treatment values must be non-negative integers"
8282

8383
# Handle policy input
8484
if callable(policy):
@@ -87,7 +87,9 @@ def weights(
8787
policy = policy.values
8888
policy = np.array(policy)
8989
d = pd.Series(index=df.index, data=policy)
90-
assert all([x >= 0 for x in d.unique()]), "Policy values must be non-negative integers"
90+
assert all(
91+
[x >= 0 for x in d.unique()]
92+
), "Policy values must be non-negative integers"
9193

9294
# Get propensity scores with better handling of edge cases
9395
if isinstance(self.propensity_model, DummyPropensity):
@@ -98,25 +100,25 @@ def weights(
98100
except Exception:
99101
# Fallback to safe defaults if prediction fails
100102
p = np.full((len(df), 2), 0.5)
101-
103+
102104
# Clip propensity scores to avoid division by zero or extreme weights
103105
min_clip = max(1e-6, self.clip) # Ensure minimum clip is not too small
104106
p = np.clip(p, min_clip, 1 - min_clip)
105107

106-
# Initialize weights
108+
# Initialize weights
107109
weight = np.zeros(len(df))
108-
110+
109111
try:
110112
# Calculate weights with safer operations
111113
for i in W.unique():
112-
mask = (W == i)
114+
mask = W == i
113115
p_i = p[:, i][mask]
114116
# Add small constant to denominator to prevent division by zero
115117
weight[mask] = 1 / (p_i + 1e-10)
116118
except Exception:
117119
# If something goes wrong, return safe weights
118120
weight = np.ones(len(df))
119-
121+
120122
# Zero out weights where policy disagrees with actual treatment
121123
weight[d != W] = 0.0
122124

@@ -133,12 +135,12 @@ def weights(
133135
else:
134136
# If all weights are zero, use uniform weights
135137
weight = np.ones(len(df)) / len(df)
136-
138+
137139
# Final check for NaNs
138140
if np.any(np.isnan(weight)):
139141
# Replace any remaining NaNs with uniform weights
140142
weight = np.ones(len(df)) / len(df)
141-
143+
142144
return pd.Series(index=df.index, data=weight)
143145

144146
def probabilistic_erupt_score(

0 commit comments

Comments
 (0)