Skip to content

Commit eda1137

Browse files
authored
Merge pull request #229 from DoubleML/s-fix-bugs-doc
Add Gain statistics and weights for ATTE
2 parents 5d9a4c2 + 0bbd655 commit eda1137

12 files changed

+528
-94
lines changed

doubleml/_utils_checks.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,13 @@ def _check_benchmarks(benchmarks):
230230

231231
def _check_weights(weights, score, n_obs, n_rep):
232232
if weights is not None:
233-
if score != "ATE":
234-
raise NotImplementedError("weights can only be set for score type 'ATE'. "
235-
f"{score} was passed.")
233+
234+
# check general type
236235
if (not isinstance(weights, np.ndarray)) and (not isinstance(weights, dict)):
237236
raise TypeError("weights must be a numpy array or dictionary. "
238237
f"weights of type {str(type(weights))} was passed.")
238+
239+
# check shape
239240
if isinstance(weights, np.ndarray):
240241
if (weights.ndim != 1) or weights.shape[0] != n_obs:
241242
raise ValueError(f"weights must have shape ({n_obs},). "
@@ -245,7 +246,19 @@ def _check_weights(weights, score, n_obs, n_rep):
245246
if weights.sum() == 0:
246247
raise ValueError("At least one weight must be non-zero.")
247248

249+
# check special form for ATTE score
250+
if score == "ATTE":
251+
if not isinstance(weights, np.ndarray):
252+
raise TypeError("weights must be a numpy array for ATTE score. "
253+
f"weights of type {str(type(weights))} was passed.")
254+
255+
is_binary = np.all((np.power(weights, 2) - weights) == 0)
256+
if not is_binary:
257+
raise ValueError("weights must be binary for ATTE score.")
258+
259+
# check general form for ATE score
248260
if isinstance(weights, dict):
261+
assert score == "ATE"
249262
expected_keys = ["weights", "weights_bar"]
250263
if not set(weights.keys()) == set(expected_keys):
251264
raise ValueError(f"weights must have keys {expected_keys}. "

doubleml/double_ml.py

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ._utils_checks import _check_in_zero_one, _check_integer, _check_float, _check_bool, _check_is_partition, \
2020
_check_all_smpls, _check_smpl_split, _check_smpl_split_tpl, _check_benchmarks
2121
from ._utils_plots import _sensitivity_contour_plot
22-
22+
from .utils.gain_statistics import gain_statistics
2323

2424
_implemented_data_backends = ['DoubleMLData', 'DoubleMLClusterData']
2525

@@ -272,7 +272,8 @@ def params_names(self):
272272
@property
273273
def predictions(self):
274274
"""
275-
The predictions of the nuisance models.
275+
The predictions of the nuisance models in form of a dictinary.
276+
Each key refers to a nuisance element with a array of values of shape ``(n_obs, n_rep, n_coefs)``.
276277
"""
277278
return self._predictions
278279

@@ -354,6 +355,7 @@ def psi(self):
354355
Values of the score function after calling :meth:`fit`;
355356
For models (e.g., PLR, IRM, PLIV, IIVM) with linear score (in the parameter)
356357
:math:`\\psi(W; \\theta, \\eta) = \\psi_a(W; \\eta) \\theta + \\psi_b(W; \\eta)`.
358+
The shape is ``(n_obs, n_rep, n_coefs)``.
357359
"""
358360
return self._psi
359361

@@ -364,6 +366,7 @@ def psi_deriv(self):
364366
after calling :meth:`fit`;
365367
For models (e.g., PLR, IRM, PLIV, IIVM) with linear score (in the parameter)
366368
:math:`\\psi_a(W; \\eta)`.
369+
The shape is ``(n_obs, n_rep, n_coefs)``.
367370
"""
368371
return self._psi_deriv
369372

@@ -1995,45 +1998,6 @@ def sensitivity_benchmark(self, benchmarking_set):
19951998
dml_short._dml_data.x_cols = x_list_short
19961999
dml_short.fit()
19972000

1998-
# save elements for readability
1999-
var_y = np.var(self._dml_data.y)
2000-
var_y_residuals_long = np.squeeze(self.sensitivity_elements['sigma2'], axis=0)
2001-
nu2_long = np.squeeze(self.sensitivity_elements['nu2'], axis=0)
2002-
var_y_residuals_short = np.squeeze(dml_short.sensitivity_elements['sigma2'], axis=0)
2003-
nu2_short = np.squeeze(dml_short.sensitivity_elements['nu2'], axis=0)
2004-
2005-
# compute nonparametric R2
2006-
R2_y_long = 1.0 - np.divide(var_y_residuals_long, var_y)
2007-
R2_y_short = 1.0 - np.divide(var_y_residuals_short, var_y)
2008-
R2_riesz = np.divide(nu2_short, nu2_long)
2009-
2010-
# Gain statistics
2011-
all_cf_y_benchmark = np.clip(np.divide((R2_y_long - R2_y_short), (1.0 - R2_y_long)), 0, 1)
2012-
all_cf_d_benchmark = np.clip(np.divide((1.0 - R2_riesz), R2_riesz), 0, 1)
2013-
cf_y_benchmark = np.median(all_cf_y_benchmark, axis=0)
2014-
cf_d_benchmark = np.median(all_cf_d_benchmark, axis=0)
2015-
2016-
# change in estimates (slightly different to paper)
2017-
all_delta_theta = np.transpose(dml_short.all_coef - self.all_coef)
2018-
delta_theta = np.median(all_delta_theta, axis=0)
2019-
2020-
# degree of adversity
2021-
var_g = var_y_residuals_short - var_y_residuals_long
2022-
var_riesz = nu2_long - nu2_short
2023-
denom = np.sqrt(np.multiply(var_g, var_riesz), out=np.zeros_like(var_g), where=(var_g > 0) & (var_riesz > 0))
2024-
rho_sign = np.sign(all_delta_theta)
2025-
rho_values = np.clip(np.divide(np.absolute(all_delta_theta),
2026-
denom,
2027-
out=np.ones_like(all_delta_theta),
2028-
where=denom != 0),
2029-
0.0, 1.0)
2030-
all_rho_benchmark = np.multiply(rho_values, rho_sign)
2031-
rho_benchmark = np.median(all_rho_benchmark, axis=0)
2032-
benchmark_dict = {
2033-
"cf_y": cf_y_benchmark,
2034-
"cf_d": cf_d_benchmark,
2035-
"rho": rho_benchmark,
2036-
"delta_theta": delta_theta,
2037-
}
2001+
benchmark_dict = gain_statistics(dml_long=self, dml_short=dml_short)
20382002
df_benchmark = pd.DataFrame(benchmark_dict, index=self._dml_data.d_cols)
20392003
return df_benchmark

doubleml/double_ml_irm.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,26 @@ def _initialize_weights(self, weights):
218218
assert isinstance(weights, dict)
219219
self._weights = weights
220220

221-
def _get_weights(self):
222-
weights = self._weights['weights']
223-
if 'weights_bar' not in self._weights.keys():
224-
weights_bar = self._weights['weights']
221+
def _get_weights(self, m_hat=None):
222+
# standard case for ATE
223+
if self.score == 'ATE':
224+
weights = self._weights['weights']
225+
if 'weights_bar' not in self._weights.keys():
226+
weights_bar = self._weights['weights']
227+
else:
228+
weights_bar = self._weights['weights_bar'][:, self._i_rep]
225229
else:
226-
weights_bar = self._weights['weights_bar'][:, self._i_rep]
230+
# special case for ATTE
231+
assert self.score == 'ATTE'
232+
assert m_hat is not None
233+
subgroup = self._weights['weights'] * self._dml_data.d
234+
subgroup_probability = np.mean(subgroup)
235+
weights = np.divide(subgroup, subgroup_probability)
236+
237+
weights_bar = np.divide(
238+
np.multiply(m_hat, self._weights['weights']),
239+
subgroup_probability)
240+
227241
return weights, weights_bar
228242

229243
def _check_data(self, obj_dml_data):
@@ -277,8 +291,13 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
277291
f'predictions obtained with the ml_g learner {str(self._learner["ml_g"])} are also '
278292
'observed to be binary with values 0 and 1. Make sure that for classifiers '
279293
'probabilities and not labels are predicted.')
294+
if self.score == 'ATTE':
295+
# skip g_hat1 estimation
296+
g_hat1 = {'preds': None,
297+
'targets': None,
298+
'models': None}
280299

281-
if g1_external:
300+
elif g1_external:
282301
# use external predictions
283302
g_hat1 = {'preds': external_predictions['ml_g1'],
284303
'targets': None,
@@ -291,7 +310,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
291310
# adjust target values to consider only compatible subsamples
292311
g_hat1['targets'] = _cond_targets(g_hat1['targets'], cond_sample=(d == 1))
293312

294-
if self._dml_data.binary_outcome:
313+
if self._dml_data.binary_outcome & (self.score != 'ATTE'):
295314
binary_preds = (type_of_target(g_hat1['preds']) == 'binary')
296315
zero_one_preds = np.all((np.power(g_hat1['preds'], 2) - g_hat1['preds']) == 0)
297316
if binary_preds & zero_one_preds:
@@ -334,11 +353,6 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
334353

335354
def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls):
336355

337-
# fraction of treated for ATTE
338-
p_hat = None
339-
if self.score == 'ATTE':
340-
p_hat = np.mean(d)
341-
342356
m_hat_adj = np.full_like(m_hat, np.nan, dtype='float64')
343357
if self.normalize_ipw:
344358
if self.dml_procedure == 'dml1':
@@ -351,24 +365,21 @@ def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls):
351365

352366
# compute residuals
353367
u_hat0 = y - g_hat0
354-
u_hat1 = None
355-
if self.score == 'ATE':
356-
u_hat1 = y - g_hat1
357-
358-
if isinstance(self.score, str):
368+
if self.score == 'ATTE':
369+
g_hat1 = y
370+
u_hat1 = y - g_hat1
371+
372+
if (self.score == 'ATE') or (self.score == 'ATTE'):
373+
weights, weights_bar = self._get_weights(m_hat=m_hat_adj)
374+
psi_b = weights * (g_hat1 - g_hat0) \
375+
+ weights_bar * (
376+
np.divide(np.multiply(d, u_hat1), m_hat_adj)
377+
- np.divide(np.multiply(1.0-d, u_hat0), 1.0 - m_hat_adj))
359378
if self.score == 'ATE':
360-
weights, weights_bar = self._get_weights()
361-
psi_b = weights * (g_hat1 - g_hat0) \
362-
+ weights_bar * (
363-
np.divide(np.multiply(d, u_hat1), m_hat_adj)
364-
- np.divide(np.multiply(1.0-d, u_hat0), 1.0 - m_hat_adj))
365379
psi_a = np.full_like(m_hat_adj, -1.0)
366380
else:
367381
assert self.score == 'ATTE'
368-
psi_b = np.divide(np.multiply(d, u_hat0), p_hat) \
369-
- np.divide(np.multiply(m_hat_adj, np.multiply(1.0-d, u_hat0)),
370-
np.multiply(p_hat, (1.0 - m_hat_adj)))
371-
psi_a = - np.divide(d, p_hat)
382+
psi_a = -1.0 * weights
372383
else:
373384
assert callable(self.score)
374385
psi_a, psi_b = self.score(y=y, d=d,
@@ -384,15 +395,14 @@ def _sensitivity_element_est(self, preds):
384395

385396
m_hat = preds['predictions']['ml_m']
386397
g_hat0 = preds['predictions']['ml_g0']
387-
g_hat1 = preds['predictions']['ml_g1']
388-
389-
# use weights make this extendable
390398
if self.score == 'ATE':
391-
weights, weights_bar = self._get_weights()
399+
g_hat1 = preds['predictions']['ml_g1']
392400
else:
393401
assert self.score == 'ATTE'
394-
weights = np.divide(d, np.mean(d))
395-
weights_bar = np.divide(m_hat, np.mean(d))
402+
g_hat1 = y
403+
404+
# use weights make this extendable
405+
weights, weights_bar = self._get_weights(m_hat=m_hat)
396406

397407
sigma2_score_element = np.square(y - np.multiply(d, g_hat1) - np.multiply(1.0-d, g_hat0))
398408
sigma2 = np.mean(sigma2_score_element)

doubleml/tests/_utils_dml_cv_predict.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88
from sklearn.preprocessing import LabelEncoder
99
from sklearn.model_selection._validation import _fit_and_predict, _check_is_permutation
1010

11+
# Adapt _fit_and_predict for earlier sklearn versions
12+
from distutils.version import LooseVersion
13+
from sklearn import __version__ as sklearn_version
14+
15+
if LooseVersion(sklearn_version) < LooseVersion("1.4.0"):
16+
def _fit_and_predict_adapted(estimator, x, y, train, test, fit_params, method):
17+
res = _fit_and_predict(estimator, x, y, train, test,
18+
verbose=0,
19+
fit_params=fit_params,
20+
method=method)
21+
return res
22+
else:
23+
def _fit_and_predict_adapted(estimator, x, y, train, test, fit_params, method):
24+
return _fit_and_predict(estimator, x, y, train, test, fit_params, method)
25+
1126

1227
def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
1328
n_jobs=None, est_params=None, method='predict'):
@@ -22,18 +37,19 @@ def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
2237
train_index, test_index = smpls[0]
2338
# set some defaults aligned with cross_val_predict
2439
fit_params = None
25-
verbose = 0
2640
if method == 'predict_proba':
2741
predictions = np.full((len(y), 2), np.nan)
2842
else:
2943
predictions = np.full(len(y), np.nan)
3044
if est_params is None:
31-
xx = _fit_and_predict(clone(estimator),
32-
x, y, train_index, test_index, verbose, fit_params, method)
45+
xx = _fit_and_predict_adapted(
46+
clone(estimator),
47+
x, y, train_index, test_index, fit_params, method)
3348
else:
3449
assert isinstance(est_params, dict)
35-
xx = _fit_and_predict(clone(estimator).set_params(**est_params),
36-
x, y, train_index, test_index, verbose, fit_params, method)
50+
xx = _fit_and_predict_adapted(
51+
clone(estimator).set_params(**est_params),
52+
x, y, train_index, test_index, fit_params, method)
3753

3854
# implementation is (also at other parts) restricted to a sorted set of test_indices, but this could be fixed
3955
# inv_test_indices = np.argsort(test_indices)
@@ -61,22 +77,22 @@ def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
6177
pre_dispatch=pre_dispatch)
6278
# FixMe: Find a better way to handle the different combinations of paramters and smpls_is_partition
6379
if est_params is None:
64-
prediction_blocks = parallel(delayed(_fit_and_predict)(
80+
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
6581
estimator,
66-
x, y, train_index, test_index, verbose, fit_params, method)
82+
x, y, train_index, test_index, fit_params, method)
6783
for idx, (train_index, test_index) in enumerate(smpls))
6884
elif isinstance(est_params, dict):
6985
# if no fold-specific parameters we redirect to the standard method
7086
# warnings.warn("Using the same (hyper-)parameters for all folds")
71-
prediction_blocks = parallel(delayed(_fit_and_predict)(
87+
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
7288
clone(estimator).set_params(**est_params),
73-
x, y, train_index, test_index, verbose, fit_params, method)
89+
x, y, train_index, test_index, fit_params, method)
7490
for idx, (train_index, test_index) in enumerate(smpls))
7591
else:
7692
assert len(est_params) == len(smpls), 'provide one parameter setting per fold'
77-
prediction_blocks = parallel(delayed(_fit_and_predict)(
93+
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
7894
clone(estimator).set_params(**est_params[idx]),
79-
x, y, train_index, test_index, verbose, fit_params, method)
95+
x, y, train_index, test_index, fit_params, method)
8096
for idx, (train_index, test_index) in enumerate(smpls))
8197

8298
# Concatenate the predictions

doubleml/tests/_utils_irm_manual.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,11 @@ def fit_sensitivity_elements_irm(y, d, all_coef, predictions, score, n_rep):
298298

299299
m_hat = predictions['ml_m'][:, i_rep, 0]
300300
g_hat0 = predictions['ml_g0'][:, i_rep, 0]
301-
g_hat1 = predictions['ml_g1'][:, i_rep, 0]
301+
if score == 'ATE':
302+
g_hat1 = predictions['ml_g1'][:, i_rep, 0]
303+
else:
304+
assert score == 'ATTE'
305+
g_hat1 = y
302306

303307
if score == 'ATE':
304308
weights = np.ones_like(d)

doubleml/tests/test_doubleml_exceptions.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,16 +428,17 @@ def test_doubleml_exception_trimming_rule():
428428

429429
@pytest.mark.ci
430430
def test_doubleml_exception_weights():
431-
msg = "weights can only be set for score type 'ATE'. ATTE was passed."
432-
with pytest.raises(NotImplementedError, match=msg):
433-
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(),
434-
score='ATTE', weights=np.ones_like(dml_data_irm.d))
431+
435432
msg = "weights must be a numpy array or dictionary. weights of type <class 'int'> was passed."
436433
with pytest.raises(TypeError, match=msg):
437434
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(), weights=1)
438435
msg = r"weights must have keys \['weights', 'weights_bar'\]. keys dict_keys\(\['d'\]\) were passed."
439436
with pytest.raises(ValueError, match=msg):
440437
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(), weights={'d': [1, 2, 3]})
438+
msg = "weights must be a numpy array for ATTE score. weights of type <class 'dict'> was passed."
439+
with pytest.raises(TypeError, match=msg):
440+
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(),
441+
score='ATTE', weights={'weights': np.ones_like(dml_data_irm.d)})
441442

442443
# shape checks
443444
msg = rf"weights must have shape \({n},\). weights of shape \(1,\) was passed."
@@ -485,6 +486,11 @@ def test_doubleml_exception_weights():
485486
weights={'weights': np.ones((dml_data_irm.d.shape[0], )),
486487
'weights_bar': np.zeros((dml_data_irm.d.shape[0], 1))})
487488

489+
msg = "weights must be binary for ATTE score."
490+
with pytest.raises(ValueError, match=msg):
491+
_ = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(),
492+
score='ATTE', weights=np.random.choice([0, 0.2], dml_data_irm.d.shape[0]))
493+
488494

489495
@pytest.mark.ci
490496
def test_doubleml_exception_quantiles():

doubleml/tests/test_irm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def dml_irm_weights_fixture(n_rep, dml_procedure):
278278

279279
# First stage estimation
280280
ml_g = LinearRegression()
281-
ml_m = LogisticRegression(penalty='none', random_state=42)
281+
ml_m = LogisticRegression(penalty='l2', random_state=42)
282282

283283
# ATE with and without weights
284284
dml_irm_obj_ate_no_weights = dml.DoubleMLIRM(

0 commit comments

Comments
 (0)