Skip to content

Commit 6385162

Browse files
committed
refactor BLP, CATE and GATE
1 parent 44ce540 commit 6385162

File tree

7 files changed

+137
-59
lines changed

7 files changed

+137
-59
lines changed

.coverage

68 KB
Binary file not shown.

doubleml/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
from .double_ml_irm import DoubleMLIRM
66
from .double_ml_iivm import DoubleMLIIVM
77
from .double_ml_data import DoubleMLData, DoubleMLClusterData
8-
from .double_ml_blp import DoubleMLIRMBLP
8+
from .double_ml_blp import DoubleMLBLP
99

1010
__all__ = ['DoubleMLPLR',
1111
'DoubleMLPLIV',
1212
'DoubleMLIRM',
1313
'DoubleMLIIVM',
1414
'DoubleMLData',
1515
'DoubleMLClusterData',
16-
'DoubleMLIRMBLP']
16+
'DoubleMLBLP']
1717

1818
__version__ = get_distribution('doubleml').version

doubleml/double_ml_blp.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from scipy.linalg import sqrtm
77

88

9-
class DoubleMLIRMBLP:
10-
"""Best linear predictor for DoubleML IRM models
9+
class DoubleMLBLP:
10+
"""Best linear predictor (BLP) for DoubleML with orthogonal signals.
1111
1212
Parameters
1313
----------
@@ -17,11 +17,16 @@ class DoubleMLIRMBLP:
1717
basis : :class:`pandas.DataFrame`
1818
The basis for estimating the best linear predictor. Has to have the shape (n,d),
1919
where d is the number of predictors.
20+
21+
is_gate : bool
22+
Indicates whether the basis is constructed for GATEs (dummy-basis).
23+
Default is ``False``.
2024
"""
2125

2226
def __init__(self,
2327
orth_signal,
24-
basis):
28+
basis,
29+
is_gate=False):
2530

2631
if not isinstance(orth_signal, np.ndarray):
2732
raise TypeError('The signal must be of np.ndarray type. '
@@ -41,6 +46,7 @@ def __init__(self,
4146

4247
self._orth_signal = orth_signal
4348
self._basis = basis
49+
self._is_gate = is_gate
4450

4551
# initialize the score and the covariance
4652
self._blp_model = None
@@ -89,15 +95,16 @@ def fit(self):
8995

9096
return self
9197

92-
def confint(self, basis, joint=False, level=0.95, n_rep_boot=500):
98+
def confint(self, basis=None, joint=False, level=0.95, n_rep_boot=500):
9399
"""
94-
Confidence intervals for BLP for DoubleML IRM.
100+
Confidence intervals for the BLP model.
95101
96102
Parameters
97103
----------
98104
basis : :class:`pandas.DataFrame`
99105
The basis for constructing the confidence interval. Has to have the same form as the basis from
100-
the construction.
106+
the construction. If ``None`` the basis for the construction of the model is used.
107+
Default is ``None``
101108
102109
joint : bool
103110
Indicates whether joint confidence intervals are computed.
@@ -138,6 +145,20 @@ def confint(self, basis, joint=False, level=0.95, n_rep_boot=500):
138145
raise ValueError('Apply fit() before confint().')
139146

140147
alpha = 1 - level
148+
gate_names = None
149+
# define basis if none is supplied
150+
if basis is None:
151+
if self._is_gate:
152+
# reduce to unique groups
153+
basis = pd.DataFrame(np.diag(v=np.full((self._basis.shape[1]), True)))
154+
gate_names = list(self._basis.columns.values)
155+
else:
156+
basis = self._basis
157+
elif not (basis.shape[1] == self._basis.shape[1]):
158+
raise ValueError('Invalid basis: DataFrame has to have the exact same number and ordering of columns.')
159+
elif not list(basis.columns.values) == list(self._basis.columns.values):
160+
raise ValueError('Invalid basis: DataFrame has to have the exact same number and ordering of columns.')
161+
141162
# blp of the orthogonal signal
142163
g_hat = self._blp_model.predict(basis)
143164

@@ -167,4 +188,8 @@ def confint(self, basis, joint=False, level=0.95, n_rep_boot=500):
167188
df_ci = pd.DataFrame(ci,
168189
columns=['{:.1f} %'.format(alpha/2 * 100), 'effect', '{:.1f} %'.format((1-alpha/2) * 100)],
169190
index=basis.index)
191+
192+
if self._is_gate and gate_names is not None:
193+
df_ci.index = gate_names
194+
170195
return df_ci

doubleml/double_ml_irm.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import pandas as pd
3+
import warnings
34
from sklearn.utils import check_X_y
45
from sklearn.utils.multiclass import type_of_target
56

67
from .double_ml import DoubleML
7-
from .double_ml_blp import DoubleMLIRMBLP
8-
8+
from .double_ml_blp import DoubleMLBLP
99
from ._utils import _dml_cv_predict, _get_cond_smpls, _dml_tune, _check_finite_predictions
1010

1111

@@ -334,43 +334,40 @@ def cate(self, basis):
334334
335335
Returns
336336
-------
337-
model : :class:`doubleML.DoubleMLIRMBLP`
337+
model : :class:`doubleML.DoubleMLBLP`
338338
Best linear Predictor model.
339339
"""
340+
valid_score = ['ATE']
341+
if self.score not in valid_score:
342+
raise ValueError('Invalid score ' + self.score + '. ' +
343+
'Valid score ' + ' or '.join(valid_score) + '.')
344+
340345
# define the orthogonal signal
341346
orth_signal = self.psi_b.reshape(-1)
342347
# fit the best linear predictor
343-
model = DoubleMLIRMBLP(orth_signal, basis=basis).fit()
348+
model = DoubleMLBLP(orth_signal, basis=basis).fit()
344349

345350
return model
346351

347-
def gate(self, groups, joint=False, level=0.95, n_rep_boot=500):
352+
def gate(self, groups):
348353
"""
349-
Calculate group average treatment effects (GATE) for a given basis.
354+
Calculate group average treatment effects (GATE) for mutually exclusive groups.
350355
351356
Parameters
352357
----------
353358
groups : :class:`pandas.DataFrame`
354-
The group indicator for estimating the best linear predictor. Has to have the shape (n,d),
359+
The group indicator for estimating the best linear predictor. Has to be dummy coded with shape (n,d),
355360
where d is the number of groups or (n,1) and contain the corresponding groups.
356361
357-
joint : bool
358-
Indicates whether joint confidence intervals are computed.
359-
Default is ``False``
360-
361-
level : float
362-
The confidence level.
363-
Default is ``0.95``.
364-
365-
n_rep_boot : int
366-
Number of bootstrap samples for joint confidence interval.
367-
Default is ``500``.
368-
369362
Returns
370363
-------
371-
df_ci : pd.DataFrame
372-
A data frame with the confidence interval(s).
364+
model : :class:`doubleML.DoubleMLBLPGATE`
365+
Best linear Predictor model for Group Effects.
373366
"""
367+
valid_score = ['ATE']
368+
if self.score not in valid_score:
369+
raise ValueError('Invalid score ' + self.score + '. ' +
370+
'Valid score ' + ' or '.join(valid_score) + '.')
374371

375372
if not isinstance(groups, pd.DataFrame):
376373
raise TypeError('Groups must be of DataFrame type. '
@@ -380,17 +377,15 @@ def gate(self, groups, joint=False, level=0.95, n_rep_boot=500):
380377
if groups.shape[1] == 1:
381378
groups = pd.get_dummies(groups, prefix='Group', prefix_sep='_')
382379
else:
383-
raise TypeError('Columns must be of of bool or int type or the data frame only should contain '
384-
'one column.')
380+
raise TypeError('Columns of groups must be of bool type or int type (dummy coded). '
381+
'Alternatively, groups should only contain one column.')
382+
383+
if any(groups.sum(0) <= 5):
384+
warnings.warn('At least one group effect is estimated with less than 6 observations.')
385385

386386
# define the orthogonal signal
387387
orth_signal = self.psi_b.reshape(-1)
388-
# fit the best linear predictor
389-
model = DoubleMLIRMBLP(orth_signal, basis=groups).fit()
388+
# fit the best linear predictor for GATE (different confint() method)
389+
model = DoubleMLBLP(orth_signal, basis=groups, is_gate=True).fit()
390390

391-
# reduce to unique groups and create confidence interval
392-
unique_groups = pd.DataFrame(np.diag(v=np.full((groups.shape[1]), True)))
393-
df_ci = model.confint(unique_groups, joint, level, n_rep_boot)
394-
df_ci.index = groups.columns.values
395-
396-
return df_ci
391+
return model

doubleml/tests/test_blp.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ def dml_blp_fixture(ci_joint, ci_level):
2626
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 3)))
2727
random_signal = np.random.normal(0, 1, size=(n, ))
2828

29-
blp = dml.DoubleMLIRMBLP(random_signal, random_basis).fit()
29+
blp = dml.DoubleMLBLP(random_signal, random_basis).fit()
3030
blp_manual = fit_blp(random_signal, random_basis)
3131

3232
np.random.seed(42)
33-
ci = blp.confint(random_basis, joint=ci_joint, level=ci_level, n_rep_boot=1000)
33+
ci_1 = blp.confint(random_basis, joint=ci_joint, level=ci_level, n_rep_boot=1000)
34+
np.random.seed(42)
35+
ci_2 = blp.confint(joint=ci_joint, level=ci_level, n_rep_boot=1000)
3436
np.random.seed(42)
3537
ci_manual = blp_confint(blp_manual, random_basis, joint=ci_joint, level=ci_level, n_rep_boot=1000)
3638

@@ -40,7 +42,8 @@ def dml_blp_fixture(ci_joint, ci_level):
4042
'values_manual': blp_manual.fittedvalues,
4143
'omega': blp.blp_omega,
4244
'omega_manual': blp_manual.cov_HC0,
43-
'ci': ci,
45+
'ci_1': ci_1,
46+
'ci_2': ci_2,
4447
'ci_manual': ci_manual}
4548

4649
return res_dict
@@ -68,7 +71,14 @@ def test_dml_blp_omega(dml_blp_fixture):
6871

6972

7073
@pytest.mark.ci
71-
def test_dml_blp_ci(dml_blp_fixture):
72-
assert np.allclose(dml_blp_fixture['ci'],
74+
def test_dml_blp_ci_1(dml_blp_fixture):
75+
assert np.allclose(dml_blp_fixture['ci_1'],
76+
dml_blp_fixture['ci_2'],
77+
rtol=1e-9, atol=1e-4)
78+
79+
80+
@pytest.mark.ci
81+
def test_dml_blp_ci_2(dml_blp_fixture):
82+
assert np.allclose(dml_blp_fixture['ci_1'],
7383
dml_blp_fixture['ci_manual'],
7484
rtol=1e-9, atol=1e-4)

doubleml/tests/test_doubleml_exceptions.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pandas as pd
33
import numpy as np
44

5-
from doubleml import DoubleMLPLR, DoubleMLIRM, DoubleMLIIVM, DoubleMLPLIV, DoubleMLData, DoubleMLClusterData, DoubleMLIRMBLP
5+
from doubleml import DoubleMLPLR, DoubleMLIRM, DoubleMLIIVM, DoubleMLPLIV, DoubleMLData, DoubleMLClusterData, DoubleMLBLP
66
from doubleml.datasets import make_plr_CCDDHNR2018, make_irm_data, make_pliv_CHS2015, make_iivm_data, \
77
make_pliv_multiway_cluster_CKMS2021
88

@@ -664,19 +664,24 @@ def test_doubleml_exception_blp():
664664

665665
msg = "The signal must be of np.ndarray type. Signal of type <class 'int'> was passed."
666666
with pytest.raises(TypeError, match=msg):
667-
DoubleMLIRMBLP(orth_signal=1, basis=random_basis)
667+
DoubleMLBLP(orth_signal=1, basis=random_basis)
668668
msg = 'The signal must be of one dimensional. Signal of dimensions 2 was passed.'
669669
with pytest.raises(ValueError, match=msg):
670-
DoubleMLIRMBLP(orth_signal=np.array([[1], [2]]), basis=random_basis)
670+
DoubleMLBLP(orth_signal=np.array([[1], [2]]), basis=random_basis)
671671
msg = "The basis must be of DataFrame type. Basis of type <class 'int'> was passed."
672672
with pytest.raises(TypeError, match=msg):
673-
DoubleMLIRMBLP(orth_signal=signal, basis=1)
673+
DoubleMLBLP(orth_signal=signal, basis=1)
674674
msg = 'Invalid pd.DataFrame: Contains duplicate column names.'
675675
with pytest.raises(ValueError, match=msg):
676-
DoubleMLIRMBLP(orth_signal=signal, basis=pd.DataFrame(np.array([[1, 2], [4, 5]]),
677-
columns=['x_1', 'x_1']))
676+
DoubleMLBLP(orth_signal=signal, basis=pd.DataFrame(np.array([[1, 2], [4, 5]]),
677+
columns=['a_1', 'a_1']))
678678

679-
dml_blp_confint = DoubleMLIRMBLP(orth_signal=signal, basis=random_basis)
679+
dml_blp_confint = DoubleMLBLP(orth_signal=signal, basis=random_basis)
680+
msg = r'Apply fit\(\) before confint\(\).'
681+
with pytest.raises(ValueError, match=msg):
682+
dml_blp_confint.confint(random_basis)
683+
684+
dml_blp_confint.fit()
680685
msg = 'joint must be True or False. Got 1.'
681686
with pytest.raises(TypeError, match=msg):
682687
dml_blp_confint.confint(random_basis, joint=1)
@@ -692,9 +697,12 @@ def test_doubleml_exception_blp():
692697
msg = 'The number of bootstrap replications must be positive. 0 was passed.'
693698
with pytest.raises(ValueError, match=msg):
694699
dml_blp_confint.confint(random_basis, n_rep_boot=0)
695-
msg = r'Apply fit\(\) before confint\(\).'
700+
msg = 'Invalid basis: DataFrame has to have the exact same number and ordering of columns.'
696701
with pytest.raises(ValueError, match=msg):
697-
dml_blp_confint.confint(random_basis)
702+
dml_blp_confint.confint(basis=pd.DataFrame(np.array([[1], [4]]), columns=['a_1']))
703+
msg = 'Invalid basis: DataFrame has to have the exact same number and ordering of columns.'
704+
with pytest.raises(ValueError, match=msg):
705+
dml_blp_confint.confint(basis=pd.DataFrame(np.array([[1, 2, 3], [4, 5, 6]]), columns=['x_1', 'x_2', 'x_3']))
698706

699707

700708
@pytest.mark.ci
@@ -709,6 +717,34 @@ def test_doubleml_exception_gate():
709717
msg = "Groups must be of DataFrame type. Groups of type <class 'int'> was passed."
710718
with pytest.raises(TypeError, match=msg):
711719
dml_irm_obj.gate(groups=2)
712-
msg = 'Columns must be of of bool or int type or the data frame only should contain one column.'
720+
msg = (r'Columns of groups must be of bool type or int type \(dummy coded\). '
721+
'Alternatively, groups should only contain one column.')
713722
with pytest.raises(TypeError, match=msg):
714-
dml_irm_obj.gate(groups=pd.DataFrame(np.random.normal(0, 1, size=(50, 3))))
723+
dml_irm_obj.gate(groups=pd.DataFrame(np.random.normal(0, 1, size=(dml_data_irm.n_obs, 3))))
724+
725+
dml_irm_obj = DoubleMLIRM(dml_data_irm,
726+
ml_g=Lasso(),
727+
ml_m=LogisticRegression(),
728+
trimming_threshold=0.05,
729+
n_folds=5,
730+
score='ATTE')
731+
dml_irm_obj.fit()
732+
733+
msg = 'Invalid score ATTE. Valid score ATE.'
734+
with pytest.raises(ValueError, match=msg):
735+
dml_irm_obj.gate(groups=2)
736+
737+
738+
@pytest.mark.ci
739+
def test_doubleml_exception_cate():
740+
dml_irm_obj = DoubleMLIRM(dml_data_irm,
741+
ml_g=Lasso(),
742+
ml_m=LogisticRegression(),
743+
trimming_threshold=0.05,
744+
n_folds=5,
745+
score='ATTE')
746+
dml_irm_obj.fit()
747+
748+
msg = 'Invalid score ATTE. Valid score ATE.'
749+
with pytest.raises(ValueError, match=msg):
750+
dml_irm_obj.cate(basis=2)

doubleml/tests/test_irm.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_dml_irm_boot(dml_irm_fixture):
124124

125125
@pytest.mark.ci
126126
def test_dml_irm_cate_gate():
127-
n = 50
127+
n = 9
128128
# collect data
129129
np.random.seed(42)
130130
obj_dml_data = make_irm_data(n_obs=n, dim_x=2)
@@ -143,12 +143,24 @@ def test_dml_irm_cate_gate():
143143
# create a random basis
144144
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
145145
cate = dml_irm_obj.cate(random_basis)
146-
assert isinstance(cate, dml.double_ml_blp.DoubleMLIRMBLP)
146+
assert isinstance(cate, dml.double_ml_blp.DoubleMLBLP)
147+
assert isinstance(cate.confint(), pd.DataFrame)
147148

148149
groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= 0,
149150
obj_dml_data.data['X1'] > 0.2]),
150151
columns=['Group 1', 'Group 2'])
151-
assert isinstance(dml_irm_obj.gate(groups_1), pd.DataFrame)
152+
msg = ('At least one group effect is estimated with less than 6 observations.')
153+
with pytest.warns(UserWarning, match=msg):
154+
gate_1 = dml_irm_obj.gate(groups_1)
155+
assert isinstance(gate_1, dml.double_ml_blp.DoubleMLBLP)
156+
assert isinstance(gate_1.confint(), pd.DataFrame)
157+
assert all(gate_1.confint().index == groups_1.columns)
152158

159+
np.random.seed(42)
153160
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
154-
assert isinstance(dml_irm_obj.gate(groups_2), pd.DataFrame)
161+
msg = ('At least one group effect is estimated with less than 6 observations.')
162+
with pytest.warns(UserWarning, match=msg):
163+
gate_2 = dml_irm_obj.gate(groups_2)
164+
assert isinstance(gate_2, dml.double_ml_blp.DoubleMLBLP)
165+
assert isinstance(gate_2.confint(), pd.DataFrame)
166+
assert all(gate_2.confint().index == ["Group_1", "Group_2"])

0 commit comments

Comments
 (0)