@@ -73,12 +73,12 @@ def score(
7373 return (w * outcome ).mean ()
7474
7575 def weights (
76- self ,
77- df : pd .DataFrame ,
78- policy : Union [Callable , np .ndarray , pd .Series ]
76+ self , df : pd .DataFrame , policy : Union [Callable , np .ndarray , pd .Series ]
7977 ) -> pd .Series :
8078 W = df [self .treatment_name ].astype (int )
81- assert all ([x >= 0 for x in W .unique ()]), "Treatment values must be non-negative integers"
79+ assert all (
80+ [x >= 0 for x in W .unique ()]
81+ ), "Treatment values must be non-negative integers"
8282
8383 # Handle policy input
8484 if callable (policy ):
@@ -87,7 +87,9 @@ def weights(
8787 policy = policy .values
8888 policy = np .array (policy )
8989 d = pd .Series (index = df .index , data = policy )
90- assert all ([x >= 0 for x in d .unique ()]), "Policy values must be non-negative integers"
90+ assert all (
91+ [x >= 0 for x in d .unique ()]
92+ ), "Policy values must be non-negative integers"
9193
9294 # Get propensity scores with better handling of edge cases
9395 if isinstance (self .propensity_model , DummyPropensity ):
@@ -98,25 +100,25 @@ def weights(
98100 except Exception :
99101 # Fallback to safe defaults if prediction fails
100102 p = np .full ((len (df ), 2 ), 0.5 )
101-
103+
102104 # Clip propensity scores to avoid division by zero or extreme weights
103105 min_clip = max (1e-6 , self .clip ) # Ensure minimum clip is not too small
104106 p = np .clip (p , min_clip , 1 - min_clip )
105107
106- # Initialize weights
108+ # Initialize weights
107109 weight = np .zeros (len (df ))
108-
110+
109111 try :
110112 # Calculate weights with safer operations
111113 for i in W .unique ():
112- mask = ( W == i )
114+ mask = W == i
113115 p_i = p [:, i ][mask ]
114116 # Add small constant to denominator to prevent division by zero
115117 weight [mask ] = 1 / (p_i + 1e-10 )
116118 except Exception :
117119 # If something goes wrong, return safe weights
118120 weight = np .ones (len (df ))
119-
121+
120122 # Zero out weights where policy disagrees with actual treatment
121123 weight [d != W ] = 0.0
122124
@@ -133,12 +135,12 @@ def weights(
133135 else :
134136 # If all weights are zero, use uniform weights
135137 weight = np .ones (len (df )) / len (df )
136-
138+
137139 # Final check for NaNs
138140 if np .any (np .isnan (weight )):
139141 # Replace any remaining NaNs with uniform weights
140142 weight = np .ones (len (df )) / len (df )
141-
143+
142144 return pd .Series (index = df .index , data = weight )
143145
144146 def probabilistic_erupt_score (
0 commit comments