Skip to content

Commit 8d34a73

Browse files
committed
update trimming for external predictions
1 parent b7127c2 commit 8d34a73

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

doubleml/irm/cvar.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import copy
32
from sklearn.base import clone
43
from sklearn.utils import check_X_y
54
from sklearn.model_selection import StratifiedKFold, train_test_split
@@ -214,7 +213,10 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
214213
'targets': np.full(shape=self._dml_data.n_obs, fill_value=np.nan),
215214
'preds': np.full(shape=self._dml_data.n_obs, fill_value=np.nan)
216215
}
217-
m_hat = copy.deepcopy(g_hat)
216+
m_hat = {'models': None,
217+
'targets': np.full(shape=self._dml_data.n_obs, fill_value=np.nan),
218+
'preds': np.full(shape=self._dml_data.n_obs, fill_value=np.nan)
219+
}
218220

219221
# initialize models
220222
fitted_models = {}

doubleml/irm/iivm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
320320
return_models=return_models)
321321
_check_finite_predictions(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls)
322322
_check_is_propensity(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls, eps=1e-12)
323-
m_hat['preds'] = _trimm(m_hat['preds'], self.trimming_rule, self.trimming_threshold)
323+
# also trimm external predictions
324+
m_hat['preds'] = _trimm(m_hat['preds'], self.trimming_rule, self.trimming_threshold)
324325

325326
# nuisance r
326327
r0 = external_predictions['ml_r0'] is not None

doubleml/irm/irm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
315315
return_models=return_models)
316316
_check_finite_predictions(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls)
317317
_check_is_propensity(m_hat['preds'], self._learner['ml_m'], 'ml_m', smpls, eps=1e-12)
318-
m_hat['preds'] = _trimm(m_hat['preds'], self.trimming_rule, self.trimming_threshold)
318+
# also trimm external predictions
319+
m_hat['preds'] = _trimm(m_hat['preds'], self.trimming_rule, self.trimming_threshold)
319320

320321
psi_a, psi_b = self._score_elements(y, d,
321322
g_hat0['preds'], g_hat1['preds'], m_hat['preds'],

0 commit comments

Comments
 (0)