Skip to content

Commit 9c51b93

Browse files
committed
Merge branch 'main' into s-restructure-doubleml
2 parents 8d34a73 + eda1137 commit 9c51b93

12 files changed

+528
-95
lines changed

doubleml/double_ml.py

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .utils._checks import _check_in_zero_one, _check_integer, _check_float, _check_bool, _check_is_partition, \
2020
_check_all_smpls, _check_smpl_split, _check_smpl_split_tpl, _check_benchmarks, _check_external_predictions
2121
from .utils._plots import _sensitivity_contour_plot
22-
22+
from .utils.gain_statistics import gain_statistics
2323

2424
_implemented_data_backends = ['DoubleMLData', 'DoubleMLClusterData']
2525

@@ -272,7 +272,8 @@ def params_names(self):
272272
@property
273273
def predictions(self):
274274
"""
275-
The predictions of the nuisance models.
275+
The predictions of the nuisance models in form of a dictinary.
276+
Each key refers to a nuisance element with a array of values of shape ``(n_obs, n_rep, n_coefs)``.
276277
"""
277278
return self._predictions
278279

@@ -354,6 +355,7 @@ def psi(self):
354355
Values of the score function after calling :meth:`fit`;
355356
For models (e.g., PLR, IRM, PLIV, IIVM) with linear score (in the parameter)
356357
:math:`\\psi(W; \\theta, \\eta) = \\psi_a(W; \\eta) \\theta + \\psi_b(W; \\eta)`.
358+
The shape is ``(n_obs, n_rep, n_coefs)``.
357359
"""
358360
return self._psi
359361

@@ -364,6 +366,7 @@ def psi_deriv(self):
364366
after calling :meth:`fit`;
365367
For models (e.g., PLR, IRM, PLIV, IIVM) with linear score (in the parameter)
366368
:math:`\\psi_a(W; \\eta)`.
369+
The shape is ``(n_obs, n_rep, n_coefs)``.
367370
"""
368371
return self._psi_deriv
369372

@@ -1966,45 +1969,6 @@ def sensitivity_benchmark(self, benchmarking_set):
19661969
dml_short._dml_data.x_cols = x_list_short
19671970
dml_short.fit()
19681971

1969-
# save elements for readability
1970-
var_y = np.var(self._dml_data.y)
1971-
var_y_residuals_long = np.squeeze(self.sensitivity_elements['sigma2'], axis=0)
1972-
nu2_long = np.squeeze(self.sensitivity_elements['nu2'], axis=0)
1973-
var_y_residuals_short = np.squeeze(dml_short.sensitivity_elements['sigma2'], axis=0)
1974-
nu2_short = np.squeeze(dml_short.sensitivity_elements['nu2'], axis=0)
1975-
1976-
# compute nonparametric R2
1977-
R2_y_long = 1.0 - np.divide(var_y_residuals_long, var_y)
1978-
R2_y_short = 1.0 - np.divide(var_y_residuals_short, var_y)
1979-
R2_riesz = np.divide(nu2_short, nu2_long)
1980-
1981-
# Gain statistics
1982-
all_cf_y_benchmark = np.clip(np.divide((R2_y_long - R2_y_short), (1.0 - R2_y_long)), 0, 1)
1983-
all_cf_d_benchmark = np.clip(np.divide((1.0 - R2_riesz), R2_riesz), 0, 1)
1984-
cf_y_benchmark = np.median(all_cf_y_benchmark, axis=0)
1985-
cf_d_benchmark = np.median(all_cf_d_benchmark, axis=0)
1986-
1987-
# change in estimates (slightly different to paper)
1988-
all_delta_theta = np.transpose(dml_short.all_coef - self.all_coef)
1989-
delta_theta = np.median(all_delta_theta, axis=0)
1990-
1991-
# degree of adversity
1992-
var_g = var_y_residuals_short - var_y_residuals_long
1993-
var_riesz = nu2_long - nu2_short
1994-
denom = np.sqrt(np.multiply(var_g, var_riesz), out=np.zeros_like(var_g), where=(var_g > 0) & (var_riesz > 0))
1995-
rho_sign = np.sign(all_delta_theta)
1996-
rho_values = np.clip(np.divide(np.absolute(all_delta_theta),
1997-
denom,
1998-
out=np.ones_like(all_delta_theta),
1999-
where=denom != 0),
2000-
0.0, 1.0)
2001-
all_rho_benchmark = np.multiply(rho_values, rho_sign)
2002-
rho_benchmark = np.median(all_rho_benchmark, axis=0)
2003-
benchmark_dict = {
2004-
"cf_y": cf_y_benchmark,
2005-
"cf_d": cf_d_benchmark,
2006-
"rho": rho_benchmark,
2007-
"delta_theta": delta_theta,
2008-
}
1972+
benchmark_dict = gain_statistics(dml_long=self, dml_short=dml_short)
20091973
df_benchmark = pd.DataFrame(benchmark_dict, index=self._dml_data.d_cols)
20101974
return df_benchmark

doubleml/irm/irm.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,26 @@ def _initialize_weights(self, weights):
221221
assert isinstance(weights, dict)
222222
self._weights = weights
223223

224-
def _get_weights(self):
225-
weights = self._weights['weights']
226-
if 'weights_bar' not in self._weights.keys():
227-
weights_bar = self._weights['weights']
224+
def _get_weights(self, m_hat=None):
225+
# standard case for ATE
226+
if self.score == 'ATE':
227+
weights = self._weights['weights']
228+
if 'weights_bar' not in self._weights.keys():
229+
weights_bar = self._weights['weights']
230+
else:
231+
weights_bar = self._weights['weights_bar'][:, self._i_rep]
228232
else:
229-
weights_bar = self._weights['weights_bar'][:, self._i_rep]
233+
# special case for ATTE
234+
assert self.score == 'ATTE'
235+
assert m_hat is not None
236+
subgroup = self._weights['weights'] * self._dml_data.d
237+
subgroup_probability = np.mean(subgroup)
238+
weights = np.divide(subgroup, subgroup_probability)
239+
240+
weights_bar = np.divide(
241+
np.multiply(m_hat, self._weights['weights']),
242+
subgroup_probability)
243+
230244
return weights, weights_bar
231245

232246
def _check_data(self, obj_dml_data):
@@ -280,8 +294,13 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
280294
f'predictions obtained with the ml_g learner {str(self._learner["ml_g"])} are also '
281295
'observed to be binary with values 0 and 1. Make sure that for classifiers '
282296
'probabilities and not labels are predicted.')
297+
if self.score == 'ATTE':
298+
# skip g_hat1 estimation
299+
g_hat1 = {'preds': None,
300+
'targets': None,
301+
'models': None}
283302

284-
if g1_external:
303+
elif g1_external:
285304
# use external predictions
286305
g_hat1 = {'preds': external_predictions['ml_g1'],
287306
'targets': None,
@@ -294,7 +313,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
294313
# adjust target values to consider only compatible subsamples
295314
g_hat1['targets'] = _cond_targets(g_hat1['targets'], cond_sample=(d == 1))
296315

297-
if self._dml_data.binary_outcome:
316+
if self._dml_data.binary_outcome & (self.score != 'ATTE'):
298317
binary_preds = (type_of_target(g_hat1['preds']) == 'binary')
299318
zero_one_preds = np.all((np.power(g_hat1['preds'], 2) - g_hat1['preds']) == 0)
300319
if binary_preds & zero_one_preds:
@@ -338,11 +357,6 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
338357

339358
def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls):
340359

341-
# fraction of treated for ATTE
342-
p_hat = None
343-
if self.score == 'ATTE':
344-
p_hat = np.mean(d)
345-
346360
m_hat_adj = np.full_like(m_hat, np.nan, dtype='float64')
347361
if self.normalize_ipw:
348362
if self.dml_procedure == 'dml1':
@@ -355,24 +369,21 @@ def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls):
355369

356370
# compute residuals
357371
u_hat0 = y - g_hat0
358-
u_hat1 = None
359-
if self.score == 'ATE':
360-
u_hat1 = y - g_hat1
361-
362-
if isinstance(self.score, str):
372+
if self.score == 'ATTE':
373+
g_hat1 = y
374+
u_hat1 = y - g_hat1
375+
376+
if (self.score == 'ATE') or (self.score == 'ATTE'):
377+
weights, weights_bar = self._get_weights(m_hat=m_hat_adj)
378+
psi_b = weights * (g_hat1 - g_hat0) \
379+
+ weights_bar * (
380+
np.divide(np.multiply(d, u_hat1), m_hat_adj)
381+
- np.divide(np.multiply(1.0-d, u_hat0), 1.0 - m_hat_adj))
363382
if self.score == 'ATE':
364-
weights, weights_bar = self._get_weights()
365-
psi_b = weights * (g_hat1 - g_hat0) \
366-
+ weights_bar * (
367-
np.divide(np.multiply(d, u_hat1), m_hat_adj)
368-
- np.divide(np.multiply(1.0-d, u_hat0), 1.0 - m_hat_adj))
369383
psi_a = np.full_like(m_hat_adj, -1.0)
370384
else:
371385
assert self.score == 'ATTE'
372-
psi_b = np.divide(np.multiply(d, u_hat0), p_hat) \
373-
- np.divide(np.multiply(m_hat_adj, np.multiply(1.0-d, u_hat0)),
374-
np.multiply(p_hat, (1.0 - m_hat_adj)))
375-
psi_a = - np.divide(d, p_hat)
386+
psi_a = -1.0 * weights
376387
else:
377388
assert callable(self.score)
378389
psi_a, psi_b = self.score(y=y, d=d,
@@ -388,15 +399,14 @@ def _sensitivity_element_est(self, preds):
388399

389400
m_hat = preds['predictions']['ml_m']
390401
g_hat0 = preds['predictions']['ml_g0']
391-
g_hat1 = preds['predictions']['ml_g1']
392-
393-
# use weights make this extendable
394402
if self.score == 'ATE':
395-
weights, weights_bar = self._get_weights()
403+
g_hat1 = preds['predictions']['ml_g1']
396404
else:
397405
assert self.score == 'ATTE'
398-
weights = np.divide(d, np.mean(d))
399-
weights_bar = np.divide(m_hat, np.mean(d))
406+
g_hat1 = y
407+
408+
# use weights make this extendable
409+
weights, weights_bar = self._get_weights(m_hat=m_hat)
400410

401411
sigma2_score_element = np.square(y - np.multiply(d, g_hat1) - np.multiply(1.0-d, g_hat0))
402412
sigma2 = np.mean(sigma2_score_element)

doubleml/irm/tests/_utils_irm_manual.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,11 @@ def fit_sensitivity_elements_irm(y, d, all_coef, predictions, score, n_rep):
298298

299299
m_hat = predictions['ml_m'][:, i_rep, 0]
300300
g_hat0 = predictions['ml_g0'][:, i_rep, 0]
301-
g_hat1 = predictions['ml_g1'][:, i_rep, 0]
301+
if score == 'ATE':
302+
g_hat1 = predictions['ml_g1'][:, i_rep, 0]
303+
else:
304+
assert score == 'ATTE'
305+
g_hat1 = y
302306

303307
if score == 'ATE':
304308
weights = np.ones_like(d)

doubleml/irm/tests/test_irm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def dml_irm_weights_fixture(n_rep, dml_procedure):
278278

279279
# First stage estimation
280280
ml_g = LinearRegression()
281-
ml_m = LogisticRegression(penalty='none', random_state=42)
281+
ml_m = LogisticRegression(penalty='l2', random_state=42)
282282

283283
# ATE with and without weights
284284
dml_irm_obj_ate_no_weights = dml.DoubleMLIRM(

doubleml/tests/_utils_dml_cv_predict.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88
from sklearn.preprocessing import LabelEncoder
99
from sklearn.model_selection._validation import _fit_and_predict, _check_is_permutation
1010

11+
# Adapt _fit_and_predict for earlier sklearn versions
12+
from distutils.version import LooseVersion
13+
from sklearn import __version__ as sklearn_version
14+
15+
if LooseVersion(sklearn_version) < LooseVersion("1.4.0"):
16+
def _fit_and_predict_adapted(estimator, x, y, train, test, fit_params, method):
17+
res = _fit_and_predict(estimator, x, y, train, test,
18+
verbose=0,
19+
fit_params=fit_params,
20+
method=method)
21+
return res
22+
else:
23+
def _fit_and_predict_adapted(estimator, x, y, train, test, fit_params, method):
24+
return _fit_and_predict(estimator, x, y, train, test, fit_params, method)
25+
1126

1227
def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
1328
n_jobs=None, est_params=None, method='predict'):
@@ -22,18 +37,19 @@ def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
2237
train_index, test_index = smpls[0]
2338
# set some defaults aligned with cross_val_predict
2439
fit_params = None
25-
verbose = 0
2640
if method == 'predict_proba':
2741
predictions = np.full((len(y), 2), np.nan)
2842
else:
2943
predictions = np.full(len(y), np.nan)
3044
if est_params is None:
31-
xx = _fit_and_predict(clone(estimator),
32-
x, y, train_index, test_index, verbose, fit_params, method)
45+
xx = _fit_and_predict_adapted(
46+
clone(estimator),
47+
x, y, train_index, test_index, fit_params, method)
3348
else:
3449
assert isinstance(est_params, dict)
35-
xx = _fit_and_predict(clone(estimator).set_params(**est_params),
36-
x, y, train_index, test_index, verbose, fit_params, method)
50+
xx = _fit_and_predict_adapted(
51+
clone(estimator).set_params(**est_params),
52+
x, y, train_index, test_index, fit_params, method)
3753

3854
# implementation is (also at other parts) restricted to a sorted set of test_indices, but this could be fixed
3955
# inv_test_indices = np.argsort(test_indices)
@@ -61,22 +77,22 @@ def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
6177
pre_dispatch=pre_dispatch)
6278
# FixMe: Find a better way to handle the different combinations of paramters and smpls_is_partition
6379
if est_params is None:
64-
prediction_blocks = parallel(delayed(_fit_and_predict)(
80+
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
6581
estimator,
66-
x, y, train_index, test_index, verbose, fit_params, method)
82+
x, y, train_index, test_index, fit_params, method)
6783
for idx, (train_index, test_index) in enumerate(smpls))
6884
elif isinstance(est_params, dict):
6985
# if no fold-specific parameters we redirect to the standard method
7086
# warnings.warn("Using the same (hyper-)parameters for all folds")
71-
prediction_blocks = parallel(delayed(_fit_and_predict)(
87+
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
7288
clone(estimator).set_params(**est_params),
73-
x, y, train_index, test_index, verbose, fit_params, method)
89+
x, y, train_index, test_index, fit_params, method)
7490
for idx, (train_index, test_index) in enumerate(smpls))
7591
else:
7692
assert len(est_params) == len(smpls), 'provide one parameter setting per fold'
77-
prediction_blocks = parallel(delayed(_fit_and_predict)(
93+
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
7894
clone(estimator).set_params(**est_params[idx]),
79-
x, y, train_index, test_index, verbose, fit_params, method)
95+
x, y, train_index, test_index, fit_params, method)
8096
for idx, (train_index, test_index) in enumerate(smpls))
8197

8298
# Concatenate the predictions

doubleml/tests/test_exceptions.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,16 +428,17 @@ def test_doubleml_exception_trimming_rule():
428428

429429
@pytest.mark.ci
430430
def test_doubleml_exception_weights():
431-
msg = "weights can only be set for score type 'ATE'. ATTE was passed."
432-
with pytest.raises(NotImplementedError, match=msg):
433-
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(),
434-
score='ATTE', weights=np.ones_like(dml_data_irm.d))
431+
435432
msg = "weights must be a numpy array or dictionary. weights of type <class 'int'> was passed."
436433
with pytest.raises(TypeError, match=msg):
437434
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(), weights=1)
438435
msg = r"weights must have keys \['weights', 'weights_bar'\]. keys dict_keys\(\['d'\]\) were passed."
439436
with pytest.raises(ValueError, match=msg):
440437
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(), weights={'d': [1, 2, 3]})
438+
msg = "weights must be a numpy array for ATTE score. weights of type <class 'dict'> was passed."
439+
with pytest.raises(TypeError, match=msg):
440+
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(),
441+
score='ATTE', weights={'weights': np.ones_like(dml_data_irm.d)})
441442

442443
# shape checks
443444
msg = rf"weights must have shape \({n},\). weights of shape \(1,\) was passed."
@@ -485,6 +486,11 @@ def test_doubleml_exception_weights():
485486
weights={'weights': np.ones((dml_data_irm.d.shape[0], )),
486487
'weights_bar': np.zeros((dml_data_irm.d.shape[0], 1))})
487488

489+
msg = "weights must be binary for ATTE score."
490+
with pytest.raises(ValueError, match=msg):
491+
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(),
492+
score='ATTE', weights=np.random.choice([0, 0.2], dml_data_irm.d.shape[0]))
493+
488494

489495
@pytest.mark.ci
490496
def test_doubleml_exception_quantiles():

0 commit comments

Comments
 (0)