forked from igerber/diff-diff
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
2892 lines (2527 loc) · 106 KB
/
Copy pathutils.py
File metadata and controls
2892 lines (2527 loc) · 106 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Utility functions for difference-in-differences estimation.
"""
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
import pandas as pd
from scipy import stats
# Import Rust backend if available (from _backend to avoid circular imports)
from diff_diff._backend import (
HAS_RUST_BACKEND,
_rust_project_simplex,
_rust_sdid_unit_weights,
_rust_compute_time_weights,
_rust_compute_noise_level,
_rust_sc_weight_fw,
_rust_sc_weight_fw_with_convergence,
_rust_sc_weight_fw_weighted,
_rust_sc_weight_fw_weighted_with_convergence,
)
from diff_diff.linalg import compute_robust_vcov as _compute_robust_vcov_linalg
from diff_diff.linalg import solve_ols as _solve_ols_linalg
# Numerical constants for optimization algorithms
_OPTIMIZATION_MAX_ITER = 1000 # Maximum iterations for weight optimization
_OPTIMIZATION_TOL = 1e-8 # Convergence tolerance for optimization
_NUMERICAL_EPS = 1e-10 # Small constant to prevent division by zero
# Cache for critical values to avoid repeated scipy calls
_critical_value_cache: Dict[Tuple[float, Optional[int]], float] = {}
def _get_critical_value(alpha: float, df: Optional[int] = None) -> float:
"""Return cached critical value for (alpha, df) pair."""
key = (alpha, df)
if key not in _critical_value_cache:
if df is not None:
_critical_value_cache[key] = float(stats.t.ppf(1 - alpha / 2, df))
else:
_critical_value_cache[key] = float(stats.norm.ppf(1 - alpha / 2))
return _critical_value_cache[key]
def validate_binary(arr: np.ndarray, name: str) -> None:
"""
Validate that an array contains only binary values (0 or 1).
Parameters
----------
arr : np.ndarray
Array to validate.
name : str
Name of the variable (for error messages).
Raises
------
ValueError
If array contains non-binary values.
"""
unique_values = np.unique(arr[~np.isnan(arr)])
if not np.all(np.isin(unique_values, [0, 1])):
raise ValueError(f"{name} must be binary (0 or 1). " f"Found values: {unique_values}")
def validate_covariate_names(
covariates: Optional[List[str]],
reserved_names: Iterable[str],
*,
estimator: str = "estimator",
) -> None:
"""
Validate that covariate column names do not collide with reserved
structural term names (and are not duplicated within ``covariates``).
Fitted coefficients are stored in a ``name -> value`` dict built by zipping
a variable-name list -- structural term names PLUS the user covariate column
names appended verbatim -- with the coefficient vector. A covariate whose
name equals a reserved structural name (the intercept ``const``, the
treatment/time indicators, the interaction term, period dummies,
fixed-effect dummies, or an internal working column) would silently
overwrite the structural coefficient (Python dict last-write-wins),
corrupting the result with no error. Duplicate names within ``covariates``
collapse to a single dict entry the same way.
The comparison is case-sensitive: column names and dict keys are
case-sensitive, so e.g. ``Const`` does not actually collide with ``const``
and is allowed.
Parameters
----------
covariates : list of str or None
User-supplied covariate column names. ``None`` or empty is a no-op.
reserved_names : iterable of str
Reserved structural term names this estimator builds (estimator-specific).
estimator : str
Estimator name, used in the error message.
Raises
------
ValueError
If a covariate name collides with a reserved structural name, or if
``covariates`` contains duplicate names.
"""
if not covariates:
return
reserved = set(reserved_names)
collisions = sorted({c for c in covariates if c in reserved})
if collisions:
raise ValueError(
f"{estimator}: covariate name(s) {collisions} collide with reserved "
f"structural term name(s). These names are used internally for the "
f"intercept, the treatment/time indicators, the interaction term, "
f"period dummies, fixed-effect dummies, or internal working columns, "
f"and a colliding covariate would silently overwrite the structural "
f"coefficient. Rename the covariate column(s). Reserved names for "
f"this fit: {sorted(reserved)}."
)
seen: set = set()
duplicates = []
for c in covariates:
if c in seen:
duplicates.append(c)
seen.add(c)
if duplicates:
raise ValueError(
f"{estimator}: duplicate covariate name(s) {sorted(set(duplicates))} "
f"in `covariates`. Each covariate maps to one coefficient; duplicates "
f"collapse to a single entry. Remove the duplicate(s)."
)
def validate_design_term_names(
var_names: Iterable[str],
*,
estimator: str = "estimator",
) -> None:
"""
Raise if the assembled design term-name list contains duplicates.
Backstop for :func:`validate_covariate_names`: even after the user
covariates are cleared, a fixed-effect dummy name (``{fe}_{value}``) can
still collide with a structural term — most notably a ``MultiPeriodDiD``
``period_{p}`` event-study key when a non-time fixed effect produces matching
dummy names — or with another dummy. Such a duplicate would silently
overwrite a coefficient when ``var_names`` is zipped into the result's
``coefficients`` dict (Python dict last-write-wins). This checks the FINAL
name list (structural terms + covariates + fixed-effect dummies) right
before the dict is built, catching collisions that depend on the data and so
cannot be known up front.
Parameters
----------
var_names : iterable of str
The fully assembled design-matrix column-name list.
estimator : str
Estimator name, used in the error message.
Raises
------
ValueError
If any name appears more than once.
"""
seen: set = set()
duplicates = []
for name in var_names:
if name in seen:
duplicates.append(name)
seen.add(name)
if duplicates:
raise ValueError(
f"{estimator}: the fitted design has duplicate term name(s) "
f"{sorted(set(duplicates))} — a covariate or fixed-effect dummy name "
f"collides with a structural term (intercept, treatment/time "
f"indicators, the interaction, or period dummies) or with another "
f"column. This would silently overwrite a coefficient in the result. "
f"Rename the offending fixed-effect category or covariate column."
)
def fe_dummy_names(col: pd.Series, prefix: str) -> List[str]:
"""
Reserved fixed-effect dummy column names for the collision guard, matching
``pd.get_dummies(col, prefix=prefix, drop_first=True).columns`` WITHOUT
materializing the dense ``(n x G)`` dummy matrix.
The within-transform ``TwoWayFixedEffects`` path is specifically designed to
avoid expanding high-cardinality fixed-effect dummies (that is its scaling
contract), so the collision guard must reserve those names without building
the dummy block. ``pd.get_dummies`` orders categories via
``pd.Categorical(col).categories`` — sorted unique values for a plain column,
the declared category order for a ``Categorical`` — then ``drop_first=True``
drops the first. This derivation reproduces that exactly (including
``Categorical`` columns with a non-default category order) at ``O(G)`` memory.
Parameters
----------
col : pandas.Series
The fixed-effect / unit / time column.
prefix : str
Dummy-name prefix (the project uses ``fe`` for ``fixed_effects`` and
``_fe_{unit}`` / ``_fe_{time}`` for TWFE unit/time dummies).
Returns
-------
list of str
The kept (post ``drop_first``) dummy column names.
"""
if isinstance(col.dtype, pd.CategoricalDtype):
cats = list(col.cat.categories)
else:
cats = list(pd.Categorical(col).categories)
return [f"{prefix}_{c}" for c in cats[1:]]
def warn_if_not_converged(
converged: bool,
method_name: str,
max_iter: int,
tol: Optional[float] = None,
stacklevel: int = 3,
) -> None:
"""Emit a UserWarning when an iterative solver exhausts max_iter without converging.
Shared helper for axis-B silent-failure fixes (iterative loops that otherwise
return the current iterate without signaling non-convergence).
"""
if converged:
return
tol_suffix = f" (tol={tol})" if tol is not None else ""
warnings.warn(
f"{method_name} did not converge in {max_iter} iterations{tol_suffix}. "
"Results may be inaccurate.",
UserWarning,
stacklevel=stacklevel,
)
def compute_robust_se(
X: np.ndarray, residuals: np.ndarray, cluster_ids: Optional[np.ndarray] = None
) -> np.ndarray:
"""
Compute heteroskedasticity-robust (HC1) or cluster-robust standard errors.
This function is a thin wrapper around the optimized implementation in
diff_diff.linalg for backwards compatibility.
Parameters
----------
X : np.ndarray
Design matrix of shape (n, k).
residuals : np.ndarray
Residuals from regression of shape (n,).
cluster_ids : np.ndarray, optional
Cluster identifiers for cluster-robust SEs.
Returns
-------
np.ndarray
Variance-covariance matrix of shape (k, k).
"""
return _compute_robust_vcov_linalg(X, residuals, cluster_ids)
def compute_confidence_interval(
estimate: float, se: float, alpha: float = 0.05, df: Optional[int] = None
) -> Tuple[float, float]:
"""
Compute confidence interval for an estimate.
Parameters
----------
estimate : float
Point estimate.
se : float
Standard error.
alpha : float
Significance level (default 0.05 for 95% CI).
df : int, optional
Degrees of freedom. If None, uses normal distribution.
Returns
-------
tuple
(lower_bound, upper_bound) of confidence interval.
"""
critical_value = _get_critical_value(alpha, df)
lower = estimate - critical_value * se
upper = estimate + critical_value * se
return (lower, upper)
def compute_p_value(t_stat: float, df: Optional[int] = None, two_sided: bool = True) -> float:
"""
Compute p-value for a t-statistic.
Parameters
----------
t_stat : float
T-statistic.
df : int, optional
Degrees of freedom. If None, uses normal distribution.
two_sided : bool
Whether to compute two-sided p-value (default True).
Returns
-------
float
P-value.
"""
if df is not None:
p_value = stats.t.sf(np.abs(t_stat), df)
else:
p_value = stats.norm.sf(np.abs(t_stat))
if two_sided:
p_value *= 2
return float(p_value)
def safe_inference(effect, se, alpha=0.05, df=None):
"""Compute t_stat, p_value, conf_int with NaN-safe gating.
When SE is non-finite, zero, or negative, ALL inference fields
are set to NaN to prevent misleading statistical output.
Accepts scalar inputs only (not numpy arrays). All existing inference
call sites operate on scalars within loops.
Parameters
----------
effect : float
Point estimate (treatment effect or coefficient).
se : float
Standard error of the estimate.
alpha : float, optional
Significance level for confidence interval (default 0.05).
df : int, optional
Degrees of freedom. If None, uses normal distribution.
Returns
-------
tuple
(t_stat, p_value, (ci_lower, ci_upper)). All NaN when SE is
non-finite, zero, or negative.
"""
if not (np.isfinite(se) and se > 0):
return np.nan, np.nan, (np.nan, np.nan)
if df is not None and df <= 0:
# Undefined degrees of freedom (e.g., rank-deficient replicate design)
return np.nan, np.nan, (np.nan, np.nan)
t_stat = effect / se
p_value = compute_p_value(t_stat, df=df)
conf_int = compute_confidence_interval(effect, se, alpha, df=df)
return t_stat, p_value, conf_int
def safe_inference_batch(effects, ses, alpha=0.05, df=None):
"""Vectorized batch inference for arrays of effects and SEs.
Parameters
----------
effects : np.ndarray
Array of point estimates.
ses : np.ndarray
Array of standard errors.
alpha : float, optional
Significance level (default 0.05).
df : int, optional
Degrees of freedom. If None, uses normal distribution.
Returns
-------
t_stats : np.ndarray
p_values : np.ndarray
ci_lowers : np.ndarray
ci_uppers : np.ndarray
"""
effects = np.asarray(effects, dtype=float)
ses = np.asarray(ses, dtype=float)
n = len(effects)
t_stats = np.full(n, np.nan)
p_values = np.full(n, np.nan)
ci_lowers = np.full(n, np.nan)
ci_uppers = np.full(n, np.nan)
# Undefined df (e.g., rank-deficient replicate design) → all NaN
if df is not None and df <= 0:
return t_stats, p_values, ci_lowers, ci_uppers
valid = np.isfinite(ses) & (ses > 0)
if not np.any(valid):
return t_stats, p_values, ci_lowers, ci_uppers
t_stats[valid] = effects[valid] / ses[valid]
if df is not None:
p_values[valid] = 2.0 * stats.t.sf(np.abs(t_stats[valid]), df)
else:
p_values[valid] = 2.0 * stats.norm.sf(np.abs(t_stats[valid]))
crit = _get_critical_value(alpha, df)
ci_lowers[valid] = effects[valid] - crit * ses[valid]
ci_uppers[valid] = effects[valid] + crit * ses[valid]
return t_stats, p_values, ci_lowers, ci_uppers
# =============================================================================
# Wild Cluster Bootstrap
# =============================================================================
@dataclass
class WildBootstrapResults:
"""
Results from wild cluster bootstrap inference.
Attributes
----------
se : float
Analytical cluster-robust (CR1) standard error of the coefficient. The
wild bootstrap studentizes the test with this SE; it is not a rescaled
bootstrap dispersion.
p_value : float
Wild cluster bootstrap p-value (two-tailed or equal-tailed).
t_stat_original : float
Studentized statistic of the original estimate, ``(coef - null) / se``.
ci_lower : float
Lower bound of the confidence interval (by test inversion).
ci_upper : float
Upper bound of the confidence interval (by test inversion).
n_clusters : int
Number of clusters in the data.
n_bootstrap : int
Number of bootstrap replications.
weight_type : str
Type of bootstrap weights used ("rademacher", "webb", or "mammen").
alpha : float
Significance level used for confidence interval.
p_val_type : str
Test shape used ("two-tailed" or "equal-tailed").
bootstrap_distribution : np.ndarray, optional
Bootstrap distribution of the studentized statistic ``t*`` evaluated at
the null (if requested).
References
----------
Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
Bootstrap-Based Improvements for Inference with Clustered Errors.
The Review of Economics and Statistics, 90(3), 414-427.
"""
se: float
p_value: float
t_stat_original: float
ci_lower: float
ci_upper: float
n_clusters: int
n_bootstrap: int
weight_type: str
alpha: float = 0.05
p_val_type: str = "two-tailed"
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
def summary(self) -> str:
"""Generate formatted summary of bootstrap results."""
lines = [
"Wild Cluster Bootstrap Results",
"=" * 40,
f"Cluster-robust SE: {self.se:.6f}",
f"Bootstrap p-value: {self.p_value:.4f}",
f"Studentized t-stat: {self.t_stat_original:.4f}",
f"CI ({int((1-self.alpha)*100)}%): [{self.ci_lower:.6f}, {self.ci_upper:.6f}]",
f"Number of clusters: {self.n_clusters}",
f"Bootstrap reps: {self.n_bootstrap}",
f"Weight type: {self.weight_type}",
f"Test type: {self.p_val_type}",
]
return "\n".join(lines)
def print_summary(self) -> None:
"""Print formatted summary to stdout."""
print(self.summary())
def _generate_rademacher_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
"""
Generate Rademacher weights: +1 or -1 with probability 0.5.
Parameters
----------
n_clusters : int
Number of clusters.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Array of Rademacher weights.
"""
return np.asarray(rng.choice([-1.0, 1.0], size=n_clusters))
def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
"""
Generate Webb's 6-point distribution weights.
Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)}
with equal probabilities (1/6 each), giving E[w]=0 and Var(w)=1.0.
This distribution is recommended for very few clusters (G < 10) as it
provides better finite-sample properties than Rademacher weights.
Parameters
----------
n_clusters : int
Number of clusters.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Array of Webb weights.
References
----------
Webb, M. D. (2014). Reworking wild bootstrap based inference for
clustered errors. Queen's Economics Department Working Paper No. 1315.
Note: Uses equal probabilities (1/6 each) matching R's `did` package,
which gives unit variance for consistency with other weight distributions.
"""
values = np.array(
[
-np.sqrt(3 / 2),
-np.sqrt(2 / 2),
-np.sqrt(1 / 2),
np.sqrt(1 / 2),
np.sqrt(2 / 2),
np.sqrt(3 / 2),
]
)
# Equal probabilities (1/6 each) matching R's did package, giving Var(w) = 1.0
return np.asarray(rng.choice(values, size=n_clusters))
def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
"""
Generate Mammen's two-point distribution weights.
Values: {-(sqrt(5)-1)/2, (sqrt(5)+1)/2}
with probabilities {(sqrt(5)+1)/(2*sqrt(5)), (sqrt(5)-1)/(2*sqrt(5))}.
This distribution satisfies E[v]=0, E[v^2]=1, E[v^3]=1, which provides
asymptotic refinement for skewed error distributions.
Parameters
----------
n_clusters : int
Number of clusters.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Array of Mammen weights.
References
----------
Mammen, E. (1993). Bootstrap and Wild Bootstrap for High Dimensional
Linear Models. The Annals of Statistics, 21(1), 255-285.
"""
sqrt5 = np.sqrt(5)
# Values from Mammen (1993)
val1 = -(sqrt5 - 1) / 2 # approximately -0.618
val2 = (sqrt5 + 1) / 2 # approximately 1.618 (golden ratio)
# Probability of val1
p1 = (sqrt5 + 1) / (2 * sqrt5) # approximately 0.724
return np.asarray(rng.choice([val1, val2], size=n_clusters, p=[p1, 1 - p1]))
def _wild_weight_matrix(
n_clusters: int,
n_bootstrap: int,
weight_type: str,
rng: np.random.Generator,
) -> np.ndarray:
"""Build the ``(B, n_clusters)`` matrix of cluster-level bootstrap weights.
For Rademacher weights with few clusters all ``2**n_clusters`` sign-vectors
are enumerated (deterministic) once ``n_bootstrap`` reaches the number of
possible draws — i.e. when ``2**n_clusters <= n_bootstrap`` (and
``n_clusters <= 20``, a guard against pathological memory use). This matches
the full-enumeration switch of ``fwildclusterboot::boottest`` (verified: for
``G=10`` boottest samples at ``B=1023`` and enumerates at ``B=1024``); the
reported ``n_bootstrap`` is then ``2**n_clusters``. (Only ``2**(n_clusters-1)``
of those draws have distinct ``|t*|`` — each draw and its all-signs-flipped
mirror share ``|t*|`` — but the full set is materialized.) Enumeration removes
RNG dependence in the few-cluster regime where the wild bootstrap matters
most. Otherwise ``n_bootstrap`` weight vectors are sampled. Webb/Mammen always
sample: the sign-flip enumeration symmetry is Rademacher-specific (Mammen is
asymmetric, Webb is a 6-point law).
"""
if weight_type == "rademacher" and n_clusters <= 20 and 2**n_clusters <= n_bootstrap:
n_enum = 2**n_clusters
bits = (np.arange(n_enum)[:, None] >> np.arange(n_clusters)) & 1
return np.where(bits == 1, 1.0, -1.0)
generators = {
"rademacher": _generate_rademacher_weights,
"webb": _generate_webb_weights,
"mammen": _generate_mammen_weights,
}
generate = generators[weight_type]
weights = np.empty((n_bootstrap, n_clusters))
for b in range(n_bootstrap):
weights[b] = generate(n_clusters, rng)
return weights
def wild_bootstrap_se(
X: np.ndarray,
y: np.ndarray,
residuals: np.ndarray,
cluster_ids: np.ndarray,
coefficient_index: int,
n_bootstrap: int = 999,
weight_type: str = "rademacher",
null_hypothesis: float = 0.0,
alpha: float = 0.05,
seed: Optional[int] = None,
return_distribution: bool = False,
p_val_type: str = "two-tailed",
) -> WildBootstrapResults:
"""
Compute wild cluster bootstrap standard errors and p-values.
Implements the Wild Cluster Restricted (WCR) bootstrap of Cameron, Gelbach,
and Miller (2008), matching the defaults of R's ``fwildclusterboot::boottest``
(Roodman, MacKinnon, Nielsen & Webb 2019): the null ``H0: coefficient =
null_hypothesis`` is genuinely imposed by re-estimating the model with the
coefficient's column dropped, the bootstrap DGP resamples the *restricted*
residuals, and the confidence interval is obtained by **inverting the
bootstrap test** (the set of null values not rejected at level ``alpha``) so
that the p-value and CI are mutually consistent (``0 in CI`` iff
``p >= alpha``). For Rademacher weights with few clusters all
``2**n_clusters`` sign-vectors are enumerated (deterministic) when
``2**n_clusters <= n_bootstrap`` (the ``boottest`` full-enumeration trigger —
it switches to enumeration once ``n_bootstrap`` reaches the number of
possible draws) and ``n_clusters <= 20`` (a memory guard); the reported
``n_bootstrap`` is then ``2**n_clusters``. Otherwise signs are sampled.
The reported ``se`` is the analytical cluster-robust (CR1) standard error of
the original estimate — the studentized bootstrap drives the p-value and CI,
not a re-scaled bootstrap dispersion.
Parameters
----------
X : np.ndarray
Design matrix of shape (n, k).
y : np.ndarray
Outcome vector of shape (n,).
residuals : np.ndarray
Retained for backward compatibility and IGNORED by the WCR
implementation, which recomputes the original fit and the restricted
(null-imposed) residualization internally from ``X`` and ``y``.
cluster_ids : np.ndarray
Cluster identifiers of shape (n,).
coefficient_index : int
Index of the coefficient for which to compute bootstrap inference.
For DiD, this is typically 3 (the treatment*post interaction term).
n_bootstrap : int, default=999
Number of bootstrap replications. Odd numbers are recommended for
exact p-value computation.
weight_type : str, default="rademacher"
Type of bootstrap weights:
- "rademacher": +1 or -1 with equal probability (standard choice)
- "webb": 6-point distribution (recommended for <10 clusters)
- "mammen": Two-point distribution with skewness correction
null_hypothesis : float, default=0.0
Value of the null hypothesis for p-value computation.
alpha : float, default=0.05
Significance level for confidence interval.
seed : int, optional
Random seed for reproducibility. If None (default), results
will vary between runs.
return_distribution : bool, default=False
If True, include the bootstrap distribution of the studentized statistic
``t*`` (evaluated at the null) in the results.
p_val_type : str, default="two-tailed"
Shape of the test (mirrors ``boottest``'s ``p_val_type``):
- "two-tailed": test on ``|t*|``; two-tailed CI by inversion (the
interval need not be symmetric about the estimate).
- "equal-tailed": each tail tested at ``alpha/2``; equal-tailed CI.
Returns
-------
WildBootstrapResults
Dataclass containing bootstrap SE, p-value, confidence interval,
and other inference results.
Raises
------
ValueError
If weight_type is not recognized or if there are fewer than 2 clusters.
Warns
-----
UserWarning
If the number of clusters is less than 5, as bootstrap inference
may be unreliable.
Examples
--------
>>> from diff_diff.utils import wild_bootstrap_se
>>> results = wild_bootstrap_se(
... X, y, residuals, cluster_ids,
... coefficient_index=3, # ATT coefficient
... n_bootstrap=999,
... weight_type="rademacher",
... seed=42
... )
>>> print(f"Bootstrap SE: {results.se:.4f}")
>>> print(f"Bootstrap p-value: {results.p_value:.4f}")
References
----------
Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
Bootstrap-Based Improvements for Inference with Clustered Errors.
The Review of Economics and Statistics, 90(3), 414-427.
MacKinnon, J. G., & Webb, M. D. (2018). The wild bootstrap for
few (treated) clusters. The Econometrics Journal, 21(2), 114-135.
"""
# Validate inputs
valid_weight_types = ["rademacher", "webb", "mammen"]
if weight_type not in valid_weight_types:
raise ValueError(f"weight_type must be one of {valid_weight_types}, got '{weight_type}'")
valid_p_val_types = ["two-tailed", "equal-tailed"]
if p_val_type not in valid_p_val_types:
raise ValueError(f"p_val_type must be one of {valid_p_val_types}, got '{p_val_type}'")
unique_clusters = np.unique(cluster_ids)
n_clusters = len(unique_clusters)
if n_clusters < 2:
raise ValueError(f"Wild cluster bootstrap requires at least 2 clusters, got {n_clusters}")
if n_clusters < 5:
warnings.warn(
f"Only {n_clusters} clusters detected. Wild cluster bootstrap inference may be "
"unreliable with fewer than 5 clusters. With Rademacher weights all "
f"{2 ** n_clusters} sign-vectors are enumerated exactly when "
f"n_bootstrap >= 2**n_clusters = {2 ** n_clusters}; Webb weights "
"(weight_type='webb') improve finite-sample behaviour but are sampled, not "
"enumerated.",
UserWarning,
)
rng = np.random.default_rng(seed)
n = X.shape[0]
def _degenerate() -> WildBootstrapResults:
# All-or-nothing NaN contract (feedback_bootstrap_nan_on_invalid_contract):
# when the original fit or the bootstrap is degenerate, NaN the entire
# (se, t_stat, p_value, ci) inference family together rather than mixing
# analytical and bootstrap quantities on the same coefficient.
return WildBootstrapResults(
se=float("nan"),
p_value=float("nan"),
t_stat_original=float("nan"),
ci_lower=float("nan"),
ci_upper=float("nan"),
n_clusters=n_clusters,
n_bootstrap=n_bootstrap,
weight_type=weight_type,
alpha=alpha,
p_val_type=p_val_type,
bootstrap_distribution=None,
)
# Step 1: original fit. Establishes the analytical cluster-robust (CR1) SE
# that studentizes the test, and the set of identified (kept) columns so the
# bootstrap stays rank-robust (e.g. an always-treated unit dummy collinear
# with treated*post on the full-dummy TWFE path: solve_ols drops the nuisance
# column and reports it as NaN, while the identified ATT is retained).
# First fit WITHOUT the cluster-robust vcov: this identifies the kept
# (full-rank) columns and lets us reject a saturated design *before*
# requesting the cluster sandwich. The shared CR1 small-sample adjustment
# (n_eff-1)/(n_eff-k) divides by zero on a saturated design (n == rank), so
# routing the degenerate case here keeps the all-or-nothing NaN contract.
beta_hat, _, _ = _solve_ols_linalg(X, y, return_vcov=False)
original_coef = float(beta_hat[coefficient_index])
if not np.isfinite(original_coef):
return _degenerate()
kept = np.isfinite(beta_hat)
if not bool(kept[coefficient_index]):
return _degenerate()
X_eff = X[:, kept]
j_eff = int(np.sum(kept[:coefficient_index])) # position of the coef among kept columns
k_eff = X_eff.shape[1]
if n <= k_eff: # no residual degrees of freedom -> CR1 undefined
return _degenerate()
# Now the cluster-robust (CR1) vcov is well-defined; it studentizes the test.
_, _, vcov_original = _solve_ols_linalg(X, y, cluster_ids=cluster_ids, return_vcov=True)
if vcov_original is None:
return _degenerate()
se_a = float(np.sqrt(vcov_original[coefficient_index, coefficient_index]))
if not np.isfinite(se_a) or se_a <= 0:
return _degenerate()
# Projections on the (full-rank) effective design.
XtX_inv = np.linalg.inv(X_eff.T @ X_eff)
a_vec = X_eff @ XtX_inv[:, j_eff] # influence of each obs on beta_j: beta*_j = a_vec . y*
proj = XtX_inv @ X_eff.T # (k_eff, n) OLS projection onto coefficients
# Restricted residualization imposing H0: regress y and x_j on X_eff \ {col j}.
# The restricted residuals u(r) = M_{-j} y - r * M_{-j} x_j are linear in the
# candidate null r, so the whole test can be re-evaluated at any r cheaply.
xj = X_eff[:, j_eff]
X_reduced = np.delete(X_eff, j_eff, axis=1)
if X_reduced.shape[1] == 0:
# Single-regressor design: the reduced model has no regressors, so the
# restricted fit is identically 0 and the residuals are the variables
# themselves (solve_ols cannot fit a zero-column design).
fit_y_red = np.zeros(n)
fit_xj_red = np.zeros(n)
else:
_, _, fit_y_red, _ = _solve_ols_linalg(X_reduced, y, return_vcov=False, return_fitted=True)
_, _, fit_xj_red, _ = _solve_ols_linalg(
X_reduced, xj, return_vcov=False, return_fitted=True
)
m_y = y - fit_y_red
m_xj = xj - fit_xj_red
# CR1 small-sample correction. NOTE: this constant cancels in |t*| vs |t0|
# (it scales se* and se_a identically), so it affects only the reported SE,
# not the p-value or CI. Kept for fidelity with the analytical CR1 SE.
corr = (n_clusters / (n_clusters - 1)) * ((n - 1) / (n - k_eff))
# Cluster membership: indicator matrix C (G, n) for fast per-cluster score sums.
cluster_pos = {c: i for i, c in enumerate(unique_clusters)}
cl_idx = np.array([cluster_pos[c] for c in cluster_ids])
cluster_indicator = np.zeros((n_clusters, n))
cluster_indicator[cl_idx, np.arange(n)] = 1.0
# Fixed bootstrap weights, held constant across the whole test inversion so
# that p(r) is a stable (monotone, step) function amenable to root-finding.
weights = _wild_weight_matrix(n_clusters, n_bootstrap, weight_type, rng)
n_boot_eff = int(weights.shape[0])
weights_obs = weights[:, cl_idx] # (B, n)
def _t_star(r: float) -> np.ndarray:
"""Studentized bootstrap statistics t*(r) under H0: beta_j = r."""
u_r = m_y - r * m_xj # restricted residuals at r (n,)
# WCR DGP: y* = fitted_restricted + u_r * w = (y - u_r) + u_r * w_obs.
y_star = y[None, :] - u_r[None, :] * (1.0 - weights_obs) # (B, n)
beta_j_star = y_star @ a_vec # (B,)
coef_full = proj @ y_star.T # (k_eff, B)
resid_star = y_star.T - X_eff @ coef_full # (n, B) bootstrap residuals
scores = cluster_indicator @ (a_vec[:, None] * resid_star) # (G, B) per-cluster scores
se_star = np.sqrt(corr * np.sum(scores**2, axis=0)) # (B,)
with np.errstate(divide="ignore", invalid="ignore"):
t = (beta_j_star - r) / se_star
t[~(se_star > 0)] = np.nan
return t
t_star = _t_star(null_hypothesis)
finite = np.isfinite(t_star)
n_valid = int(finite.sum())
if n_valid < 2:
return _degenerate()
t_star_valid = t_star[finite]
t0 = (original_coef - null_hypothesis) / se_a
# Strict-inequality tail counts, matching fwildclusterboot/boottest: a
# bootstrap statistic is counted only if it *exceeds* the observed one. In
# the fully-enumerated few-cluster case the all-(+1) / all-(-1) sign-vectors
# reproduce t* = +/- t0 exactly (the observed draw and its mirror); strict
# ">" excludes those boundary ties, as boottest does. The small relative
# guard (~1e-9) makes the exclusion robust to floating-point noise from the
# fast-form path so a true tie never sneaks in as a strict exceedance.
def _frac_gt(vals: np.ndarray, thresh: float) -> float:
return float(np.mean(vals > thresh + 1e-9 * max(1.0, abs(thresh))))
def _frac_lt(vals: np.ndarray, thresh: float) -> float:
return float(np.mean(vals < thresh - 1e-9 * max(1.0, abs(thresh))))
# p-value at the test null (two-tailed on |t*|, or equal-tailed).
if p_val_type == "two-tailed":
raw_p = _frac_gt(np.abs(t_star_valid), abs(t0))
else:
p_low = _frac_lt(t_star_valid, t0)
p_up = _frac_gt(t_star_valid, t0)
raw_p = 2.0 * min(p_low, p_up)
# Floor the reported p-value to avoid an exact zero (a documented departure
# from boottest, which can report p == 0) — but NEVER let the floor reach
# the significance level. With very few valid draws 1/(n_valid+1) can exceed
# alpha, and flooring there would flip a bootstrap-significant result (0
# outside the inverted CI) to "non-significant", re-creating the very
# p-vs-CI contradiction this estimator fixes. When the floor would cross
# alpha we report the raw p-value (which is < alpha in exactly those cases),
# so the significance verdict always agrees with the inverted CI.
floor = 1.0 / (n_valid + 1)
p_value = max(raw_p, floor) if floor < alpha else raw_p
p_value = float(min(1.0, p_value))
# ---- Confidence interval by test inversion ------------------------------
# The CI is the set of nulls r not rejected at level alpha. The relevant
# rejection frequency is monotonically decreasing as r moves away from the
# point estimate, so each endpoint is found by outward bracketing + plain
# bisection — robust to the step-function nature of a finite bootstrap
# (unlike brentq, which assumes a continuous sign change).
def _reject_two_tailed(r: float) -> float:
t = _t_star(r)
t = t[np.isfinite(t)]
if t.size < 2:
return 0.0
t0_r = (original_coef - r) / se_a
return _frac_gt(np.abs(t), abs(t0_r))
def _tail_freq(r: float, upper: bool) -> float:
t = _t_star(r)
t = t[np.isfinite(t)]
if t.size < 2:
return 0.0
t0_r = (original_coef - r) / se_a
return _frac_gt(t, t0_r) if upper else _frac_lt(t, t0_r)
def _bisect(f: Any, level: float, direction: int) -> float:
# f(center) >= level; search outward (direction -1 lower, +1 upper) for
# the crossing f(r) = level (f decreasing in |r - center|), then bisect.
center = original_coef
scale = se_a if se_a > 0 else 1.0
step = scale
hi = center + direction * step
bracketed = False
for _ in range(64):
if f(hi) < level:
bracketed = True
break
step *= 2.0
hi = center + direction * step
if not bracketed:
# The test never rejects arbitrarily far out: the inverted CI is
# genuinely unbounded on this side. Represent it with a signed
# infinity (NOT NaN) so the (se, t, p, CI) inference family stays
# internally consistent — 0 still lies inside an unbounded interval
# exactly when the test fails to reject it, preserving
# 0 ∈ CI ⟺ p ≥ alpha.
return float(direction) * np.inf
lo = center # f(lo) >= level, f(hi) < level
for _ in range(100):
mid = 0.5 * (lo + hi)
if f(mid) >= level:
lo = mid
else:
hi = mid
if abs(hi - lo) <= 1e-10 * max(1.0, abs(center)):
break
return 0.5 * (lo + hi)
if p_val_type == "two-tailed":
ci_lower = _bisect(_reject_two_tailed, alpha, -1)
ci_upper = _bisect(_reject_two_tailed, alpha, +1)
else:
# equal-tailed: lower endpoint where the upper-tail frequency hits
# alpha/2; upper endpoint where the lower-tail frequency hits alpha/2.
ci_lower = _bisect(lambda r: _tail_freq(r, True), alpha / 2.0, -1)
ci_upper = _bisect(lambda r: _tail_freq(r, False), alpha / 2.0, +1)
return WildBootstrapResults(
se=se_a,
p_value=p_value,
t_stat_original=t0,
ci_lower=float(ci_lower),
ci_upper=float(ci_upper),
n_clusters=n_clusters,
n_bootstrap=n_boot_eff,
weight_type=weight_type,
alpha=alpha,
p_val_type=p_val_type,
bootstrap_distribution=t_star_valid if return_distribution else None,
)