Skip to content

Commit 3870ac7

Browse files
authored
Merge pull request #182 from DoubleML/s-add-rmse
Add RMSEs and targets
2 parents 32298ee + 10b180e commit 3870ac7

15 files changed

+349
-34
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ share/python-wheels/
2525
*.egg-info/
2626
.installed.cfg
2727
*.egg
28+
*.vscode
2829
MANIFEST
2930
*.idea
3031
*.vscode

doubleml/_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def _dml_cv_predict(estimator, x, y, smpls=None,
111111
res['preds'] = preds[:, 1]
112112
else:
113113
res['preds'] = preds
114+
res['targets'] = y
114115
else:
115116
if not smpls_is_partition:
116117
assert not fold_specific_target, 'combination of fold-specific y and no cross-fitting not implemented yet'
@@ -150,7 +151,9 @@ def _dml_cv_predict(estimator, x, y, smpls=None,
150151
for idx, (train_index, test_index) in enumerate(smpls))
151152

152153
preds = np.full(n_obs, np.nan)
154+
targets = np.full(n_obs, np.nan)
153155
train_preds = list()
156+
train_targets = list()
154157
for idx, (train_index, test_index) in enumerate(smpls):
155158
assert idx == fitted_models[idx][1]
156159
pred_fun = getattr(fitted_models[idx][0], method)
@@ -159,12 +162,21 @@ def _dml_cv_predict(estimator, x, y, smpls=None,
159162
else:
160163
preds[test_index] = pred_fun(x[test_index, :])
161164

165+
if fold_specific_target:
166+
# targets not available for fold specific target
167+
targets = None
168+
else:
169+
targets[test_index] = y[test_index]
170+
162171
if return_train_preds:
163172
train_preds.append(pred_fun(x[train_index, :]))
173+
train_targets.append(y[train_index])
164174

165175
res['preds'] = preds
176+
res['targets'] = targets
166177
if return_train_preds:
167178
res['train_preds'] = train_preds
179+
res['train_targets'] = train_targets
168180
if return_models:
169181
fold_ids = [xx[1] for xx in fitted_models]
170182
if not np.alltrue(fold_ids == np.arange(len(smpls))):
@@ -222,4 +234,4 @@ def _check_is_propensity(preds, learner, learner_name, smpls, eps=1e-12):
222234
if any((preds[test_indices] < eps) | (preds[test_indices] > 1 - eps)):
223235
warnings.warn(f'Propensity predictions from learner {str(learner)} for'
224236
f' {learner_name} are close to zero or one (eps={eps}).')
225-
return
237+
return

doubleml/double_ml.py

Lines changed: 114 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44

55
from sklearn.base import is_regressor, is_classifier
6+
from sklearn.metrics import mean_squared_error
67

78
from scipy.stats import norm
89

@@ -45,8 +46,10 @@ def __init__(self,
4546
self._learner = None
4647
self._params = None
4748

48-
# initialize predictions to None which are only stored if method fit is called with store_predictions=True
49+
# initialize predictions and target to None which are only stored if method fit is called with store_predictions=True
4950
self._predictions = None
51+
self._nuisance_targets = None
52+
self._rmses = None
5053

5154
# initialize models to None which are only stored if method fit is called with store_models=True
5255
self._models = None
@@ -129,6 +132,11 @@ def __str__(self):
129132
learner_info = ''
130133
for key, value in self.learner.items():
131134
learner_info += f'Learner {key}: {str(value)}\n'
135+
if self.rmses is not None:
136+
learner_info += 'Out-of-sample Performance:\n'
137+
for learner in self.params_names:
138+
learner_info += f'Learner {learner} RMSE: {self.rmses[learner]}\n'
139+
132140
if self._is_cluster_data:
133141
resampling_info = f'No. folds per cluster: {self._n_folds_per_cluster}\n' \
134142
f'No. folds: {self.n_folds}\n' \
@@ -231,6 +239,20 @@ def predictions(self):
231239
"""
232240
return self._predictions
233241

242+
@property
243+
def nuisance_targets(self):
244+
"""
245+
The outcome of the nuisance models.
246+
"""
247+
return self._nuisance_targets
248+
249+
@property
250+
def rmses(self):
251+
"""
252+
The root-mean-squared-errors of the nuisance models.
253+
"""
254+
return self._rmses
255+
234256
@property
235257
def models(self):
236258
"""
@@ -434,7 +456,7 @@ def __psi_deriv(self):
434456
def __all_se(self):
435457
return self._all_se[self._i_treat, self._i_rep]
436458

437-
def fit(self, n_jobs_cv=None, store_predictions=False, store_models=False):
459+
def fit(self, n_jobs_cv=None, store_predictions=True, store_models=False):
438460
"""
439461
Estimate DoubleML models.
440462
@@ -471,8 +493,11 @@ def fit(self, n_jobs_cv=None, store_predictions=False, store_models=False):
471493
raise TypeError('store_models must be True or False. '
472494
f'Got {str(store_models)}.')
473495

496+
# initialize rmse arrays for nuisance functions evaluation
497+
self._initialize_rmses()
498+
474499
if store_predictions:
475-
self._initialize_predictions()
500+
self._initialize_predictions_and_targets()
476501

477502
if store_models:
478503
self._initialize_models()
@@ -491,8 +516,10 @@ def fit(self, n_jobs_cv=None, store_predictions=False, store_models=False):
491516

492517
self._set_score_elements(score_elements, self._i_rep, self._i_treat)
493518

519+
# calculate rmses and store predictions and targets of the nuisance models
520+
self._calc_rmses(preds['predictions'], preds['targets'])
494521
if store_predictions:
495-
self._store_predictions(preds['predictions'])
522+
self._store_predictions_and_targets(preds['predictions'], preds['targets'])
496523
if store_models:
497524
self._store_models(preds['models'])
498525

@@ -990,22 +1017,103 @@ def _initialize_boot_arrays(self, n_rep_boot):
9901017
boot_t_stat = np.full((self._dml_data.n_coefs, n_rep_boot * self.n_rep), np.nan)
9911018
return n_rep_boot, boot_coef, boot_t_stat
9921019

993-
def _initialize_predictions(self):
1020+
def _initialize_predictions_and_targets(self):
9941021
self._predictions = {learner: np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_coefs), np.nan)
9951022
for learner in self.params_names}
1023+
self._nuisance_targets = {learner: np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_coefs), np.nan)
1024+
for learner in self.params_names}
1025+
1026+
def _initialize_rmses(self):
1027+
self._rmses = {learner: np.full((self.n_rep, self._dml_data.n_coefs), np.nan)
1028+
for learner in self.params_names}
9961029

9971030
def _initialize_models(self):
9981031
self._models = {learner: {treat_var: [None] * self.n_rep for treat_var in self._dml_data.d_cols}
9991032
for learner in self.params_names}
10001033

1001-
def _store_predictions(self, preds):
1034+
def _store_predictions_and_targets(self, preds, targets):
10021035
for learner in self.params_names:
10031036
self._predictions[learner][:, self._i_rep, self._i_treat] = preds[learner]
1037+
self._nuisance_targets[learner][:, self._i_rep, self._i_treat] = targets[learner]
1038+
1039+
def _calc_rmses(self, preds, targets):
1040+
for learner in self.params_names:
1041+
if targets[learner] is None:
1042+
self._rmses[learner][self._i_rep, self._i_treat] = np.nan
1043+
else:
1044+
sq_error = np.power(targets[learner] - preds[learner], 2)
1045+
self._rmses[learner][self._i_rep, self._i_treat] = np.sqrt(np.mean(sq_error, 0))
10041046

10051047
def _store_models(self, models):
10061048
for learner in self.params_names:
10071049
self._models[learner][self._dml_data.d_cols[self._i_treat]][self._i_rep] = models[learner]
10081050

1051+
def evaluate_learners(self, learners=None, metric=mean_squared_error):
1052+
"""
1053+
Evaluate fitted learners for DoubleML models on cross-validated predictions.
1054+
1055+
Parameters
1056+
----------
1057+
learners : list
1058+
A list of strings which correspond to the nuisance functions of the model.
1059+
1060+
metric : callable
1061+
A callable function with inputs ``y_pred`` and ``y_true`` of shape ``(1, n)``,
1062+
where ``n`` specifies the number of observations.
1063+
Default is the euclidean distance.
1064+
1065+
Returns
1066+
-------
1067+
dist : dict
1068+
A dictionary containing the evaluated metric for each learner.
1069+
1070+
Examples
1071+
--------
1072+
>>> import numpy as np
1073+
>>> import doubleml as dml
1074+
>>> from sklearn.metrics import mean_absolute_error
1075+
>>> from doubleml.datasets import make_irm_data
1076+
>>> from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
1077+
>>> np.random.seed(3141)
1078+
>>> ml_g = RandomForestRegressor(n_estimators=100, max_features=20, max_depth=5, min_samples_leaf=2)
1079+
>>> ml_m = RandomForestClassifier(n_estimators=100, max_features=20, max_depth=5, min_samples_leaf=2)
1080+
>>> data = make_irm_data(theta=0.5, n_obs=500, dim_x=20, return_type='DataFrame')
1081+
>>> obj_dml_data = dml.DoubleMLData(data, 'y', 'd')
1082+
>>> dml_irm_obj = dml.DoubleMLIRM(obj_dml_data, ml_g, ml_m)
1083+
>>> dml_irm_obj.fit()
1084+
>>> dml_irm_obj.evaluate_learners(metric=mean_absolute_error)
1085+
{'ml_g0': array([[1.13318973]]),
1086+
'ml_g1': array([[0.91659939]]),
1087+
'ml_m': array([[0.36350912]])}
1088+
"""
1089+
# if no learners are provided try to evaluate all learners
1090+
if learners is None:
1091+
learners = self.params_names
1092+
1093+
# check metric
1094+
if not callable(metric):
1095+
raise TypeError('metric should be a callable. '
1096+
'%r was passed.' % metric)
1097+
1098+
if all(learner in self.params_names for learner in learners):
1099+
if self.nuisance_targets is None:
1100+
raise ValueError('Apply fit() before evaluate_learners().')
1101+
else:
1102+
dist = {learner: np.full((self.n_rep, self._dml_data.n_coefs), np.nan)
1103+
for learner in learners}
1104+
for learner in learners:
1105+
for rep in range(self.n_rep):
1106+
for coef_idx in range(self._dml_data.n_coefs):
1107+
res = metric(y_pred=self.predictions[learner][:, rep, coef_idx].reshape(1, -1),
1108+
y_true=self.nuisance_targets[learner][:, rep, coef_idx].reshape(1, -1))
1109+
if not np.isfinite(res):
1110+
raise ValueError(f'Evaluation from learner {str(learner)} is not finite.')
1111+
dist[learner][rep, coef_idx] = res
1112+
return dist
1113+
else:
1114+
raise ValueError(f'The learners have to be a subset of {str(self.params_names)}. '
1115+
f'Learners {str(learners)} provided.')
1116+
10091117
def draw_sample_splitting(self):
10101118
"""
10111119
Draw sample splitting for DoubleML models.

doubleml/double_ml_iivm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class DoubleMLIIVM(LinearScoreMixin, DoubleML):
6161
6262
trimming_threshold : float
6363
The threshold used for trimming.
64-
Default is ``1e-12``.
64+
Default is ``1e-2``.
6565
6666
draw_sample_splitting : bool
6767
Indicates whether the sample splitting should be drawn during initialization of the object.
@@ -129,7 +129,7 @@ def __init__(self,
129129
subgroups=None,
130130
dml_procedure='dml2',
131131
trimming_rule='truncate',
132-
trimming_threshold=1e-12,
132+
trimming_threshold=1e-2,
133133
draw_sample_splitting=True,
134134
apply_cross_fitting=True):
135135
super().__init__(obj_dml_data,
@@ -282,15 +282,15 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
282282
est_params=self._get_params('ml_r0'), method=self._predict_method['ml_r'],
283283
return_models=return_models)
284284
else:
285-
r_hat0 = {'preds': np.zeros_like(d), 'models': None}
285+
r_hat0 = {'preds': np.zeros_like(d), 'targets': np.zeros_like(d), 'models': None}
286286
_check_finite_predictions(r_hat0['preds'], self._learner['ml_r'], 'ml_r', smpls)
287287

288288
if self.subgroups['never_takers']:
289289
r_hat1 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z1, n_jobs=n_jobs_cv,
290290
est_params=self._get_params('ml_r1'), method=self._predict_method['ml_r'],
291291
return_models=return_models)
292292
else:
293-
r_hat1 = {'preds': np.ones_like(d), 'models': None}
293+
r_hat1 = {'preds': np.ones_like(d), 'targets': np.ones_like(d), 'models': None}
294294
_check_finite_predictions(r_hat1['preds'], self._learner['ml_r'], 'ml_r', smpls)
295295

296296
psi_a, psi_b = self._score_elements(y, z, d,
@@ -303,6 +303,11 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
303303
'ml_m': m_hat['preds'],
304304
'ml_r0': r_hat0['preds'],
305305
'ml_r1': r_hat1['preds']},
306+
'targets': {'ml_g0': g_hat0['targets'],
307+
'ml_g1': g_hat1['targets'],
308+
'ml_m': m_hat['targets'],
309+
'ml_r0': r_hat0['targets'],
310+
'ml_r1': r_hat1['targets']},
306311
'models': {'ml_g0': g_hat0['models'],
307312
'ml_g1': g_hat1['models'],
308313
'ml_m': m_hat['models'],

doubleml/double_ml_irm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class DoubleMLIRM(LinearScoreMixin, DoubleML):
5555
5656
trimming_threshold : float
5757
The threshold used for trimming.
58-
Default is ``1e-12``.
58+
Default is ``1e-2``.
5959
6060
draw_sample_splitting : bool
6161
Indicates whether the sample splitting should be drawn during initialization of the object.
@@ -114,7 +114,7 @@ def __init__(self,
114114
score='ATE',
115115
dml_procedure='dml2',
116116
trimming_rule='truncate',
117-
trimming_threshold=1e-12,
117+
trimming_threshold=1e-2,
118118
draw_sample_splitting=True,
119119
apply_cross_fitting=True):
120120
super().__init__(obj_dml_data,
@@ -206,7 +206,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
206206
'observed to be binary with values 0 and 1. Make sure that for classifiers '
207207
'probabilities and not labels are predicted.')
208208

209-
g_hat1 = {'preds': None, 'models': None}
209+
g_hat1 = {'preds': None, 'targets': None, 'models': None}
210210
if (self.score == 'ATE') | callable(self.score):
211211
g_hat1 = _dml_cv_predict(self._learner['ml_g'], x, y, smpls=smpls_d1, n_jobs=n_jobs_cv,
212212
est_params=self._get_params('ml_g1'), method=self._predict_method['ml_g'],
@@ -237,6 +237,9 @@ def _nuisance_est(self, smpls, n_jobs_cv, return_models=False):
237237
preds = {'predictions': {'ml_g0': g_hat0['preds'],
238238
'ml_g1': g_hat1['preds'],
239239
'ml_m': m_hat['preds']},
240+
'targets': {'ml_g0': g_hat0['targets'],
241+
'ml_g1': g_hat1['targets'],
242+
'ml_m': m_hat['targets']},
240243
'models': {'ml_g0': g_hat0['models'],
241244
'ml_g1': g_hat1['models'],
242245
'ml_m': m_hat['models']}

0 commit comments

Comments
 (0)