1
1
import numpy as np
2
2
from sklearn .utils import check_X_y
3
3
from sklearn .utils .multiclass import type_of_target
4
+ from sklearn .base import clone
5
+
6
+ import warnings
7
+ from functools import wraps
4
8
5
9
from .double_ml import DoubleML
6
10
from ._utils import _dml_cv_predict , _dml_tune , _check_finite_predictions
7
11
8
12
13
+ # To be removed in version 0.6.0
14
+ def changed_api_decorator (f ):
15
+ @wraps (f )
16
+ def wrapper (* args , ** kwds ):
17
+ ml_l_missing = (len (set (kwds ).intersection ({'obj_dml_data' , 'ml_l' , 'ml_m' })) + len (args )) < 4
18
+ if ml_l_missing & ('ml_g' in kwds ):
19
+ warnings .warn (("The required positional argument ml_g was renamed to ml_l. "
20
+ "Please adapt the argument name accordingly. "
21
+ "ml_g is redirected to ml_l. "
22
+ "The redirection will be removed in a future version." ),
23
+ DeprecationWarning , stacklevel = 2 )
24
+ kwds ['ml_l' ] = kwds .pop ('ml_g' )
25
+ return f (* args , ** kwds )
26
+ return wrapper
27
+
28
+
9
29
class DoubleMLPLR (DoubleML ):
10
30
"""Double machine learning for partially linear regression models
11
31
@@ -14,9 +34,9 @@ class DoubleMLPLR(DoubleML):
14
34
obj_dml_data : :class:`DoubleMLData` object
15
35
The :class:`DoubleMLData` object providing the data and specifying the variables for the causal model.
16
36
17
- ml_g : estimator implementing ``fit()`` and ``predict()``
37
+ ml_l : estimator implementing ``fit()`` and ``predict()``
18
38
A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
19
- :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`g_0 (X) = E[Y|X]`.
39
+ :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`\\ ell_0 (X) = E[Y|X]`.
20
40
21
41
ml_m : estimator implementing ``fit()`` and ``predict()``
22
42
A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
@@ -25,6 +45,13 @@ class DoubleMLPLR(DoubleML):
25
45
``predict_proba()`` can also be specified. If :py:func:`sklearn.base.is_classifier` returns ``True``,
26
46
``predict_proba()`` is used otherwise ``predict()``.
27
47
48
+ ml_g : estimator implementing ``fit()`` and ``predict()``
49
+ A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
50
+ :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function
51
+ :math:`g_0(X) = E[Y - D \\ theta_0|X]`.
52
+ Note: The learner `ml_g` is only required for the score ``'IV-type'``. Optionally, it can be specified and
53
+ estimated for callable scores.
54
+
28
55
n_folds : int
29
56
Number of folds.
30
57
Default is ``5``.
@@ -35,7 +62,7 @@ class DoubleMLPLR(DoubleML):
35
62
36
63
score : str or callable
37
64
A str (``'partialling out'`` or ``'IV-type'``) specifying the score function
38
- or a callable object / function with signature ``psi_a, psi_b = score(y, d, g_hat , m_hat, smpls)``.
65
+ or a callable object / function with signature ``psi_a, psi_b = score(y, d, l_hat , m_hat, g_hat , smpls)``.
39
66
Default is ``'partialling out'``.
40
67
41
68
dml_procedure : str
@@ -81,10 +108,12 @@ class DoubleMLPLR(DoubleML):
81
108
The high-dimensional vector :math:`X = (X_1, \\ ldots, X_p)` consists of other confounding covariates,
82
109
and :math:`\\ zeta` and :math:`V` are stochastic errors.
83
110
"""
111
+ @changed_api_decorator
84
112
def __init__ (self ,
85
113
obj_dml_data ,
86
- ml_g ,
114
+ ml_l ,
87
115
ml_m ,
116
+ ml_g = None ,
88
117
n_folds = 5 ,
89
118
n_rep = 1 ,
90
119
score = 'partialling out' ,
@@ -101,22 +130,41 @@ def __init__(self,
101
130
102
131
self ._check_data (self ._dml_data )
103
132
self ._check_score (self .score )
104
- _ = self ._check_learner (ml_g , 'ml_g' , regressor = True , classifier = False )
133
+
134
+ _ = self ._check_learner (ml_l , 'ml_l' , regressor = True , classifier = False )
105
135
ml_m_is_classifier = self ._check_learner (ml_m , 'ml_m' , regressor = True , classifier = True )
106
- self ._learner = {'ml_g' : ml_g , 'ml_m' : ml_m }
136
+ self ._learner = {'ml_l' : ml_l , 'ml_m' : ml_m }
137
+
138
+ if ml_g is not None :
139
+ if (isinstance (self .score , str ) & (self .score == 'IV-type' )) | callable (self .score ):
140
+ _ = self ._check_learner (ml_g , 'ml_g' , regressor = True , classifier = False )
141
+ self ._learner ['ml_g' ] = ml_g
142
+ else :
143
+ assert (isinstance (self .score , str ) & (self .score == 'partialling out' ))
144
+ warnings .warn (('A learner ml_g has been provided for score = "partialling out" but will be ignored. "'
145
+ 'A learner ml_g is not required for estimation.' ))
146
+ elif isinstance (self .score , str ) & (self .score == 'IV-type' ):
147
+ warnings .warn (("For score = 'IV-type', learners ml_l and ml_g should be specified. "
148
+ "Set ml_g = clone(ml_l)." ))
149
+ self ._learner ['ml_g' ] = clone (ml_l )
150
+
151
+ self ._predict_method = {'ml_l' : 'predict' }
152
+ if 'ml_g' in self ._learner :
153
+ self ._predict_method ['ml_g' ] = 'predict'
107
154
if ml_m_is_classifier :
108
- if obj_dml_data .binary_treats .all ():
109
- self ._predict_method = { 'ml_g' : 'predict' , 'ml_m' : ' predict_proba'}
155
+ if self . _dml_data .binary_treats .all ():
156
+ self ._predict_method [ 'ml_m' ] = ' predict_proba'
110
157
else :
111
158
raise ValueError (f'The ml_m learner { str (ml_m )} was identified as classifier '
112
159
'but at least one treatment variable is not binary with values 0 and 1.' )
113
160
else :
114
- self ._predict_method = { 'ml_g' : 'predict' , 'ml_m' : ' predict'}
161
+ self ._predict_method [ 'ml_m' ] = ' predict'
115
162
116
163
self ._initialize_ml_nuisance_params ()
117
164
118
165
def _initialize_ml_nuisance_params (self ):
119
- self ._params = {learner : {key : [None ] * self .n_rep for key in self ._dml_data .d_cols } for learner in ['ml_g' , 'ml_m' ]}
166
+ self ._params = {learner : {key : [None ] * self .n_rep for key in self ._dml_data .d_cols }
167
+ for learner in self ._learner }
120
168
121
169
def _check_score (self , score ):
122
170
if isinstance (score , str ):
@@ -138,16 +186,27 @@ def _check_data(self, obj_dml_data):
138
186
'To fit a partially linear IV regression model use DoubleMLPLIV instead of DoubleMLPLR.' )
139
187
return
140
188
189
+ # To be removed in version 0.6.0
190
+ def set_ml_nuisance_params (self , learner , treat_var , params ):
191
+ if isinstance (self .score , str ) & (self .score == 'partialling out' ) & (learner == 'ml_g' ):
192
+ warnings .warn (("Learner ml_g was renamed to ml_l. "
193
+ "Please adapt the argument learner accordingly. "
194
+ "The provided parameters are set for ml_l. "
195
+ "The redirection will be removed in a future version." ),
196
+ DeprecationWarning , stacklevel = 2 )
197
+ learner = 'ml_l'
198
+ super (DoubleMLPLR , self ).set_ml_nuisance_params (learner , treat_var , params )
199
+
141
200
def _nuisance_est (self , smpls , n_jobs_cv ):
142
201
x , y = check_X_y (self ._dml_data .x , self ._dml_data .y ,
143
202
force_all_finite = False )
144
203
x , d = check_X_y (x , self ._dml_data .d ,
145
204
force_all_finite = False )
146
205
147
- # nuisance g
148
- g_hat = _dml_cv_predict (self ._learner ['ml_g ' ], x , y , smpls = smpls , n_jobs = n_jobs_cv ,
149
- est_params = self ._get_params ('ml_g ' ), method = self ._predict_method ['ml_g ' ])
150
- _check_finite_predictions (g_hat , self ._learner ['ml_g ' ], 'ml_g ' , smpls )
206
+ # nuisance l
207
+ l_hat = _dml_cv_predict (self ._learner ['ml_l ' ], x , y , smpls = smpls , n_jobs = n_jobs_cv ,
208
+ est_params = self ._get_params ('ml_l ' ), method = self ._predict_method ['ml_l ' ])
209
+ _check_finite_predictions (l_hat , self ._learner ['ml_l ' ], 'ml_l ' , smpls )
151
210
152
211
# nuisance m
153
212
m_hat = _dml_cv_predict (self ._learner ['ml_m' ], x , d , smpls = smpls , n_jobs = n_jobs_cv ,
@@ -163,31 +222,79 @@ def _nuisance_est(self, smpls, n_jobs_cv):
163
222
'observed to be binary with values 0 and 1. Make sure that for classifiers '
164
223
'probabilities and not labels are predicted.' )
165
224
166
- psi_a , psi_b = self ._score_elements (y , d , g_hat , m_hat , smpls )
167
- preds = {'ml_g' : g_hat ,
168
- 'ml_m' : m_hat }
225
+ # an estimate of g is obtained for the IV-type score and callable scores
226
+ g_hat = None
227
+ if 'ml_g' in self ._learner :
228
+ # get an initial estimate for theta using the partialling out score
229
+ psi_a = - np .multiply (d - m_hat , d - m_hat )
230
+ psi_b = np .multiply (d - m_hat , y - l_hat )
231
+ theta_initial = - np .nanmean (psi_b ) / np .nanmean (psi_a )
232
+ # nuisance g
233
+ g_hat = _dml_cv_predict (self ._learner ['ml_g' ], x , y - theta_initial * d , smpls = smpls , n_jobs = n_jobs_cv ,
234
+ est_params = self ._get_params ('ml_g' ), method = self ._predict_method ['ml_g' ])
235
+ _check_finite_predictions (g_hat , self ._learner ['ml_g' ], 'ml_g' , smpls )
236
+
237
+ psi_a , psi_b = self ._score_elements (y , d , l_hat , m_hat , g_hat , smpls )
238
+ preds = {'ml_l' : l_hat ,
239
+ 'ml_m' : m_hat ,
240
+ 'ml_g' : g_hat }
169
241
170
242
return psi_a , psi_b , preds
171
243
172
- def _score_elements (self , y , d , g_hat , m_hat , smpls ):
244
+ def _score_elements (self , y , d , l_hat , m_hat , g_hat , smpls ):
173
245
# compute residuals
174
- u_hat = y - g_hat
246
+ u_hat = y - l_hat
175
247
v_hat = d - m_hat
176
- v_hatd = np .multiply (v_hat , d )
177
248
178
249
if isinstance (self .score , str ):
179
250
if self .score == 'IV-type' :
180
- psi_a = - v_hatd
251
+ psi_a = - np .multiply (v_hat , d )
252
+ psi_b = np .multiply (v_hat , y - g_hat )
181
253
else :
182
254
assert self .score == 'partialling out'
183
255
psi_a = - np .multiply (v_hat , v_hat )
184
- psi_b = np .multiply (v_hat , u_hat )
256
+ psi_b = np .multiply (v_hat , u_hat )
185
257
else :
186
258
assert callable (self .score )
187
- psi_a , psi_b = self .score (y , d , g_hat , m_hat , smpls )
259
+ psi_a , psi_b = self .score (y = y , d = d ,
260
+ l_hat = l_hat , m_hat = m_hat , g_hat = g_hat ,
261
+ smpls = smpls )
188
262
189
263
return psi_a , psi_b
190
264
265
+ # To be removed in version 0.6.0
266
+ def tune (self ,
267
+ param_grids ,
268
+ tune_on_folds = False ,
269
+ scoring_methods = None , # if None the estimator's score method is used
270
+ n_folds_tune = 5 ,
271
+ search_mode = 'grid_search' ,
272
+ n_iter_randomized_search = 100 ,
273
+ n_jobs_cv = None ,
274
+ set_as_params = True ,
275
+ return_tune_res = False ):
276
+
277
+ if isinstance (self .score , str ) and (self .score == 'partialling out' ) and (param_grids is not None ) and \
278
+ ('ml_g' in param_grids ) and ('ml_l' not in param_grids ):
279
+ warnings .warn (("Learner ml_g was renamed to ml_l. "
280
+ "Please adapt the key of param_grids accordingly. "
281
+ "The provided param_grids for ml_g are set for ml_l. "
282
+ "The redirection will be removed in a future version." ),
283
+ DeprecationWarning , stacklevel = 2 )
284
+ param_grids ['ml_l' ] = param_grids .pop ('ml_g' )
285
+
286
+ if isinstance (self .score , str ) and (self .score == 'partialling out' ) and (scoring_methods is not None ) and \
287
+ ('ml_g' in scoring_methods ) and ('ml_l' not in scoring_methods ):
288
+ warnings .warn (("Learner ml_g was renamed to ml_l. "
289
+ "Please adapt the key of scoring_methods accordingly. "
290
+ "The provided scoring_methods for ml_g are set for ml_l. "
291
+ "The redirection will be removed in a future version." ),
292
+ DeprecationWarning , stacklevel = 2 )
293
+ scoring_methods ['ml_l' ] = scoring_methods .pop ('ml_g' )
294
+
295
+ super (DoubleMLPLR , self ).tune (param_grids , tune_on_folds , scoring_methods , n_folds_tune , search_mode ,
296
+ n_iter_randomized_search , n_jobs_cv , set_as_params , return_tune_res )
297
+
191
298
def _nuisance_tuning (self , smpls , param_grids , scoring_methods , n_folds_tune , n_jobs_cv ,
192
299
search_mode , n_iter_randomized_search ):
193
300
x , y = check_X_y (self ._dml_data .x , self ._dml_data .y ,
@@ -196,25 +303,48 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
196
303
force_all_finite = False )
197
304
198
305
if scoring_methods is None :
199
- scoring_methods = {'ml_g' : None ,
200
- 'ml_m' : None }
306
+ scoring_methods = {'ml_l' : None ,
307
+ 'ml_m' : None ,
308
+ 'ml_g' : None }
201
309
202
310
train_inds = [train_index for (train_index , _ ) in smpls ]
203
- g_tune_res = _dml_tune (y , x , train_inds ,
204
- self ._learner ['ml_g ' ], param_grids ['ml_g ' ], scoring_methods ['ml_g ' ],
311
+ l_tune_res = _dml_tune (y , x , train_inds ,
312
+ self ._learner ['ml_l ' ], param_grids ['ml_l ' ], scoring_methods ['ml_l ' ],
205
313
n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
206
314
m_tune_res = _dml_tune (d , x , train_inds ,
207
315
self ._learner ['ml_m' ], param_grids ['ml_m' ], scoring_methods ['ml_m' ],
208
316
n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
209
317
210
- g_best_params = [xx .best_params_ for xx in g_tune_res ]
318
+ l_best_params = [xx .best_params_ for xx in l_tune_res ]
211
319
m_best_params = [xx .best_params_ for xx in m_tune_res ]
212
320
213
- params = {'ml_g' : g_best_params ,
214
- 'ml_m' : m_best_params }
215
-
216
- tune_res = {'g_tune' : g_tune_res ,
217
- 'm_tune' : m_tune_res }
321
+ # an ML model for g is obtained for the IV-type score and callable scores
322
+ if 'ml_g' in self ._learner :
323
+ # construct an initial theta estimate from the tuned models using the partialling out score
324
+ l_hat = np .full_like (y , np .nan )
325
+ m_hat = np .full_like (d , np .nan )
326
+ for idx , (train_index , _ ) in enumerate (smpls ):
327
+ l_hat [train_index ] = l_tune_res [idx ].predict (x [train_index , :])
328
+ m_hat [train_index ] = m_tune_res [idx ].predict (x [train_index , :])
329
+ psi_a = - np .multiply (d - m_hat , d - m_hat )
330
+ psi_b = np .multiply (d - m_hat , y - l_hat )
331
+ theta_initial = - np .nanmean (psi_b ) / np .nanmean (psi_a )
332
+ g_tune_res = _dml_tune (y - theta_initial * d , x , train_inds ,
333
+ self ._learner ['ml_g' ], param_grids ['ml_g' ], scoring_methods ['ml_g' ],
334
+ n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
335
+
336
+ g_best_params = [xx .best_params_ for xx in g_tune_res ]
337
+ params = {'ml_l' : l_best_params ,
338
+ 'ml_m' : m_best_params ,
339
+ 'ml_g' : g_best_params }
340
+ tune_res = {'l_tune' : l_tune_res ,
341
+ 'm_tune' : m_tune_res ,
342
+ 'g_tune' : g_tune_res }
343
+ else :
344
+ params = {'ml_l' : l_best_params ,
345
+ 'ml_m' : m_best_params }
346
+ tune_res = {'l_tune' : l_tune_res ,
347
+ 'm_tune' : m_tune_res }
218
348
219
349
res = {'params' : params ,
220
350
'tune_res' : tune_res }
0 commit comments