Skip to content

Commit 551ddfd

Browse files
authored
Merge pull request #158 from DoubleML/m-nonlinear-score-mixin
Score mixin classes: `LinearScoreMixin` and `NonLinearScoreMixin`
2 parents 7924273 + 0965951 commit 551ddfd

11 files changed

+679
-145
lines changed

doc/api/api.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,16 @@ Dataset generators
5656
datasets.make_irm_data
5757
datasets.make_iivm_data
5858
datasets.make_plr_turrell2018
59-
datasets.make_pliv_multiway_cluster_CKMS2021
59+
datasets.make_pliv_multiway_cluster_CKMS2021
60+
61+
Score mixin classes for double machine learning models
62+
------------------------------------------------------
63+
64+
.. currentmodule:: doubleml
65+
66+
.. autosummary::
67+
:toctree: generated/
68+
:template: class.rst
69+
70+
double_ml_score_mixins.LinearScoreMixin
71+
double_ml_score_mixins.NonLinearScoreMixin

doubleml/double_ml.py

Lines changed: 103 additions & 110 deletions
Large diffs are not rendered by default.

doubleml/double_ml_iivm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from .double_ml import DoubleML
66
from .double_ml_data import DoubleMLData
7+
from .double_ml_score_mixins import LinearScoreMixin
78
from ._utils import _dml_cv_predict, _get_cond_smpls, _dml_tune, _check_finite_predictions
89

910

10-
class DoubleMLIIVM(DoubleML):
11+
class DoubleMLIIVM(LinearScoreMixin, DoubleML):
1112
"""Double machine learning for interactive IV regression models
1213
1314
Parameters
@@ -290,6 +291,8 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
290291
psi_a, psi_b = self._score_elements(y, z, d,
291292
g_hat0['preds'], g_hat1['preds'], m_hat['preds'],
292293
r_hat0['preds'], r_hat1['preds'], smpls)
294+
psi_elements = {'psi_a': psi_a,
295+
'psi_b': psi_b}
293296
preds = {'predictions': {'ml_g0': g_hat0['preds'],
294297
'ml_g1': g_hat1['preds'],
295298
'ml_m': m_hat['preds'],
@@ -302,7 +305,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
302305
'ml_r1': r_hat1['models']}
303306
}
304307

305-
return psi_a, psi_b, preds
308+
return psi_elements, preds
306309

307310
def _score_elements(self, y, z, d, g_hat0, g_hat1, m_hat, r_hat0, r_hat1, smpls):
308311
# compute residuals

doubleml/double_ml_irm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from .double_ml import DoubleML
66
from .double_ml_data import DoubleMLData
7+
from .double_ml_score_mixins import LinearScoreMixin
78
from ._utils import _dml_cv_predict, _get_cond_smpls, _dml_tune, _check_finite_predictions
89

910

10-
class DoubleMLIRM(DoubleML):
11+
class DoubleMLIRM(LinearScoreMixin, DoubleML):
1112
"""Double machine learning for interactive regression models
1213
1314
Parameters
@@ -225,6 +226,8 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
225226
psi_a, psi_b = self._score_elements(y, d,
226227
g_hat0['preds'], g_hat1['preds'], m_hat['preds'],
227228
smpls)
229+
psi_elements = {'psi_a': psi_a,
230+
'psi_b': psi_b}
228231
preds = {'predictions': {'ml_g0': g_hat0['preds'],
229232
'ml_g1': g_hat1['preds'],
230233
'ml_m': m_hat['preds']},
@@ -233,7 +236,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
233236
'ml_m': m_hat['models']}
234237
}
235238

236-
return psi_a, psi_b, preds
239+
return psi_elements, preds
237240

238241
def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls):
239242
# fraction of treated for ATTE

doubleml/double_ml_pliv.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .double_ml import DoubleML
1212
from .double_ml_data import DoubleMLData
13+
from .double_ml_score_mixins import LinearScoreMixin
1314
from ._utils import _dml_cv_predict, _dml_tune, _check_finite_predictions
1415

1516

@@ -29,7 +30,7 @@ def wrapper(*args, **kwds):
2930
return wrapper
3031

3132

32-
class DoubleMLPLIV(DoubleML):
33+
class DoubleMLPLIV(LinearScoreMixin, DoubleML):
3334
"""Double machine learning for partially linear IV regression models
3435
3536
Parameters
@@ -310,14 +311,14 @@ def set_ml_nuisance_params(self, learner, treat_var, params):
310311

311312
def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
312313
if self.partialX & (not self.partialZ):
313-
psi_a, psi_b, preds = self._nuisance_est_partial_x(smpls, n_jobs_cv, return_models)
314+
psi_elements, preds = self._nuisance_est_partial_x(smpls, n_jobs_cv, return_models)
314315
elif (not self.partialX) & self.partialZ:
315-
psi_a, psi_b, preds = self._nuisance_est_partial_z(smpls, n_jobs_cv, return_models)
316+
psi_elements, preds = self._nuisance_est_partial_z(smpls, n_jobs_cv, return_models)
316317
else:
317318
assert (self.partialX & self.partialZ)
318-
psi_a, psi_b, preds = self._nuisance_est_partial_xz(smpls, n_jobs_cv, return_models)
319+
psi_elements, preds = self._nuisance_est_partial_xz(smpls, n_jobs_cv, return_models)
319320

320-
return psi_a, psi_b, preds
321+
return psi_elements, preds
321322

322323
def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
323324
search_mode, n_iter_randomized_search):
@@ -391,6 +392,8 @@ def _nuisance_est_partial_x(self, smpls, n_jobs_cv, return_models=False):
391392
psi_a, psi_b = self._score_elements(y, z, d,
392393
l_hat['preds'], m_hat['preds'], r_hat['preds'], g_hat['preds'],
393394
smpls)
395+
psi_elements = {'psi_a': psi_a,
396+
'psi_b': psi_b}
394397
preds = {'predictions': {'ml_l': l_hat['preds'],
395398
'ml_m': m_hat['preds'],
396399
'ml_r': r_hat['preds'],
@@ -401,7 +404,7 @@ def _nuisance_est_partial_x(self, smpls, n_jobs_cv, return_models=False):
401404
'ml_g': g_hat['models']}
402405
}
403406

404-
return psi_a, psi_b, preds
407+
return psi_elements, preds
405408

406409
def _score_elements(self, y, z, d, l_hat, m_hat, r_hat, g_hat, smpls):
407410
# compute residuals
@@ -463,10 +466,12 @@ def _nuisance_est_partial_z(self, smpls, n_jobs_cv, return_models=False):
463466
assert callable(self.score)
464467
raise NotImplementedError('Callable score not implemented for DoubleMLPLIV.partialZ.')
465468

469+
psi_elements = {'psi_a': psi_a,
470+
'psi_b': psi_b}
466471
preds = {'predictions': {'ml_r': r_hat['preds']},
467472
'models': {'ml_r': r_hat['models']}}
468473

469-
return psi_a, psi_b, preds
474+
return psi_elements, preds
470475

471476
def _nuisance_est_partial_xz(self, smpls, n_jobs_cv, return_models=False):
472477
x, y = check_X_y(self._dml_data.x, self._dml_data.y,
@@ -507,6 +512,8 @@ def _nuisance_est_partial_xz(self, smpls, n_jobs_cv, return_models=False):
507512
assert callable(self.score)
508513
raise NotImplementedError('Callable score not implemented for DoubleMLPLIV.partialXZ.')
509514

515+
psi_elements = {'psi_a': psi_a,
516+
'psi_b': psi_b}
510517
preds = {'predictions': {'ml_l': l_hat['preds'],
511518
'ml_m': m_hat['preds'],
512519
'ml_r': m_hat_tilde['preds']},
@@ -515,7 +522,7 @@ def _nuisance_est_partial_xz(self, smpls, n_jobs_cv, return_models=False):
515522
'ml_r': m_hat_tilde['models']}
516523
}
517524

518-
return psi_a, psi_b, preds
525+
return psi_elements, preds
519526

520527
# To be removed in version 0.6.0
521528
def tune(self,

doubleml/double_ml_plr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from .double_ml import DoubleML
1010
from .double_ml_data import DoubleMLData
11+
from .double_ml_score_mixins import LinearScoreMixin
1112
from ._utils import _dml_cv_predict, _dml_tune, _check_finite_predictions
1213

1314

@@ -27,7 +28,7 @@ def wrapper(*args, **kwds):
2728
return wrapper
2829

2930

30-
class DoubleMLPLR(DoubleML):
31+
class DoubleMLPLR(LinearScoreMixin, DoubleML):
3132
"""Double machine learning for partially linear regression models
3233
3334
Parameters
@@ -242,14 +243,16 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
242243
_check_finite_predictions(g_hat['preds'], self._learner['ml_g'], 'ml_g', smpls)
243244

244245
psi_a, psi_b = self._score_elements(y, d, l_hat['preds'], m_hat['preds'], g_hat['preds'], smpls)
246+
psi_elements = {'psi_a': psi_a,
247+
'psi_b': psi_b}
245248
preds = {'predictions': {'ml_l': l_hat['preds'],
246249
'ml_m': m_hat['preds'],
247250
'ml_g': g_hat['preds']},
248251
'models': {'ml_l': l_hat['models'],
249252
'ml_m': m_hat['models'],
250253
'ml_g': g_hat['models']}}
251254

252-
return psi_a, psi_b, preds
255+
return psi_elements, preds
253256

254257
def _score_elements(self, y, d, l_hat, m_hat, g_hat, smpls):
255258
# compute residuals

0 commit comments

Comments
 (0)