Skip to content

Commit 9b08a91

Browse files
committed
update DoubleMLQTE with basic framework
1 parent 99c84eb commit 9b08a91

File tree

2 files changed

+56
-16
lines changed

2 files changed

+56
-16
lines changed

doubleml/double_ml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self,
4444
self._is_cluster_data = True
4545
self._dml_data = obj_dml_data
4646

47-
# initialize framework which is set after the fit method is called
47+
# initialize framework which is constructed after the fit method is called
4848
self._framework = None
4949

5050
# initialize learners and parameters which are set model specific

doubleml/irm/qte.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .pq import DoubleMLPQ
1111
from .lpq import DoubleMLLPQ
1212
from .cvar import DoubleMLCVAR
13+
from ..double_ml_framework import concat
1314

1415
from ..utils._estimation import _draw_weights, _default_kde
1516
from ..utils.resampling import DoubleMLResampling
@@ -130,6 +131,9 @@ def __init__(self,
130131
self._is_cluster_data = True
131132
self._check_data(self._dml_data)
132133

134+
# initialize framework which is constructed after the fit method is called
135+
self._framework = None
136+
133137
# initialize and check trimming
134138
self._trimming_rule = trimming_rule
135139
self._trimming_threshold = trimming_threshold
@@ -180,6 +184,28 @@ def n_rep(self):
180184
"""
181185
return self._n_rep
182186

187+
@property
188+
def n_rep_boot(self):
189+
"""
190+
The number of bootstrap replications.
191+
"""
192+
if self._framework is None:
193+
n_rep_boot = None
194+
else:
195+
n_rep_boot = self._framework.n_rep_boot
196+
return n_rep_boot
197+
198+
@property
199+
def boot_method(self):
200+
"""
201+
The method to construct the bootstrap replications.
202+
"""
203+
if self._framework is None:
204+
method = None
205+
else:
206+
method = self._framework.boot_method
207+
return method
208+
183209
@property
184210
def smpls(self):
185211
"""
@@ -191,6 +217,13 @@ def smpls(self):
191217
raise ValueError(err_msg)
192218
return self._smpls
193219

220+
@property
221+
def framework(self):
222+
"""
223+
The corresponding :class:`doubleml.DoubleMLFramework` object.
224+
"""
225+
return self._framework
226+
194227
@property
195228
def quantiles(self):
196229
"""
@@ -247,6 +280,10 @@ def coef(self):
247280
"""
248281
return self._coef
249282

283+
@coef.setter
284+
def coef(self, value):
285+
self._coef = value
286+
250287
@property
251288
def all_coef(self):
252289
"""
@@ -261,6 +298,10 @@ def se(self):
261298
"""
262299
return self._se
263300

301+
@se.setter
302+
def se(self, value):
303+
self._se = value
304+
264305
@property
265306
def t_stat(self):
266307
"""
@@ -277,26 +318,16 @@ def pval(self):
277318
pval = 2 * norm.cdf(-np.abs(self.t_stat))
278319
return pval
279320

280-
@property
281-
def n_rep_boot(self):
282-
"""
283-
The number of bootstrap replications.
284-
"""
285-
return self._n_rep_boot
286-
287-
@property
288-
def boot_coef(self):
289-
"""
290-
Bootstrapped coefficients for the causal parameter(s) after calling :meth:`fit` and :meth:`bootstrap`.
291-
"""
292-
return self._boot_coef
293-
294321
@property
295322
def boot_t_stat(self):
296323
"""
297324
Bootstrapped t-statistics for the causal parameter(s) after calling :meth:`fit` and :meth:`bootstrap`.
298325
"""
299-
return self._boot_t_stat
326+
if self._framework is None:
327+
boot_t_stat = None
328+
else:
329+
boot_t_stat = self._framework.boot_t_stat
330+
return boot_t_stat
300331

301332
@property
302333
def modellist_0(self):
@@ -393,12 +424,18 @@ def fit(self, n_jobs_models=None, n_jobs_cv=None, store_predictions=True, store_
393424
for i_quant in range(self.n_quantiles))
394425

395426
# combine the estimates and scores
427+
framework_list = [None] * self.n_quantiles
428+
396429
for i_quant in range(self.n_quantiles):
397430
self._i_quant = i_quant
398431
# save the parallel fitted models in the right list
399432
self._modellist_0[self._i_quant] = fitted_models[self._i_quant][0]
400433
self._modellist_1[self._i_quant] = fitted_models[self._i_quant][1]
401434

435+
# set up the framework
436+
framework_list[self._i_quant] = self._modellist_1[self._i_quant].framework - \
437+
self._modellist_0[self._i_quant].framework
438+
402439
# treatment Effects
403440
self._all_coef[self._i_quant, :] = self.modellist_1[self._i_quant].all_coef - \
404441
self.modellist_0[self._i_quant].all_coef
@@ -419,6 +456,9 @@ def fit(self, n_jobs_models=None, n_jobs_cv=None, store_predictions=True, store_
419456
# aggregated parameter estimates and standard errors from repeated cross-fitting
420457
self._agg_cross_fit()
421458

459+
# aggregate all frameworks
460+
self._framework = concat(framework_list)
461+
422462
return self
423463

424464
def bootstrap(self, method='normal', n_rep_boot=500):

0 commit comments

Comments
 (0)