Skip to content

Commit 0dcd89c

Browse files
authored
Merge pull request #169 from DoubleML/s-cate
Add GATE and CATE for IRM models
2 parents 4b7698e + a7263a4 commit 0dcd89c

File tree

10 files changed

+618
-9
lines changed

10 files changed

+618
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ share/python-wheels/
2626
.installed.cfg
2727
*.egg
2828
MANIFEST
29+
*.idea

doubleml/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from .double_ml_irm import DoubleMLIRM
66
from .double_ml_iivm import DoubleMLIIVM
77
from .double_ml_data import DoubleMLData, DoubleMLClusterData
8+
from .double_ml_blp import DoubleMLBLP
89

910
__all__ = ['DoubleMLPLR',
1011
'DoubleMLPLIV',
1112
'DoubleMLIRM',
1213
'DoubleMLIIVM',
1314
'DoubleMLData',
14-
'DoubleMLClusterData']
15+
'DoubleMLClusterData',
16+
'DoubleMLBLP']
1517

1618
__version__ = get_distribution('doubleml').version

doubleml/double_ml_blp.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import statsmodels.api as sm
2+
import numpy as np
3+
import pandas as pd
4+
5+
from scipy.stats import norm
6+
from scipy.linalg import sqrtm
7+
8+
9+
class DoubleMLBLP:
10+
"""Best linear predictor (BLP) for DoubleML with orthogonal signals.
11+
12+
Parameters
13+
----------
14+
orth_signal : :class:`numpy.array`
15+
The orthogonal signal to be predicted. Has to be of shape ``(n_obs,)``,
16+
where ``n_obs`` is the number of observations.
17+
18+
basis : :class:`pandas.DataFrame`
19+
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
20+
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
21+
22+
is_gate : bool
23+
Indicates whether the basis is constructed for GATEs (dummy-basis).
24+
Default is ``False``.
25+
"""
26+
27+
def __init__(self,
28+
orth_signal,
29+
basis,
30+
is_gate=False):
31+
32+
if not isinstance(orth_signal, np.ndarray):
33+
raise TypeError('The signal must be of np.ndarray type. '
34+
f'Signal of type {str(type(orth_signal))} was passed.')
35+
36+
if orth_signal.ndim != 1:
37+
raise ValueError('The signal must be of one dimensional. '
38+
f'Signal of dimensions {str(orth_signal.ndim)} was passed.')
39+
40+
if not isinstance(basis, pd.DataFrame):
41+
raise TypeError('The basis must be of DataFrame type. '
42+
f'Basis of type {str(type(basis))} was passed.')
43+
44+
if not basis.columns.is_unique:
45+
raise ValueError('Invalid pd.DataFrame: '
46+
'Contains duplicate column names.')
47+
48+
self._orth_signal = orth_signal
49+
self._basis = basis
50+
self._is_gate = is_gate
51+
52+
# initialize the score and the covariance
53+
self._blp_model = None
54+
self._blp_omega = None
55+
56+
def __str__(self):
57+
class_name = self.__class__.__name__
58+
header = f'================== {class_name} Object ==================\n'
59+
fit_summary = str(self.summary)
60+
res = header + \
61+
'\n------------------ Fit summary ------------------\n' + fit_summary
62+
return res
63+
64+
@property
65+
def blp_model(self):
66+
"""
67+
Best-Linear-Predictor model.
68+
"""
69+
return self._blp_model
70+
71+
@property
72+
def orth_signal(self):
73+
"""
74+
Orthogonal signal.
75+
"""
76+
return self._orth_signal
77+
78+
@property
79+
def basis(self):
80+
"""
81+
Basis.
82+
"""
83+
return self._basis
84+
85+
@property
86+
def blp_omega(self):
87+
"""
88+
Covariance matrix.
89+
"""
90+
return self._blp_omega
91+
92+
@property
93+
def summary(self):
94+
"""
95+
A summary for the best linear predictor effect after calling :meth:`fit`.
96+
"""
97+
col_names = ['coef', 'std err', 't', 'P>|t|', '[0.025', '0.975]']
98+
if self.blp_model is None:
99+
df_summary = pd.DataFrame(columns=col_names)
100+
else:
101+
summary_stats = {'coef': self.blp_model.params,
102+
'std err': self.blp_model.bse,
103+
't': self.blp_model.tvalues,
104+
'P>|t|': self.blp_model.pvalues,
105+
'[0.025': self.blp_model.conf_int()[0],
106+
'0.975]': self.blp_model.conf_int()[1]}
107+
df_summary = pd.DataFrame(summary_stats,
108+
columns=col_names)
109+
return df_summary
110+
111+
def fit(self):
112+
"""
113+
Estimate DoubleML models.
114+
115+
Returns
116+
-------
117+
self : object
118+
"""
119+
120+
# fit the best-linear-predictor of the orthogonal signal with respect to the grid
121+
self._blp_model = sm.OLS(self._orth_signal, self._basis).fit()
122+
self._blp_omega = self._blp_model.cov_HC0
123+
124+
return self
125+
126+
def confint(self, basis=None, joint=False, level=0.95, n_rep_boot=500):
127+
"""
128+
Confidence intervals for the BLP model.
129+
130+
Parameters
131+
----------
132+
basis : :class:`pandas.DataFrame`
133+
The basis for constructing the confidence interval. Has to have the same form as the basis from
134+
the construction. If ``None`` the basis for the construction of the model is used.
135+
Default is ``None``
136+
137+
joint : bool
138+
Indicates whether joint confidence intervals are computed.
139+
Default is ``False``
140+
141+
level : float
142+
The confidence level.
143+
Default is ``0.95``.
144+
145+
n_rep_boot : int
146+
The number of bootstrap repetitions (only relevant for joint confidence intervals).
147+
Default is ``500``.
148+
149+
Returns
150+
-------
151+
df_ci : pd.DataFrame
152+
A data frame with the confidence interval(s).
153+
"""
154+
if not isinstance(joint, bool):
155+
raise TypeError('joint must be True or False. '
156+
f'Got {str(joint)}.')
157+
158+
if not isinstance(level, float):
159+
raise TypeError('The confidence level must be of float type. '
160+
f'{str(level)} of type {str(type(level))} was passed.')
161+
if (level <= 0) | (level >= 1):
162+
raise ValueError('The confidence level must be in (0,1). '
163+
f'{str(level)} was passed.')
164+
165+
if not isinstance(n_rep_boot, int):
166+
raise TypeError('The number of bootstrap replications must be of int type. '
167+
f'{str(n_rep_boot)} of type {str(type(n_rep_boot))} was passed.')
168+
if n_rep_boot < 1:
169+
raise ValueError('The number of bootstrap replications must be positive. '
170+
f'{str(n_rep_boot)} was passed.')
171+
172+
if self._blp_model is None:
173+
raise ValueError('Apply fit() before confint().')
174+
175+
alpha = 1 - level
176+
gate_names = None
177+
# define basis if none is supplied
178+
if basis is None:
179+
if self._is_gate:
180+
# reduce to unique groups
181+
basis = pd.DataFrame(np.diag(v=np.full((self._basis.shape[1]), True)))
182+
gate_names = list(self._basis.columns.values)
183+
else:
184+
basis = self._basis
185+
elif not (basis.shape[1] == self._basis.shape[1]):
186+
raise ValueError('Invalid basis: DataFrame has to have the exact same number and ordering of columns.')
187+
elif not list(basis.columns.values) == list(self._basis.columns.values):
188+
raise ValueError('Invalid basis: DataFrame has to have the exact same number and ordering of columns.')
189+
190+
# blp of the orthogonal signal
191+
g_hat = self._blp_model.predict(basis)
192+
193+
np_basis = basis.to_numpy()
194+
# calculate se for basis elements
195+
blp_se = np.sqrt((np.dot(np_basis, self._blp_omega) * np_basis).sum(axis=1))
196+
197+
if joint:
198+
# calculate the maximum t-statistic with bootstrap
199+
normal_samples = np.random.normal(size=[basis.shape[1], n_rep_boot])
200+
bootstrap_samples = np.multiply(np.dot(np_basis, np.dot(sqrtm(self._blp_omega), normal_samples)).T,
201+
(1.0 / blp_se))
202+
203+
max_t_stat = np.quantile(np.max(np.abs(bootstrap_samples), axis=0), q=level)
204+
205+
# Lower simultaneous CI
206+
g_hat_lower = g_hat - max_t_stat * blp_se
207+
# Upper simultaneous CI
208+
g_hat_upper = g_hat + max_t_stat * blp_se
209+
210+
else:
211+
# Lower point-wise CI
212+
g_hat_lower = g_hat + norm.ppf(q=alpha / 2) * blp_se
213+
# Upper point-wise CI
214+
g_hat_upper = g_hat + norm.ppf(q=1 - alpha / 2) * blp_se
215+
216+
ci = np.vstack((g_hat_lower, g_hat, g_hat_upper)).T
217+
df_ci = pd.DataFrame(ci,
218+
columns=['{:.1f} %'.format(alpha/2 * 100), 'effect', '{:.1f} %'.format((1-alpha/2) * 100)],
219+
index=basis.index)
220+
221+
if self._is_gate and gate_names is not None:
222+
df_ci.index = gate_names
223+
224+
return df_ci

doubleml/double_ml_data.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -725,10 +725,9 @@ def from_arrays(cls, x, y, d, cluster_vars, z=None, use_other_treat_as_covariate
725725

726726
data = pd.concat((pd.DataFrame(cluster_vars, columns=cluster_cols), dml_data.data), axis=1)
727727

728-
return(cls(data, dml_data.y_col, dml_data.d_cols,
729-
cluster_cols,
730-
dml_data.x_cols, dml_data.z_cols,
731-
dml_data.use_other_treat_as_covariate, dml_data.force_all_x_finite))
728+
return (cls(data, dml_data.y_col, dml_data.d_cols, cluster_cols,
729+
dml_data.x_cols, dml_data.z_cols,
730+
dml_data.use_other_treat_as_covariate, dml_data.force_all_x_finite))
732731

733732
@property
734733
def cluster_cols(self):

doubleml/double_ml_iivm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _check_data(self, obj_dml_data):
206206
one_treat = (obj_dml_data.n_treat == 1)
207207
binary_treat = (type_of_target(obj_dml_data.d) == 'binary')
208208
zero_one_treat = np.all((np.power(obj_dml_data.d, 2) - obj_dml_data.d) == 0)
209-
if not(one_treat & binary_treat & zero_one_treat):
209+
if not (one_treat & binary_treat & zero_one_treat):
210210
raise ValueError('Incompatible data. '
211211
'To fit an IIVM model with DML '
212212
'exactly one binary variable with values 0 and 1 '
@@ -219,7 +219,7 @@ def _check_data(self, obj_dml_data):
219219
if one_instr:
220220
binary_instr = (type_of_target(obj_dml_data.z) == 'binary')
221221
zero_one_instr = np.all((np.power(obj_dml_data.z, 2) - obj_dml_data.z) == 0)
222-
if not(one_instr & binary_instr & zero_one_instr):
222+
if not (one_instr & binary_instr & zero_one_instr):
223223
raise ValueError(err_msg)
224224
else:
225225
raise ValueError(err_msg)

doubleml/double_ml_irm.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import numpy as np
2+
import pandas as pd
3+
import warnings
24
from sklearn.utils import check_X_y
35
from sklearn.utils.multiclass import type_of_target
46

57
from .double_ml import DoubleML
8+
9+
from .double_ml_blp import DoubleMLBLP
610
from .double_ml_data import DoubleMLData
711
from .double_ml_score_mixins import LinearScoreMixin
12+
813
from ._utils import _dml_cv_predict, _get_cond_smpls, _dml_tune, _check_finite_predictions
914

1015

@@ -171,7 +176,7 @@ def _check_data(self, obj_dml_data):
171176
one_treat = (obj_dml_data.n_treat == 1)
172177
binary_treat = (type_of_target(obj_dml_data.d) == 'binary')
173178
zero_one_treat = np.all((np.power(obj_dml_data.d, 2) - obj_dml_data.d) == 0)
174-
if not(one_treat & binary_treat & zero_one_treat):
179+
if not (one_treat & binary_treat & zero_one_treat):
175180
raise ValueError('Incompatible data. '
176181
'To fit an IRM model with DML '
177182
'exactly one binary variable with values 0 and 1 '
@@ -325,3 +330,80 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
325330
'tune_res': tune_res}
326331

327332
return res
333+
334+
def cate(self, basis):
335+
"""
336+
Calculate conditional average treatment effects (CATE) for a given basis.
337+
338+
Parameters
339+
----------
340+
basis : :class:`pandas.DataFrame`
341+
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
342+
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
343+
344+
Returns
345+
-------
346+
model : :class:`doubleML.DoubleMLBLP`
347+
Best linear Predictor model.
348+
"""
349+
valid_score = ['ATE']
350+
if self.score not in valid_score:
351+
raise ValueError('Invalid score ' + self.score + '. ' +
352+
'Valid score ' + ' or '.join(valid_score) + '.')
353+
354+
if self.n_rep != 1:
355+
raise NotImplementedError('Only implemented for one repetition. ' +
356+
f'Number of repetitions is {str(self.n_rep)}.')
357+
358+
# define the orthogonal signal
359+
orth_signal = self.psi_elements['psi_b'].reshape(-1)
360+
# fit the best linear predictor
361+
model = DoubleMLBLP(orth_signal, basis=basis).fit()
362+
363+
return model
364+
365+
def gate(self, groups):
366+
"""
367+
Calculate group average treatment effects (GATE) for mutually exclusive groups.
368+
369+
Parameters
370+
----------
371+
groups : :class:`pandas.DataFrame`
372+
The group indicator for estimating the best linear predictor.
373+
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
374+
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups.
375+
376+
Returns
377+
-------
378+
model : :class:`doubleML.DoubleMLBLPGATE`
379+
Best linear Predictor model for Group Effects.
380+
"""
381+
valid_score = ['ATE']
382+
if self.score not in valid_score:
383+
raise ValueError('Invalid score ' + self.score + '. ' +
384+
'Valid score ' + ' or '.join(valid_score) + '.')
385+
386+
if self.n_rep != 1:
387+
raise NotImplementedError('Only implemented for one repetition. ' +
388+
f'Number of repetitions is {str(self.n_rep)}.')
389+
390+
if not isinstance(groups, pd.DataFrame):
391+
raise TypeError('Groups must be of DataFrame type. '
392+
f'Groups of type {str(type(groups))} was passed.')
393+
394+
if not all(groups.dtypes == bool) or all(groups.dtypes == int):
395+
if groups.shape[1] == 1:
396+
groups = pd.get_dummies(groups, prefix='Group', prefix_sep='_')
397+
else:
398+
raise TypeError('Columns of groups must be of bool type or int type (dummy coded). '
399+
'Alternatively, groups should only contain one column.')
400+
401+
if any(groups.sum(0) <= 5):
402+
warnings.warn('At least one group effect is estimated with less than 6 observations.')
403+
404+
# define the orthogonal signal
405+
orth_signal = self.psi_elements['psi_b'].reshape(-1)
406+
# fit the best linear predictor for GATE (different confint() method)
407+
model = DoubleMLBLP(orth_signal, basis=groups, is_gate=True).fit()
408+
409+
return model

0 commit comments

Comments
 (0)