10
10
from .pq import DoubleMLPQ
11
11
from .lpq import DoubleMLLPQ
12
12
from .cvar import DoubleMLCVAR
13
+ from ..double_ml_framework import concat
13
14
14
15
from ..utils ._estimation import _draw_weights , _default_kde
15
16
from ..utils .resampling import DoubleMLResampling
@@ -130,6 +131,9 @@ def __init__(self,
130
131
self ._is_cluster_data = True
131
132
self ._check_data (self ._dml_data )
132
133
134
+ # initialize framework which is constructed after the fit method is called
135
+ self ._framework = None
136
+
133
137
# initialize and check trimming
134
138
self ._trimming_rule = trimming_rule
135
139
self ._trimming_threshold = trimming_threshold
@@ -180,6 +184,28 @@ def n_rep(self):
180
184
"""
181
185
return self ._n_rep
182
186
187
+ @property
188
+ def n_rep_boot (self ):
189
+ """
190
+ The number of bootstrap replications.
191
+ """
192
+ if self ._framework is None :
193
+ n_rep_boot = None
194
+ else :
195
+ n_rep_boot = self ._framework .n_rep_boot
196
+ return n_rep_boot
197
+
198
+ @property
199
+ def boot_method (self ):
200
+ """
201
+ The method to construct the bootstrap replications.
202
+ """
203
+ if self ._framework is None :
204
+ method = None
205
+ else :
206
+ method = self ._framework .boot_method
207
+ return method
208
+
183
209
@property
184
210
def smpls (self ):
185
211
"""
@@ -191,6 +217,13 @@ def smpls(self):
191
217
raise ValueError (err_msg )
192
218
return self ._smpls
193
219
220
+ @property
221
+ def framework (self ):
222
+ """
223
+ The corresponding :class:`doubleml.DoubleMLFramework` object.
224
+ """
225
+ return self ._framework
226
+
194
227
@property
195
228
def quantiles (self ):
196
229
"""
@@ -247,6 +280,10 @@ def coef(self):
247
280
"""
248
281
return self ._coef
249
282
283
+ @coef .setter
284
+ def coef (self , value ):
285
+ self ._coef = value
286
+
250
287
@property
251
288
def all_coef (self ):
252
289
"""
@@ -261,6 +298,10 @@ def se(self):
261
298
"""
262
299
return self ._se
263
300
301
+ @se .setter
302
+ def se (self , value ):
303
+ self ._se = value
304
+
264
305
@property
265
306
def t_stat (self ):
266
307
"""
@@ -277,26 +318,16 @@ def pval(self):
277
318
pval = 2 * norm .cdf (- np .abs (self .t_stat ))
278
319
return pval
279
320
280
- @property
281
- def n_rep_boot (self ):
282
- """
283
- The number of bootstrap replications.
284
- """
285
- return self ._n_rep_boot
286
-
287
- @property
288
- def boot_coef (self ):
289
- """
290
- Bootstrapped coefficients for the causal parameter(s) after calling :meth:`fit` and :meth:`bootstrap`.
291
- """
292
- return self ._boot_coef
293
-
294
321
@property
295
322
def boot_t_stat (self ):
296
323
"""
297
324
Bootstrapped t-statistics for the causal parameter(s) after calling :meth:`fit` and :meth:`bootstrap`.
298
325
"""
299
- return self ._boot_t_stat
326
+ if self ._framework is None :
327
+ boot_t_stat = None
328
+ else :
329
+ boot_t_stat = self ._framework .boot_t_stat
330
+ return boot_t_stat
300
331
301
332
@property
302
333
def modellist_0 (self ):
@@ -393,12 +424,18 @@ def fit(self, n_jobs_models=None, n_jobs_cv=None, store_predictions=True, store_
393
424
for i_quant in range (self .n_quantiles ))
394
425
395
426
# combine the estimates and scores
427
+ framework_list = [None ] * self .n_quantiles
428
+
396
429
for i_quant in range (self .n_quantiles ):
397
430
self ._i_quant = i_quant
398
431
# save the parallel fitted models in the right list
399
432
self ._modellist_0 [self ._i_quant ] = fitted_models [self ._i_quant ][0 ]
400
433
self ._modellist_1 [self ._i_quant ] = fitted_models [self ._i_quant ][1 ]
401
434
435
+ # set up the framework
436
+ framework_list [self ._i_quant ] = self ._modellist_1 [self ._i_quant ].framework - \
437
+ self ._modellist_0 [self ._i_quant ].framework
438
+
402
439
# treatment Effects
403
440
self ._all_coef [self ._i_quant , :] = self .modellist_1 [self ._i_quant ].all_coef - \
404
441
self .modellist_0 [self ._i_quant ].all_coef
@@ -419,6 +456,9 @@ def fit(self, n_jobs_models=None, n_jobs_cv=None, store_predictions=True, store_
419
456
# aggregated parameter estimates and standard errors from repeated cross-fitting
420
457
self ._agg_cross_fit ()
421
458
459
+ # aggregate all frameworks
460
+ self ._framework = concat (framework_list )
461
+
422
462
return self
423
463
424
464
def bootstrap (self , method = 'normal' , n_rep_boot = 500 ):
0 commit comments