Skip to content

Commit ed50c11

Browse files
committed
docs: clarify DiD predict support contract
1 parent 7a7debf commit ed50c11

3 files changed

Lines changed: 37 additions & 7 deletions

File tree

diff_diff/estimators.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -968,26 +968,39 @@ def _validate_data(
968968

969969
def predict(self, data: pd.DataFrame) -> np.ndarray:
970970
"""
971-
Predict outcomes using fitted model.
971+
Predict outcomes using the fitted model.
972+
973+
Out-of-sample prediction is intentionally unsupported pending a broader
974+
post-estimation design for estimator result objects. For fitted
975+
training-data predictions, use ``results_.fitted_values`` after
976+
:meth:`fit`.
972977
973978
Parameters
974979
----------
975980
data : pd.DataFrame
976-
DataFrame with same structure as training data.
981+
Candidate prediction data. Currently unused because out-of-sample
982+
prediction is unsupported.
977983
978984
Returns
979985
-------
980986
np.ndarray
981987
Predicted values.
988+
989+
Raises
990+
------
991+
RuntimeError
992+
If called before :meth:`fit`.
993+
NotImplementedError
994+
Always raised after fitting until the broader post-estimation
995+
prediction contract is designed.
982996
"""
983997
if not self.is_fitted_:
984998
raise RuntimeError("Model must be fitted before calling predict()")
985999

986-
# This is a placeholder - would need to store column names
987-
# for full implementation
9881000
raise NotImplementedError(
989-
"predict() is not yet implemented. "
990-
"Use results_.fitted_values for training data predictions."
1001+
"out-of-sample predict() is unsupported pending a broader "
1002+
"post-estimation design. Use results_.fitted_values for fitted "
1003+
"training-data predictions."
9911004
)
9921005

9931006
def get_params(self) -> Dict[str, Any]:

docs/api/estimators.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ DifferenceInDifferences (alias: ``DiD``)
3030

3131
Basic 2x2 DiD estimator.
3232

33+
``DifferenceInDifferences.predict()`` is present for sklearn-like
34+
discoverability, but out-of-sample prediction is not currently supported. Use
35+
``results_.fitted_values`` for fitted training-data predictions until a broader
36+
post-estimation result-object contract is designed.
37+
3338
.. autoclass:: diff_diff.DifferenceInDifferences
3439
:no-index:
3540
:members:
@@ -42,6 +47,7 @@ Basic 2x2 DiD estimator.
4247
.. autosummary::
4348

4449
~DifferenceInDifferences.fit
50+
~DifferenceInDifferences.predict
4551
~DifferenceInDifferences.get_params
4652
~DifferenceInDifferences.set_params
4753

@@ -84,4 +90,3 @@ Synthetic control combined with DiD (Arkhangelsky et al. 2021).
8490
:undoc-members:
8591
:show-inheritance:
8692
:inherited-members:
87-

tests/test_methodology_did.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,18 @@ def test_residuals_and_fitted_values(self):
15501550
assert np.allclose(reconstructed, original), \
15511551
"Residuals + fitted should equal original outcome"
15521552

1553+
def test_predict_contract_points_to_fitted_values(self):
1554+
"""predict() is intentionally unsupported until post-estimation is designed."""
1555+
data, _ = generate_hand_calculable_data()
1556+
1557+
did = DifferenceInDifferences()
1558+
did.fit(data, outcome='outcome', treatment='treated', time='post')
1559+
1560+
with pytest.raises(
1561+
NotImplementedError,
1562+
match="out-of-sample.*post-estimation.*results_\\.fitted_values",
1563+
):
1564+
did.predict(data)
15531565

15541566
# =============================================================================
15551567
# Multi-absorb (N>1 FE) iterative alternating-projection demeaning

0 commit comments

Comments
 (0)