-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathtest_did_had_parity.py
More file actions
502 lines (439 loc) · 20.5 KB
/
Copy pathtest_did_had_parity.py
File metadata and controls
502 lines (439 loc) · 20.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
"""Cross-language end-to-end parity tests against R `DIDHAD::did_had`.
HAD Phase 4 (PR #389): verifies that Python
``HeterogeneousAdoptionDiD.fit()`` (overall, event-study, placebo,
yatchew, trends_lin) and ``yatchew_hr_test()`` match R
``DIDHAD::did_had()`` v2.0.0 bit-exactly on shared synthetic input.
Tolerances (per Phase 4 plan):
- Point / SE / CI bounds: ``atol=1e-8`` (full pipeline through nprobust
numerical paths).
- Yatchew T-stat (closed-form ratio): ``atol=1e-10``.
- Per-horizon arrays: shape exact, values at the appropriate per-field
tolerance.
Point-estimate convention deviation (documented in REGISTRY.md):
R `DIDHAD::did_had` reports the CONVENTIONAL local-linear estimate in
its ``Estimate`` column (= ``(mean(ΔY) - tau.us) / mean(D)``) and
constructs the BIAS-CORRECTED CI separately. Python `HAD.fit` reports
the BIAS-CORRECTED estimate directly in ``att`` (= ``(mean(ΔY) - tau.bc)
/ mean(D)``). Both are valid Calonico-Cattaneo-Farrell (2018)
constructions; the CCF paper shows the BC CI has correct coverage even
when applied to the conventional point estimate.
For parity testing, ``Python att`` matches ``R (ci_lo + ci_hi) / 2``
(R's CI midpoint, which is the bias-corrected location). ``Python se``,
``conf_int_low``, ``conf_int_high`` match R's ``se``, ``ci_lo``,
``ci_hi`` directly.
Fixture generated by ``benchmarks/R/generate_did_had_golden.R``.
Guard per ``feedback_golden_file_pytest_skip``: CI isolated-install jobs
copy ``tests/`` only, not ``benchmarks/data/``, so a missing fixture
downgrades to pytest.skip rather than fail.
R-side row index → our event-time convention mapping:
R row "Effect_i" (ID = +i) → our event time e = i - 1
R row "Placebo_i" (ID = -i) →
without trends_lin: our event time e = -(i + 1)
with trends_lin: our event time e = -(i + 2)
The convention shift on the placebo side mirrors R's anchor swap
(F-1 anchor without trends_lin; F-2 anchor with trends_lin); on our
end the F-1 anchor is preserved and the math collapses to a uniform
``adjusted_dy[e] = dy_dict[e] - (e+1) × slope`` adjustment with the
``e=-2`` placebo dropped (R's "consumed" placebo).
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import pandas as pd
import pytest
from diff_diff import HeterogeneousAdoptionDiD
from diff_diff.had_pretests import yatchew_hr_test
FIXTURE_PATH = Path(__file__).parent.parent / "benchmarks" / "data" / "did_had_golden.json"
POINT_ATOL = 1e-8
SE_ATOL = 1e-8
CI_ATOL = 1e-8
YATCHEW_ATOL = 1e-10
def _load_fixture():
if not FIXTURE_PATH.exists():
pytest.skip(
f"Golden fixture {FIXTURE_PATH} missing — regenerate via "
f"`Rscript benchmarks/R/generate_did_had_golden.R`."
)
with open(FIXTURE_PATH) as f:
return json.load(f)
@pytest.fixture(scope="module")
def fixture():
return _load_fixture()
def _panel_from_fixture(dgp_entry: Dict[str, Any]) -> pd.DataFrame:
"""Reconstruct the long-format panel DataFrame from the JSON entry.
R serializes columns as parallel arrays; rebuild a DataFrame with
int unit/time columns to match HAD.fit's contract."""
panel = dgp_entry["panel"]
return pd.DataFrame(
{
"g": np.asarray(panel["g"], dtype=np.int64),
"t": np.asarray(panel["t"], dtype=np.int64),
"y": np.asarray(panel["y"], dtype=np.float64),
"d": np.asarray(panel["d"], dtype=np.float64),
}
)
def _r_id_to_event_time(r_id: int, trends_lin: bool) -> int:
"""Map R's ID column to our event_time convention.
R reports Effect_i with ID = +i (Effect_1 is at our e=0); Placebo_i
with ID = -i (under trends_lin=False, Placebo_1 is at our e=-2;
under trends_lin=True, Placebo_1 is at our e=-3). The placebo shift
reflects R's F-2 anchor swap under trends_lin."""
if r_id > 0:
return r_id - 1
placebo_lag = -r_id # 1, 2, ...
return -(placebo_lag + (2 if trends_lin else 1))
def _python_fit(
panel: pd.DataFrame,
effects: int,
placebo: int,
trends_lin: bool,
) -> Any:
"""Run HAD.fit on the same panel + same options as R.
Forces ``design="continuous_at_zero"`` (Design 1', d_lower=0) to
match R's ``did_had`` which always evaluates the local linear at
``d=0`` regardless of dose distribution. Our auto-detect would
otherwise pick Design 1 (``continuous_near_d_lower``,
``d_lower=d.min()``) for dose distributions with boundary density
bounded away from zero (e.g., Beta(2,2)), producing different
point estimates.
For the ``overall_e1`` combo (effects=1, placebo=0, trends_lin=False)
we slice the panel to exactly two periods (F-1 and F, where F=4 in
our fixture) and route through ``aggregate="overall"`` — that
exercises the actual two-period overall code path, not the
event-study path with a single horizon. PR #392 R1 P2 fix."""
est = HeterogeneousAdoptionDiD(design="continuous_at_zero")
if effects == 1 and placebo == 0 and not trends_lin:
# Slice to the two periods (F-1=3, F=4) that R's effects=1 case
# actually consumes for Effect_1 = Y[F] - Y[F-1].
F = int(panel.loc[panel["d"] > 0, "t"].min())
panel_2p = panel[panel["t"].isin([F - 1, F])].copy()
return est.fit(
panel_2p,
outcome_col="y",
dose_col="d",
time_col="t",
unit_col="g",
aggregate="overall",
)
return est.fit(
panel,
outcome_col="y",
dose_col="d",
time_col="t",
unit_col="g",
aggregate="event_study",
trends_lin=trends_lin,
)
def _as_list(x: Any) -> list:
"""jsonlite::write_json with auto_unbox=TRUE collapses single-element
vectors to scalars; rewrap them as 1-element lists so iteration is
uniform across single-horizon and multi-horizon results."""
if isinstance(x, (list, tuple)):
return list(x)
return [x]
def _zip_r_python(
r_result: Dict[str, Any], py_result: Any, trends_lin: bool
) -> List[Tuple[int, int, str]]:
"""Build (r_row_idx, py_event_idx, r_rowname) tuples zipping R rows
to Python event-time positions for parity assertions.
PR #392 R5 P3: also asserts the EXACT mapped event-time set is a
subset of Python's ``event_times`` and that the mapping is total
over R's reported rows (no R row maps to a missing Python
horizon). This catches future horizon-shape regressions where
Python silently drops an event-time the R fixture lists."""
py_event_times = py_result.event_times.tolist()
py_idx_by_event_time = {int(e): i for i, e in enumerate(py_event_times)}
pairs = []
r_event_ids = _as_list(r_result["event_id"])
r_rownames = _as_list(r_result["rownames"])
expected_event_times = []
for i, (r_id, rowname) in enumerate(zip(r_event_ids, r_rownames)):
e = _r_id_to_event_time(int(r_id), trends_lin)
expected_event_times.append(e)
if e not in py_idx_by_event_time:
raise AssertionError(
f"R row {rowname!r} (ID={r_id}) maps to our e={e}, but "
f"Python event_times = {py_event_times}. Mapping bug?"
)
pairs.append((i, py_idx_by_event_time[e], rowname))
# PR #392 R6 P3: exact set equality between R-mapped horizons and
# Python's event_times (was: subset inclusion). Catches
# horizon-selection regressions in BOTH directions:
# - missing_in_python: Python silently dropped a horizon R requested
# - extra_in_python: Python emitted an extra horizon R did not
# request (e.g. effects/placebo cap drift in our event_study path)
expected_set = set(expected_event_times)
py_set = set(py_event_times)
missing_in_python = expected_set - py_set
extra_in_python = py_set - expected_set
assert not missing_in_python and not extra_in_python, (
f"event_times set-equality mismatch: "
f"R-mapped {sorted(expected_set)}; Python emitted "
f"{sorted(py_event_times)}; missing in Python: "
f"{sorted(missing_in_python)}; extra in Python: "
f"{sorted(extra_in_python)}."
)
return pairs
# ---------------------------------------------------------------------------
# DGPs and combos enumerated from the fixture metadata.
# ---------------------------------------------------------------------------
DGP_NAMES = ["uniform_G200_F4_T5", "beta22_G200_F4_T5", "boundary_G200_F4_T5"]
# Combos that produce a HAD point/SE/CI block (excludes pure-yatchew rows
# which are validated separately).
POINT_COMBOS = [
("overall_e1", 1, 0, False),
("event_e2_p2", 2, 2, False),
("event_e2_p2_yatchew", 2, 2, False),
("event_e2_p2_trendslin", 2, 2, True),
("event_e2_p2_yatchew_trendslin", 2, 2, True),
]
YATCHEW_COMBOS = [
("event_e2_p2_yatchew", 2, 2, False),
("event_e2_p2_yatchew_trendslin", 2, 2, True),
]
@pytest.fixture(scope="module")
def panels(fixture):
"""Per-DGP DataFrames built from the JSON fixture."""
return {name: _panel_from_fixture(fixture["fixtures"][name]) for name in DGP_NAMES}
# ---------------------------------------------------------------------------
# Test classes — separated by mode for clean per-field shape handling.
# ---------------------------------------------------------------------------
class TestPointSEParity:
"""Point estimate / SE / CI parity vs R DIDHAD across all method combos."""
@pytest.mark.parametrize(
"dgp_name,combo_name,effects,placebo,trends_lin",
[(dgp, name, eff, pla, tl) for dgp in DGP_NAMES for name, eff, pla, tl in POINT_COMBOS],
)
def test_point_se_ci_parity(
self, fixture, panels, dgp_name, combo_name, effects, placebo, trends_lin
):
r_combo = fixture["fixtures"][dgp_name]["combos"][combo_name]
r_result = r_combo["result"]
py = _python_fit(panels[dgp_name], effects, placebo, trends_lin)
# `overall_e1` returns a scalar HeterogeneousAdoptionDiDResults
# (aggregate="overall" path); all others return the array
# HeterogeneousAdoptionDiDEventStudyResults. Build a uniform
# (rowname, py_att, py_se, py_ci_lo, py_ci_hi) iterator that
# works on both.
r_se_list = _as_list(r_result["se"])
r_ci_lo_list = _as_list(r_result["ci_lo"])
r_ci_hi_list = _as_list(r_result["ci_hi"])
if combo_name == "overall_e1":
# R's overall_e1 has exactly one row (Effect_1 → ID=1 → e=0).
r_idx = 0
rowname = "Effect_1"
py_att = float(py.att)
py_se = float(py.se)
py_ci_lo = float(py.conf_int[0])
py_ci_hi = float(py.conf_int[1])
iterations = [(r_idx, rowname, py_att, py_se, py_ci_lo, py_ci_hi)]
else:
pairs = _zip_r_python(r_result, py, trends_lin)
iterations = []
for r_idx, py_idx, rowname in pairs:
iterations.append(
(
r_idx,
rowname,
float(py.att[py_idx]),
float(py.se[py_idx]),
float(py.conf_int_low[py_idx]),
float(py.conf_int_high[py_idx]),
)
)
for r_idx, rowname, py_att, py_se, py_ci_lo, py_ci_hi in iterations:
r_se = r_se_list[r_idx]
r_ci_lo = r_ci_lo_list[r_idx]
r_ci_hi = r_ci_hi_list[r_idx]
# R's `Estimate` column is the CONVENTIONAL (non-bias-
# corrected) point estimate; the bias-corrected location
# is the CI midpoint. Python ships the bias-corrected
# location directly in `att`. Compare to CI midpoint.
r_att_bc = 0.5 * (r_ci_lo + r_ci_hi)
np.testing.assert_allclose(
py_att,
r_att_bc,
atol=POINT_ATOL,
rtol=0,
err_msg=(
f"{dgp_name}/{combo_name}/{rowname}: bias-corrected "
f"estimate mismatch (Python att vs R CI midpoint)"
),
)
np.testing.assert_allclose(
py_se,
r_se,
atol=SE_ATOL,
rtol=0,
err_msg=f"{dgp_name}/{combo_name}/{rowname}: SE mismatch",
)
np.testing.assert_allclose(
py_ci_lo,
r_ci_lo,
atol=CI_ATOL,
rtol=0,
err_msg=f"{dgp_name}/{combo_name}/{rowname}: ci_lo mismatch",
)
np.testing.assert_allclose(
py_ci_hi,
r_ci_hi,
atol=CI_ATOL,
rtol=0,
err_msg=f"{dgp_name}/{combo_name}/{rowname}: ci_hi mismatch",
)
class TestYatchewParity:
"""Yatchew T-stat parity vs R `DIDHAD::did_had(yatchew=TRUE)`.
R computes the trends-adjusted Effect_i / Placebo_i internally and
runs Yatchew on each horizon. Python parity reproduces the same
trends-adjusted dy values via the HAD pipeline and runs
`yatchew_hr_test` on each horizon.
Convention deviation (documented in REGISTRY.md):
Our `yatchew_hr_test` follows paper Appendix E literally with
the (1/G) population-variance convention:
sigma2_lin = sum(eps^2) / G
sigma2_diff = sum(diff_dy^2) / (2G)
R's `YatchewTest::yatchew_test` uses base R's `var()` (1/(N-1)
sample-variance) and `mean()` conventions, which scale the
numerator by N/(N-1) relative to ours. Both converge to the
same asymptotic distribution; for finite samples they differ
by exactly N/(N-1).
Parity is asserted at atol=1e-10 AFTER applying the convention
shift: ``R_T = py_T_hr × G / (G-1)``. This is bit-exact parity
against R's reported T-stat under the documented convention
transformation.
Placebo rows use a DIFFERENT null in R: per the DIDHAD README,
the yatchew flag on placebo tests "the null being tested is that
groups' F-1 to F-1-ℓ outcome evolution is mean independent of
their F-1+ℓ treatment". R's `YatchewTest::yatchew_test(order=0)`
fits Y ~ 1 (intercept only) instead of Y ~ D. Effect rows
(post-treatment) use R's `order=1` (linearity null) which matches
our default `yatchew_hr_test(null="linearity")`. Placebo rows are
routed through `yatchew_hr_test(null="mean_independence")` (added
post-PR #392), which mirrors R's `order=0`. Parity holds at the
same `atol=1e-10` after the documented N/(N-1) convention shift
on both modes."""
@pytest.mark.parametrize(
"dgp_name,combo_name,effects,placebo,trends_lin",
[(dgp, name, eff, pla, tl) for dgp in DGP_NAMES for name, eff, pla, tl in YATCHEW_COMBOS],
)
def test_yatchew_t_stat_parity(
self, fixture, panels, dgp_name, combo_name, effects, placebo, trends_lin
):
r_combo = fixture["fixtures"][dgp_name]["combos"][combo_name]
r_result = r_combo["result"]
# PR #392 R5 P3: assert R's reported (effects + placebo) row
# count matches the parametrize spec — catches future fixture
# drift where R's effects/placebo args don't actually drive
# the row count we expect.
n_yatchew_rows = len(_as_list(r_result["yatchew_t"]))
# Under trends_lin, R drops one placebo (consumed). Otherwise
# rows = effects + placebo (the auto-truncation cap from R is
# capped at the panel's max via did_het_adoption_main).
expected_rows = effects + placebo - (1 if trends_lin else 0)
assert n_yatchew_rows == expected_rows, (
f"R fixture row count for {combo_name} = {n_yatchew_rows}, "
f"expected effects+placebo{'-1' if trends_lin else ''} = "
f"{expected_rows}; fixture/combo spec drift?"
)
if "yatchew_t" not in r_result:
pytest.fail(
f"{combo_name} expected to have yatchew_t in fixture; "
f"check generate_did_had_golden.R"
)
# Reproduce trends-adjusted per-event-time dy in Python by
# extracting the dy_dict that HAD.fit consumes internally. We
# rebuild via the same _aggregate_multi_period_first_differences
# path + apply trends_lin detrending.
from diff_diff.had import _aggregate_multi_period_first_differences
panel = panels[dgp_name]
F = int(panel.loc[panel["d"] > 0, "t"].min())
all_periods = sorted(panel["t"].unique().tolist())
t_pre_list = [t for t in all_periods if t < F]
t_post_list = [t for t in all_periods if t >= F]
d_arr, dy_dict, _, _, _ = _aggregate_multi_period_first_differences(
panel,
"y",
"d",
"t",
"g",
F,
t_pre_list,
t_post_list,
None,
)
if trends_lin:
slope = -dy_dict[-2]
dy_dict = {e: dy_dict[e] - (e + 1) * slope for e in dy_dict.keys()}
del dy_dict[-2]
# Build a (r_row_idx → our event_time) map and run Yatchew per
# horizon. R reports yatchew_test rows in the SAME ORDER as
# resmat (Effect_1, Effect_2, ..., Placebo_1, ...).
r_yatchew_t = _as_list(r_result["yatchew_t"])
r_yatchew_n = _as_list(r_result["yatchew_n"])
r_event_ids = _as_list(r_result["event_id"])
for r_idx, r_id in enumerate(r_event_ids):
e = _r_id_to_event_time(int(r_id), trends_lin)
if e not in dy_dict:
continue
# Effect rows (R ID > 0): linearity null (Y ~ 1 + D), R's
# YatchewTest::yatchew_test(order=1).
# Placebo rows (R ID < 0): mean-independence null (Y ~ 1),
# R's YatchewTest::yatchew_test(order=0). Both modes share
# the same N/(N-1) convention shift downstream.
null_mode = "mean_independence" if int(r_id) < 0 else "linearity"
dy_e = dy_dict[e]
r = yatchew_hr_test(d_arr, dy_e, null=null_mode)
# Apply documented convention shift: R's T = our T × G/(G-1)
# (sample-variance vs population-variance denominators).
G_horizon = int(r_yatchew_n[r_idx])
py_t_in_r_convention = float(r.t_stat_hr) * G_horizon / (G_horizon - 1)
np.testing.assert_allclose(
py_t_in_r_convention,
float(r_yatchew_t[r_idx]),
atol=YATCHEW_ATOL,
rtol=0,
err_msg=(
f"{dgp_name}/{combo_name}/Yatchew row {r_idx} "
f"(R ID={r_id}, our e={e}, G={G_horizon}, "
f"null={null_mode!r}): T_hr mismatch "
f"after N/(N-1) convention shift"
),
)
class TestFixtureMetadata:
"""Sanity checks on the fixture itself."""
def test_metadata_versions_match(self, fixture):
"""Ensure the JSON metadata lists the EXACT pinned upstream
versions. PR #392 R4 P3: exact pin (not >=) so future
regeneration does not silently re-anchor the goldens to a
newer CRAN release while changelog / registry still cite the
old version. Bump these pins (here AND in
``benchmarks/R/generate_did_had_golden.R``) when intentionally
re-anchoring."""
meta = fixture["metadata"]
assert meta["didhad_version"] == "2.0.0", (
f"Fixture was generated against DIDHAD={meta['didhad_version']!r}; "
f"the parity test pins exactly 2.0.0. Regenerate after bumping "
f"the pin in both the generator and this test."
)
assert meta["yatchewtest_version"] == "1.1.1", (
f"Fixture was generated against YatchewTest="
f"{meta['yatchewtest_version']!r}; the parity test pins exactly "
f"1.1.1. Regenerate after bumping the pin."
)
# PR #392 R5 P3: nprobust is on the parity contract path
# (DIDHAD's local-linear bandwidth + bias-correction calls go
# through it), so pin it exactly too. Bump in lockstep with
# the generator's stopifnot guards.
assert meta["nprobust_version"] == "0.5.0", (
f"Fixture was generated against nprobust="
f"{meta['nprobust_version']!r}; the parity test pins exactly "
f"0.5.0. Regenerate after bumping the pin."
)
def test_metadata_n_dgps(self, fixture):
meta = fixture["metadata"]
assert meta["n_dgps"] == len(DGP_NAMES) == 3
def test_all_dgps_present(self, fixture):
for name in DGP_NAMES:
assert name in fixture["fixtures"], f"missing DGP {name!r}"