Skip to content

Commit 5b09f66

Browse files
committed
add test for ssm tuning
1 parent 1103d48 commit 5b09f66

File tree

3 files changed

+197
-2
lines changed

3 files changed

+197
-2
lines changed

doubleml/irm/ssm.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ def __init__(self,
134134
self._check_data(self._dml_data)
135135
_check_score(self.score, ['missing-at-random', 'nonignorable'])
136136

137+
# for both score function stratification by d and s is viable
138+
self._strata = self._dml_data.d.reshape(-1, 1) + 2 * self._dml_data.s.reshape(-1, 1)
139+
if draw_sample_splitting:
140+
self.draw_sample_splitting()
141+
137142
ml_g_is_classifier = self._check_learner(ml_g, 'ml_g', regressor=True, classifier=True)
138143
_ = self._check_learner(ml_pi, 'ml_pi', regressor=False, classifier=True)
139144
_ = self._check_learner(ml_m, 'ml_m', regressor=False, classifier=True)
@@ -405,7 +410,11 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
405410
# time indicator is used for selection (selection not available in DoubleMLData yet)
406411
x, s = check_X_y(x, self._dml_data.s, force_all_finite=False)
407412

408-
dx = np.column_stack((d, x))
413+
if self._score == 'nonignorable':
414+
z, _ = check_X_y(self._dml_data.z, y, force_all_finite=False)
415+
dx = np.column_stack((x, d, z))
416+
else:
417+
dx = np.column_stack((x, d))
409418

410419
if scoring_methods is None:
411420
scoring_methods = {'ml_g': None,

doubleml/irm/tests/_utils_ssm_manual.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from sklearn.base import clone
33
from sklearn.model_selection import train_test_split
44

5-
from ...tests._utils import fit_predict, fit_predict_proba
5+
from ...tests._utils import fit_predict, fit_predict_proba, tune_grid_search
66
from ...utils._estimation import _predict_zero_one_propensity, _trimm
77

88

@@ -235,3 +235,30 @@ def var_selection(theta, psi_a, psi_b, n_obs):
235235
J = np.mean(psi_a)
236236
var = 1/n_obs * np.mean(np.power(np.multiply(psi_a, theta) + psi_b, 2)) / np.power(J, 2)
237237
return var
238+
239+
240+
def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tune,
241+
param_grid_g, param_grid_pi, param_grid_m):
242+
d0_s1 = np.intersect1d(np.where(d == 0)[0], np.where(s == 1)[0])
243+
d1_s1 = np.intersect1d(np.where(d == 1)[0], np.where(s == 1)[0])
244+
245+
g0_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune,
246+
train_cond=d0_s1)
247+
g1_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune,
248+
train_cond=d1_s1)
249+
250+
if score == 'nonignorable':
251+
dx = np.column_stack((x, d, z))
252+
else:
253+
dx = np.column_stack((x, d))
254+
255+
pi_tune_res = tune_grid_search(s, dx, ml_pi, smpls, param_grid_pi, n_folds_tune)
256+
257+
m_tune_res = tune_grid_search(d, x, ml_m, smpls, param_grid_m, n_folds_tune)
258+
259+
g0_best_params = [xx.best_params_ for xx in g0_tune_res]
260+
g1_best_params = [xx.best_params_ for xx in g1_tune_res]
261+
pi_best_params = [xx.best_params_ for xx in pi_tune_res]
262+
m_best_params = [xx.best_params_ for xx in m_tune_res]
263+
264+
return g0_best_params, g1_best_params, pi_best_params, m_best_params

doubleml/irm/tests/test_ssm_tune.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import numpy as np
2+
import pytest
3+
import math
4+
5+
from sklearn.base import clone
6+
7+
from sklearn.linear_model import LogisticRegression
8+
from sklearn.ensemble import RandomForestRegressor
9+
10+
import doubleml as dml
11+
12+
from ...tests._utils import draw_smpls
13+
from ._utils_ssm_manual import fit_selection, tune_nuisance_ssm
14+
15+
16+
@pytest.fixture(scope='module',
17+
params=[RandomForestRegressor(random_state=42)])
18+
def learner_g(request):
19+
return request.param
20+
21+
22+
@pytest.fixture(scope='module',
23+
params=[LogisticRegression(random_state=42)])
24+
def learner_m(request):
25+
return request.param
26+
27+
28+
@pytest.fixture(scope='module',
29+
params=['missing-at-random', 'nonignorable'])
30+
def score(request):
31+
return request.param
32+
33+
34+
@pytest.fixture(scope='module',
35+
params=[True, False])
36+
def normalize_ipw(request):
37+
return request.param
38+
39+
40+
@pytest.fixture(scope='module',
41+
params=[True, False])
42+
def tune_on_folds(request):
43+
return request.param
44+
45+
46+
def get_par_grid(learner):
47+
if learner.__class__ in [RandomForestRegressor]:
48+
par_grid = {'n_estimators': [5, 10, 20]}
49+
else:
50+
assert learner.__class__ in [LogisticRegression]
51+
par_grid = {'C': np.logspace(-2, 2, 10)}
52+
return par_grid
53+
54+
55+
@pytest.fixture(scope='module')
56+
def dml_ssm_fixture(generate_data_selection_mar, generate_data_selection_nonignorable,
57+
learner_g, learner_m, score,
58+
normalize_ipw, tune_on_folds):
59+
par_grid = {'ml_g': get_par_grid(learner_g),
60+
'ml_pi': get_par_grid(learner_m),
61+
'ml_m': get_par_grid(learner_m)}
62+
n_folds_tune = 4
63+
n_folds = 2
64+
65+
# collect data
66+
np.random.seed(42)
67+
if score == 'missing-at-random':
68+
(x, y, d, z, s) = generate_data_selection_mar
69+
else:
70+
(x, y, d, z, s) = generate_data_selection_nonignorable
71+
72+
n_obs = len(y)
73+
all_smpls = draw_smpls(n_obs, n_folds)
74+
75+
ml_g = clone(learner_g)
76+
ml_pi = clone(learner_m)
77+
ml_m = clone(learner_m)
78+
79+
np.random.seed(42)
80+
if score == 'missing-at-random':
81+
obj_dml_data = dml.DoubleMLData.from_arrays(x, y, d, z=None, s=s)
82+
dml_sel_obj = dml.DoubleMLSSM(obj_dml_data,
83+
ml_g, ml_pi, ml_m,
84+
n_folds=n_folds,
85+
score=score,
86+
normalize_ipw=normalize_ipw,
87+
draw_sample_splitting=False)
88+
else:
89+
assert score == 'nonignorable'
90+
obj_dml_data = dml.DoubleMLData.from_arrays(x, y, d, z=z, s=s)
91+
dml_sel_obj = dml.DoubleMLSSM(obj_dml_data,
92+
ml_g, ml_pi, ml_m,
93+
n_folds=n_folds,
94+
score=score,
95+
normalize_ipw=normalize_ipw,
96+
draw_sample_splitting=False)
97+
98+
# synchronize the sample splitting
99+
np.random.seed(42)
100+
dml_sel_obj.set_sample_splitting(all_smpls=all_smpls)
101+
102+
np.random.seed(42)
103+
# tune hyperparameters
104+
tune_res = dml_sel_obj.tune(par_grid, tune_on_folds=tune_on_folds, n_folds_tune=n_folds_tune,
105+
return_tune_res=False)
106+
assert isinstance(tune_res, dml.DoubleMLSSM)
107+
108+
dml_sel_obj.fit()
109+
110+
np.random.seed(42)
111+
smpls = all_smpls[0]
112+
if tune_on_folds:
113+
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm(
114+
y, x, d, z, s,
115+
clone(learner_g), clone(learner_m), clone(learner_m),
116+
smpls, score, n_folds_tune,
117+
par_grid['ml_g'], par_grid['ml_pi'], par_grid['ml_m'])
118+
119+
else:
120+
xx = [(np.arange(len(y)), np.array([]))]
121+
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm(
122+
y, x, d, z, s,
123+
clone(learner_g), clone(learner_m), clone(learner_m),
124+
xx, score, n_folds_tune,
125+
par_grid['ml_g'], par_grid['ml_pi'], par_grid['ml_m'])
126+
127+
g0_best_params = g0_best_params * n_folds
128+
g1_best_params = g1_best_params * n_folds
129+
pi_best_params = pi_best_params * n_folds
130+
m_best_params = m_best_params * n_folds
131+
132+
np.random.seed(42)
133+
res_manual = fit_selection(y, x, d, z, s,
134+
clone(learner_g), clone(learner_m), clone(learner_m),
135+
all_smpls, score,
136+
normalize_ipw=normalize_ipw,
137+
g_d0_params=g0_best_params, g_d1_params=g1_best_params,
138+
pi_params=pi_best_params, m_params=m_best_params)
139+
140+
res_dict = {'coef': dml_sel_obj.coef[0],
141+
'coef_manual': res_manual['theta'],
142+
'se': dml_sel_obj.se[0],
143+
'se_manual': res_manual['se']}
144+
145+
return res_dict
146+
147+
148+
@pytest.mark.ci
149+
def test_dml_ssm_coef(dml_ssm_fixture):
150+
assert math.isclose(dml_ssm_fixture['coef'],
151+
dml_ssm_fixture['coef_manual'],
152+
rel_tol=1e-9, abs_tol=1e-4)
153+
154+
155+
@pytest.mark.ci
156+
def test_dml_ssm_se(dml_ssm_fixture):
157+
assert math.isclose(dml_ssm_fixture['se'],
158+
dml_ssm_fixture['se_manual'],
159+
rel_tol=1e-9, abs_tol=1e-4)

0 commit comments

Comments
 (0)