Skip to content

Commit d879c88

Browse files
igerberclaude
andcommitted
Address local AI review P2/P3 findings
- P2 (wooldridge): Extract shared `_warn_and_fill_nan_cohort(df, cohort, stacklevel)` helper used by both `_filter_sample` and `fit()`. Removes the copy-paste warning block that was flagged as a future drift risk. - P2 (tests): Add `test_inf_first_treat_warning_counts_rows_not_units` on a 4-unit x 3-period panel. 2 units carry inf across all 3 periods (6 inf rows, 2 inf units) — the warning must report 6, not 2, because `.replace(inf, 0)` is row-level. - P3 (utils wording): The `_compute_outcome_changes` excess-drop warning said "gaps or NaN outcomes" but the code actually counts all NaN first-differences. Rephrased to "additional NaN first-differences (e.g. NaN outcomes or unit-period gaps upstream)" so the message doesn't over-claim what the helper can detect. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a5f38e1 commit d879c88

4 files changed

Lines changed: 56 additions & 30 deletions

File tree

diff_diff/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -936,8 +936,9 @@ def _compute_outcome_changes(
936936
warnings.warn(
937937
f"check_parallel_trends dropped {n_dropped} row(s) with NaN "
938938
f"first-differences; {n_units_observed} are the expected "
939-
f"first-period-per-unit drops, and {n_unexpected_drops} came "
940-
f"from gaps or NaN outcomes. Parallel-trend statistics are "
939+
f"first-period-per-unit drops, and {n_unexpected_drops} are "
940+
f"additional NaN first-differences (e.g. NaN outcomes or "
941+
f"unit-period gaps upstream). Parallel-trend statistics are "
941942
f"computed on the remaining rows.",
942943
UserWarning,
943944
stacklevel=3,

diff_diff/wooldridge.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ def _resolve_survey_for_wooldridge(survey_design, sample, cluster_ids, cluster_n
113113
return resolved, survey_weights, survey_weight_type, survey_metadata, df_inf
114114

115115

116+
def _warn_and_fill_nan_cohort(df: pd.DataFrame, cohort: str, stacklevel: int) -> pd.DataFrame:
117+
"""Fill NaN cohort with 0 (never-treated) and warn with the row count.
118+
119+
Used by both `_filter_sample` (pre-fit) and `WooldridgeDiD.fit()` so the
120+
silent recategorization is surfaced on whichever entry path the caller
121+
hits first. See REGISTRY.md §WooldridgeDiD (axis-E silent coercion).
122+
"""
123+
n_nan_cohort = int(df[cohort].isna().sum())
124+
if n_nan_cohort > 0:
125+
warnings.warn(
126+
f"{n_nan_cohort} row(s) have NaN cohort values; filling with 0 "
127+
f"and treating the corresponding units as never-treated. Pass "
128+
f"an explicit never-treated marker (0) if this is not intended.",
129+
UserWarning,
130+
stacklevel=stacklevel,
131+
)
132+
df[cohort] = df[cohort].fillna(0)
133+
return df
134+
135+
116136
def _filter_sample(
117137
data: pd.DataFrame,
118138
unit: str,
@@ -129,20 +149,7 @@ def _filter_sample(
129149
(see _build_interaction_matrix).
130150
"""
131151
df = data.copy()
132-
# Normalise never-treated: fill NaN cohort with 0. Report the row count so
133-
# callers can see how many rows were recategorized as never-treated — a
134-
# silent recategorization here would quietly move units between the
135-
# treated and control sides of the estimator (axis-E silent coercion).
136-
n_nan_cohort = int(df[cohort].isna().sum())
137-
if n_nan_cohort > 0:
138-
warnings.warn(
139-
f"{n_nan_cohort} row(s) have NaN cohort values; filling with 0 "
140-
f"and treating the corresponding units as never-treated. Pass "
141-
f"an explicit never-treated marker (0) if this is not intended.",
142-
UserWarning,
143-
stacklevel=3,
144-
)
145-
df[cohort] = df[cohort].fillna(0)
152+
df = _warn_and_fill_nan_cohort(df, cohort, stacklevel=3)
146153

147154
treated_mask = df[cohort] > 0
148155

@@ -409,19 +416,7 @@ def fit(
409416
``NotImplementedError``.
410417
"""
411418
df = data.copy()
412-
# See `_filter_sample` for the analogous warning; fit() does its own
413-
# fillna earlier in the pipeline so we warn here too to cover the
414-
# direct-fit path.
415-
n_nan_cohort = int(df[cohort].isna().sum())
416-
if n_nan_cohort > 0:
417-
warnings.warn(
418-
f"{n_nan_cohort} row(s) have NaN cohort values; filling with 0 "
419-
f"and treating the corresponding units as never-treated. Pass "
420-
f"an explicit never-treated marker (0) if this is not intended.",
421-
UserWarning,
422-
stacklevel=2,
423-
)
424-
df[cohort] = df[cohort].fillna(0)
419+
df = _warn_and_fill_nan_cohort(df, cohort, stacklevel=2)
425420

426421
# 0a. Validate cohort is time-invariant within unit
427422
cohort_per_unit = df.groupby(unit)[cohort].nunique()

tests/test_continuous_did.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,36 @@ def test_no_inf_first_treat_no_warning(self):
675675
inf_warnings = [x for x in w if "inf in 'first_treat'" in str(x.message)]
676676
assert inf_warnings == []
677677

678+
def test_inf_first_treat_warning_counts_rows_not_units(self):
679+
"""The warning counts affected rows (not units). On a panel with
680+
multiple periods per unit, each inf row must count separately so the
681+
message surface matches the per-row semantics of `.replace(inf, 0)`."""
682+
# Build a 4-unit, 3-period panel (12 rows). 2 units have inf across
683+
# all 3 periods → 6 inf rows, 2 units, so row-count != unit-count.
684+
rows = []
685+
for unit in range(4):
686+
ft = np.inf if unit < 2 else 2.0
687+
dose = 0.0 if unit < 2 else 1.0
688+
for t in range(1, 4):
689+
rows.append({
690+
"unit": unit, "period": t, "outcome": float(unit + t),
691+
"first_treat": ft, "dose": dose,
692+
})
693+
data = pd.DataFrame(rows)
694+
est = ContinuousDiD()
695+
696+
with pytest.warns(
697+
UserWarning,
698+
match=r"6 row\(s\) have inf in 'first_treat'",
699+
):
700+
try:
701+
est.fit(data, "outcome", "unit", "period", "first_treat", "dose")
702+
except Exception:
703+
# Downstream validation may reject this minimal panel (too few
704+
# treated for OLS). We only care that the inf-row warning fires
705+
# with the correct row count.
706+
pass
707+
678708
def test_custom_dvals(self):
679709
data = generate_continuous_did_data(n_units=100, n_periods=3, seed=42)
680710
custom_grid = np.array([1.0, 2.0, 3.0])

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ def test_warns_on_nan_outcomes_with_excess_drop_count(self):
871871

872872
with pytest.warns(
873873
UserWarning,
874-
match=r"check_parallel_trends dropped \d+ row\(s\).*first-period-per-unit",
874+
match=r"check_parallel_trends dropped \d+ row\(s\).*additional NaN first-differences",
875875
):
876876
_compute_outcome_changes(
877877
df, outcome="outcome", time="period",

0 commit comments

Comments
 (0)