Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
Signed-off-by: 1andrin <[email protected]>
  • Loading branch information
1andrin committed Nov 15, 2024
1 parent 6a5a4f2 commit 1cb92b4
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions causaltune/erupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def score(
return (w * outcome).mean()

def weights(
self,
df: pd.DataFrame,
policy: Union[Callable, np.ndarray, pd.Series]
self, df: pd.DataFrame, policy: Union[Callable, np.ndarray, pd.Series]
) -> pd.Series:
W = df[self.treatment_name].astype(int)
assert all([x >= 0 for x in W.unique()]), "Treatment values must be non-negative integers"
assert all(
[x >= 0 for x in W.unique()]
), "Treatment values must be non-negative integers"

# Handle policy input
if callable(policy):
Expand All @@ -87,7 +87,9 @@ def weights(
policy = policy.values
policy = np.array(policy)
d = pd.Series(index=df.index, data=policy)
assert all([x >= 0 for x in d.unique()]), "Policy values must be non-negative integers"
assert all(
[x >= 0 for x in d.unique()]
), "Policy values must be non-negative integers"

# Get propensity scores with better handling of edge cases
if isinstance(self.propensity_model, DummyPropensity):
Expand All @@ -98,25 +100,25 @@ def weights(
except Exception:
# Fallback to safe defaults if prediction fails
p = np.full((len(df), 2), 0.5)

# Clip propensity scores to avoid division by zero or extreme weights
min_clip = max(1e-6, self.clip) # Ensure minimum clip is not too small
p = np.clip(p, min_clip, 1 - min_clip)

# Initialize weights
# Initialize weights
weight = np.zeros(len(df))

try:
# Calculate weights with safer operations
for i in W.unique():
mask = (W == i)
mask = W == i
p_i = p[:, i][mask]
# Add small constant to denominator to prevent division by zero
weight[mask] = 1 / (p_i + 1e-10)
except Exception:
# If something goes wrong, return safe weights
weight = np.ones(len(df))

# Zero out weights where policy disagrees with actual treatment
weight[d != W] = 0.0

Expand All @@ -133,12 +135,12 @@ def weights(
else:
# If all weights are zero, use uniform weights
weight = np.ones(len(df)) / len(df)

# Final check for NaNs
if np.any(np.isnan(weight)):
# Replace any remaining NaNs with uniform weights
weight = np.ones(len(df)) / len(df)

return pd.Series(index=df.index, data=weight)

def probabilistic_erupt_score(
Expand Down

0 comments on commit 1cb92b4

Please sign in to comment.