Skip to content

Commit c3b9da0

Browse files
authored
Merge pull request igerber#561 from igerber/perf/cs-bootstrap-chunking
perf(callaway-santanna): chunk multiplier-bootstrap weight generation to bound peak memory
2 parents 22580c8 + 173289b commit c3b9da0

8 files changed

Lines changed: 640 additions & 72 deletions

File tree

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3131
Korn-Graubard (1990), and Solon-Haider-Wooldridge (2015) to `docs/references.rst`.
3232

3333
### Changed
34+
- **CallawaySantAnna multiplier bootstrap now tiles weight generation over draws, cutting
35+
peak memory at large `n_units`.** The dense `(n_bootstrap × n_units)` multiplier-weight
36+
matrix (the dominant allocation for the default unit-level bootstrap — `cluster=None`,
37+
equivalently `cluster="unit"` — where each unit is its own
38+
PSU) is generated and consumed one draw-block at a time via the new
39+
`diff_diff/bootstrap_chunking.py` helper instead of being materialized in full. Measured peak
40+
RSS at 999 bootstrap reps drops ~79% at 500k units (11.6 GB → 2.4 GB) and ~68% at 1M units
41+
(10.8 GB → 3.4 GB); the previously out-of-reach millions-of-units × 999-rep regime now stays
42+
near the fit's memory floor. The weight *stream* is bit-identical on both backends (Rust
43+
absolute per-row seeding; NumPy in-order stream); end-to-end bootstrap SEs match to within
44+
floating-point reassociation of the BLAS reductions (~1 ULP, far below bootstrap Monte-Carlo
45+
error). Stratified survey designs (few PSUs) are unchanged (full generation + sliced blocks);
46+
see TODO.md for the deferred per-stratum tiling.
3447
- **`run_placebo_test`'s `fake_group` path now filters ever-treated units by default.** The
3548
dispatcher threads its `treatment` column into `placebo_group_test`, so the fake-group
3649
placebo runs on never-treated units only (a more-correct placebo). Calling

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ generic sparse-FE, QR+SVD rank-detection redundancy, `check_finite` bypass — m
6666
| `SpilloverDiD` sparse cKDTree path for the staggered nearest-treated-distance helper (mirrors the static helper's sparse branch). `_compute_nearest_treated_distance_staggered` always builds dense `(n_units, n_treated_by_onset)` matrices per cohort; add a sparse branch gated on `n > _CONLEY_SPARSE_N_THRESHOLD`. | `spillover.py` | Wave B | Mid | Low |
6767
| `HeterogeneousAdoptionDiD` Phase 3 Stute: Appendix-D vectorized form replaces the per-iteration OLS refit with a single precomputed `M = I - X(X'X)^{-1}X'` applied to `eps*eta` (~2× faster, functionally identical). Shipped the literal-refit form to match paper text. | `had_pretests.py::stute_test` | Phase 3 | Mid | Low |
6868
| Rust faer SVD ndarray-to-faer conversion overhead (minimal vs SVD cost). | `rust/src/linalg.rs:67` | #115 | Quick | Low |
69+
| CallawaySantAnna multiplier-bootstrap weight chunking covers the **unstratified** survey-PSU generation (the default unit-level bootstrap — `cluster=None`, equivalently `cluster="unit"` — the large-`n_units` OOM case). Two gaps remain: (1) EfficientDiD and HAD bootstraps still materialize the full `(n_bootstrap × n_units)` weight matrix — wire them through `diff_diff/bootstrap_chunking.py`; (2) the **stratified** survey-PSU generator (`generate_survey_multiplier_weights_batch`, per-stratum + lonely-PSU pooling + FPC) still materializes the full `(n_bootstrap × n_psu)` matrix (consumed via sliced blocks). Stratified designs have few PSUs so this rarely OOMs; tile per-stratum generation over draws (each stratum's draws are independent → contiguous draw-blocks reproduce the stream bit-identically) if a large-PSU stratified design hits memory. | `diff_diff/bootstrap_chunking.py::iter_survey_multiplier_weight_blocks`, `efficient_did_bootstrap.py`, `had.py` | follow-up | Mid | Low |
6970

7071
### Testing / docs
7172

diff_diff/bootstrap_chunking.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""Memory-bounded chunking for multiplier-bootstrap weight matrices.
2+
3+
The multiplier bootstrap perturbs cached influence functions with a dense
4+
``(n_bootstrap, n_units)`` weight matrix. At large ``n_units`` that matrix
5+
dominates peak memory (e.g. ``999 x 5_000_000 x 8`` bytes is ~40 GB). Every
6+
consumer is a left-multiply ``weights @ influence_vector`` whose result is small
7+
(``(n_bootstrap,)`` or ``(n_bootstrap, n_gt)``), so the bootstrap can be tiled
8+
over the *draw* dimension: generate and consume the weights in row-blocks of
9+
``B``, capping the live intermediate at ``(B, n_units)``. FLOPs are identical to
10+
the un-chunked path -- only the draw axis is tiled. The generated weight stream
11+
is *bit-identical* to the un-chunked matrix (see below); the downstream
12+
``weights @ influence`` matmuls go through BLAS, whose reduction order depends on
13+
the operand row-count, so the resulting statistics match the un-chunked path to
14+
within floating-point reassociation (typically <~1 ULP), far below bootstrap
15+
Monte-Carlo error -- not bit-for-bit.
16+
17+
Bit-identity of the weight *generation* is preserved on **both** backends:
18+
19+
- **Rust** seeds each row absolutely as ``base_seed + row_index``
20+
(``rust/src/bootstrap.rs``), so calling the generator per block with base seed
21+
``base_seed + chunk_start`` reproduces the exact un-chunked rows. Exactly one
22+
``rng.integers`` draw is consumed, matching the un-chunked wrapper.
23+
- The **NumPy** fallback draws the matrix row-major from the ``Generator``
24+
stream, so consuming it in contiguous, in-order blocks from the same generator
25+
reproduces the identical sequence.
26+
"""
27+
28+
from __future__ import annotations
29+
30+
from typing import Iterator, Optional, Tuple
31+
32+
import numpy as np
33+
34+
from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights
35+
from diff_diff.bootstrap_utils import generate_bootstrap_weights_batch_numpy
36+
37+
# Byte ceiling for a single ``(B, n_units)`` float64 weight block. 256 MB keeps
38+
# the live intermediate small at millions of units while staying large enough
39+
# that the per-block matmuls remain BLAS-efficient and chunk overhead (a handful
40+
# of extra Python iterations / FFI calls) is negligible.
41+
_TARGET_BLOCK_BYTES = 256 * 1024 * 1024
42+
43+
44+
def compute_block_size(
45+
n_units: int, n_bootstrap: int, target_bytes: int = _TARGET_BLOCK_BYTES
46+
) -> int:
47+
"""Number of bootstrap rows per block so a ``(B, n_units)`` float64 block
48+
stays under ``target_bytes``. Always in ``[1, n_bootstrap]``."""
49+
if n_units <= 0:
50+
return max(1, n_bootstrap)
51+
b = target_bytes // (n_units * 8)
52+
return int(max(1, min(max(1, n_bootstrap), b)))
53+
54+
55+
def iter_weight_blocks(
56+
n_bootstrap: int,
57+
n_gen: int,
58+
weight_type: str,
59+
rng: np.random.Generator,
60+
*,
61+
expand_index: Optional[np.ndarray] = None,
62+
block_size: Optional[int] = None,
63+
) -> Iterator[Tuple[int, np.ndarray]]:
64+
"""Yield ``(chunk_start, block)`` pairs covering all ``n_bootstrap`` draws.
65+
66+
``block`` has shape ``(B, width)`` where ``width = len(expand_index)`` when
67+
``expand_index`` is given, else ``n_gen``. Weights are generated at width
68+
``n_gen`` (unit / cluster / PSU level) and, when ``expand_index`` is given,
69+
expanded to unit level via ``block[:, expand_index]`` (cluster->unit or
70+
PSU->unit fan-out). The concatenation of all yielded blocks is bit-identical
71+
to a single ``generate_bootstrap_weights_batch(n_bootstrap, n_gen, ...)``
72+
followed by the same expansion.
73+
74+
Generation is in-order and stateful on ``rng`` (NumPy fallback) -- the caller
75+
must consume the iterator sequentially, which the chunk loop does.
76+
"""
77+
width = n_gen if expand_index is None else int(len(expand_index))
78+
if block_size is None:
79+
block_size = compute_block_size(width, n_bootstrap)
80+
if block_size < 1:
81+
raise ValueError(f"block_size must be >= 1, got {block_size}")
82+
83+
rust_gen = (
84+
_rust_bootstrap_weights
85+
if (HAS_RUST_BACKEND and _rust_bootstrap_weights is not None)
86+
else None
87+
)
88+
# Draw exactly one base seed (matching the un-chunked Rust wrapper); the
89+
# NumPy fallback consumes the rng stream directly per block instead.
90+
base_seed = int(rng.integers(0, 2**63 - 1)) if rust_gen is not None else 0
91+
92+
for chunk_start in range(0, n_bootstrap, block_size):
93+
rows = min(block_size, n_bootstrap - chunk_start)
94+
if rust_gen is not None:
95+
block = rust_gen(rows, n_gen, weight_type, base_seed + chunk_start)
96+
else:
97+
block = generate_bootstrap_weights_batch_numpy(rows, n_gen, weight_type, rng)
98+
if expand_index is not None:
99+
block = block[:, expand_index]
100+
yield chunk_start, block
101+
102+
103+
def iter_survey_multiplier_weight_blocks(
104+
n_bootstrap: int,
105+
resolved_survey: object,
106+
weight_type: str,
107+
rng: np.random.Generator,
108+
*,
109+
block_size: int,
110+
) -> Tuple[np.ndarray, Iterator[Tuple[int, np.ndarray]]]:
111+
"""Chunked PSU-level multiplier weights for the survey-aware bootstrap.
112+
113+
Returns ``(psu_ids, blocks)`` where ``blocks`` yields
114+
``(chunk_start, (B, n_psu))`` PSU-weight blocks covering all draws.
115+
116+
For UNSTRATIFIED designs (``strata is None``, ``n_psu >= 2``) the
117+
``(n_bootstrap, n_psu)`` matrix is generated one draw-block at a time via
118+
:func:`iter_weight_blocks` plus the unstratified FPC scalar -- bit-identical
119+
to the unstratified branch of
120+
:func:`diff_diff.bootstrap_utils.generate_survey_multiplier_weights_batch`,
121+
but the full matrix is never materialized. This is the path taken by
122+
``cluster="unit"`` (each unit its own PSU, ``n_psu == n_units``), the case
123+
that otherwise dominates bootstrap memory at large n_units.
124+
125+
Stratified designs (and the ``n_psu < 2`` degenerate case) fall back to full
126+
generation + sliced blocks: per-stratum / lonely-PSU generation is not tiled
127+
here, but stratified designs have few PSUs so the full matrix is small.
128+
"""
129+
from diff_diff.bootstrap_utils import generate_survey_multiplier_weights_batch
130+
131+
if block_size < 1:
132+
raise ValueError(f"block_size must be >= 1, got {block_size}")
133+
134+
psu = getattr(resolved_survey, "psu", None)
135+
strata = getattr(resolved_survey, "strata", None)
136+
if psu is None:
137+
n_psu = len(resolved_survey.weights) # type: ignore[attr-defined]
138+
psu_ids = np.arange(n_psu)
139+
else:
140+
psu_ids = np.unique(psu)
141+
n_psu = len(psu_ids)
142+
143+
if strata is not None or n_psu < 2:
144+
# Stratified or degenerate single-PSU: full generation (small here).
145+
weights, psu_ids = generate_survey_multiplier_weights_batch(
146+
n_bootstrap, resolved_survey, weight_type, rng
147+
)
148+
149+
def _sliced() -> Iterator[Tuple[int, np.ndarray]]:
150+
for chunk_start in range(0, n_bootstrap, block_size):
151+
yield chunk_start, weights[chunk_start : chunk_start + block_size]
152+
153+
return psu_ids, _sliced()
154+
155+
# Unstratified, n_psu >= 2: tile the generation over draws. Mirror the
156+
# unstratified FPC scaling from generate_survey_multiplier_weights_batch.
157+
fpc = getattr(resolved_survey, "fpc", None)
158+
fpc_scale = 1.0
159+
fpc_zero = False
160+
if fpc is not None:
161+
# psu=None already sets n_psu = len(weights), so n_units_for_fpc == n_psu
162+
# on both branches of the original generator.
163+
n_units_for_fpc = n_psu
164+
if fpc[0] < n_units_for_fpc:
165+
raise ValueError(
166+
f"FPC ({fpc[0]}) is less than the number of PSUs "
167+
f"({n_units_for_fpc}). FPC must be >= number of PSUs."
168+
)
169+
f = n_units_for_fpc / fpc[0]
170+
if f < 1.0:
171+
fpc_scale = float(np.sqrt(1.0 - f))
172+
else:
173+
fpc_zero = True
174+
175+
def _generated() -> Iterator[Tuple[int, np.ndarray]]:
176+
for chunk_start, block in iter_weight_blocks(
177+
n_bootstrap, n_psu, weight_type, rng, block_size=block_size
178+
):
179+
if fpc_zero:
180+
block = np.zeros_like(block)
181+
elif fpc_scale != 1.0:
182+
block = block * fpc_scale
183+
yield chunk_start, block
184+
185+
return psu_ids, _generated()

diff_diff/staggered.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,16 @@ class CallawaySantAnna(
282282
Recommended: 999 or more for reliable inference.
283283
284284
.. note:: Memory Usage
285-
The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
286-
float64 array. For large datasets, this can be significant:
287-
- 1K bootstrap × 10K units = ~80 MB
288-
- 10K bootstrap × 100K units = ~8 GB
289-
Consider reducing n_bootstrap if memory is constrained.
285+
Bootstrap multiplier weights are generated and consumed one
286+
draw-block at a time (see :mod:`diff_diff.bootstrap_chunking`), so the
287+
full ``(n_bootstrap, n_units)`` weight matrix is never materialized.
288+
The live weight intermediate is bounded by roughly
289+
``max(~256 MB, 8 * n_units)`` bytes -- a block holds at least one full
290+
draw row -- independent of ``n_bootstrap``. Only the small bootstrap
291+
*output* arrays (``(n_bootstrap, n_group_time)`` and ``(n_bootstrap,)``
292+
per aggregation) stay fully in memory. Stratified survey designs are
293+
the current exception (the full PSU-weight matrix is built up front,
294+
but PSUs are few).
290295
291296
bootstrap_weights : str, default="rademacher"
292297
Type of weights for multiplier bootstrap:
@@ -445,7 +450,6 @@ def __init__(
445450
pscore_fallback: str = "error",
446451
vcov_type: str = "hc1",
447452
):
448-
import warnings
449453

450454
if control_group not in ["never_treated", "not_yet_treated"]:
451455
raise ValueError(

0 commit comments

Comments
 (0)