Skip to content

Commit 1bbdfa8

Browse files
igerberclaude
andcommitted
Use positional indices for cell membership, add duplicate-index test
Replace label-based index lookup with stable positional row tracking via _row_pos column, so duplicate DataFrame indices cannot break or mis-map cell aggregation. Add regression test verifying identical results with duplicated vs clean indices. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fe77d84 commit 1bbdfa8

2 files changed

Lines changed: 36 additions & 3 deletions

File tree

diff_diff/prep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,14 +1532,15 @@ def aggregate_survey(
15321532
y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars}
15331533

15341534
# --- Per-cell computation ---
1535-
grouped = data.groupby(by_cols, sort=True)
1535+
# Use stable positional indices (safe with duplicate DataFrame indices)
1536+
row_positions = np.arange(n_total)
1537+
grouped = data.assign(_row_pos=row_positions).groupby(by_cols, sort=True)
15361538
rows: List[Dict[str, Any]] = []
15371539
srs_cells: List[str] = []
15381540
zero_var_cells: List[str] = []
15391541

15401542
for cell_key, cell_df in grouped:
1541-
cell_idx = np.array(cell_df.index)
1542-
pos_idx = data.index.get_indexer(cell_idx)
1543+
pos_idx = cell_df["_row_pos"].values
15431544

15441545
# Boolean mask for full-design domain estimation
15451546
cell_mask = np.zeros(n_total, dtype=bool)

tests/test_prep.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,6 +2452,38 @@ def test_error_all_missing_grouping_keys(self, design):
24522452
survey_design=design_simple,
24532453
)
24542454

2455+
def test_duplicate_index(self):
2456+
"""Duplicate DataFrame indices do not break aggregation."""
2457+
rng = np.random.RandomState(77)
2458+
n = 40
2459+
data = pd.DataFrame(
2460+
{
2461+
"geo": np.repeat(["A", "B"], n // 2),
2462+
"time": np.tile(np.repeat([0, 1], n // 4), 2),
2463+
"wt": np.ones(n),
2464+
"y": rng.normal(10, 2, n),
2465+
}
2466+
)
2467+
# Create duplicate indices (e.g., from concat without reset_index)
2468+
data.index = list(range(n // 2)) * 2 # 0..19, 0..19
2469+
2470+
design = SurveyDesign(weights="wt")
2471+
panel_dup, _ = aggregate_survey(
2472+
data, by=["geo", "time"], outcomes="y", survey_design=design
2473+
)
2474+
2475+
# Compare against clean-index version
2476+
data_clean = data.reset_index(drop=True)
2477+
panel_clean, _ = aggregate_survey(
2478+
data_clean, by=["geo", "time"], outcomes="y", survey_design=design
2479+
)
2480+
2481+
# Results should be identical
2482+
np.testing.assert_allclose(
2483+
panel_dup["y_mean"].values, panel_clean["y_mean"].values, rtol=1e-12
2484+
)
2485+
np.testing.assert_allclose(panel_dup["y_se"].values, panel_clean["y_se"].values, rtol=1e-12)
2486+
24552487
def test_domain_estimation_preserves_full_design(self):
24562488
"""Full-design domain estimation accounts for PSUs outside the cell.
24572489

0 commit comments

Comments
 (0)