Skip to content

Enable external predictions for short model in benchmark #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 11, 2024
10 changes: 8 additions & 2 deletions doubleml/double_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,7 +1735,7 @@ def sensitivity_plot(self, idx_treatment=0, value='theta', include_scenario=True
fill=fill)
return fig

def sensitivity_benchmark(self, benchmarking_set):
def sensitivity_benchmark(self, benchmarking_set, fit_args=None):
"""
Computes a benchmark for a given set of features.
Returns a DataFrame containing the corresponding values for cf_y, cf_d, rho and the change in estimates.
Expand All @@ -1757,12 +1757,18 @@ def sensitivity_benchmark(self, benchmarking_set):
if not set(benchmarking_set) <= set(x_list_long):
raise ValueError(f"benchmarking_set must be a subset of features {str(self._dml_data.x_cols)}. "
f'{str(benchmarking_set)} was passed.')
if fit_args is not None and not isinstance(fit_args, dict):
raise TypeError('fit_args must be a dict. '
f'{str(fit_args)} of type {type(fit_args)} was passed.')

# refit short form of the model
x_list_short = [x for x in x_list_long if x not in benchmarking_set]
dml_short = copy.deepcopy(self)
dml_short._dml_data.x_cols = x_list_short
dml_short.fit()
if fit_args is not None:
dml_short.fit(**fit_args)
else:
dml_short.fit()

benchmark_dict = gain_statistics(dml_long=self, dml_short=dml_short)
df_benchmark = pd.DataFrame(benchmark_dict, index=self._dml_data.d_cols)
Expand Down
15 changes: 14 additions & 1 deletion doubleml/tests/test_exceptions_ext_preds.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLData
from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLIRM, DoubleMLData
from doubleml.datasets import make_irm_data
from doubleml.utils import DMLDummyRegressor, DMLDummyClassifier

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

df_irm = make_irm_data(n_obs=10, dim_x=2, theta=0.5, return_type="DataFrame")
ext_predictions = {"d": {}}

Expand All @@ -21,3 +23,14 @@ def test_qte_external_prediction_exception():
with pytest.raises(NotImplementedError, match=msg):
qte = DoubleMLQTE(DoubleMLData(df_irm, "y", "d"), DMLDummyClassifier(), DMLDummyClassifier())
qte.fit(external_predictions=ext_predictions)


@pytest.mark.ci
def test_sensitivity_benchmark_external_prediction_exception():
msg = "fit_args must be a dict. "
with pytest.raises(TypeError, match=msg):
fit_args = []
irm = DoubleMLIRM(DoubleMLData(df_irm, "y", "d"), RandomForestRegressor(), RandomForestClassifier())
irm.fit()
irm.sensitivity_analysis()
irm.sensitivity_benchmark(benchmarking_set=["X1"], fit_args=fit_args)
63 changes: 62 additions & 1 deletion doubleml/tests/test_sensitivity.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import pytest
import numpy as np
import copy

import doubleml as dml
from sklearn.linear_model import LinearRegression
from doubleml import DoubleMLIRM, DoubleMLData
from doubleml.datasets import make_irm_data
from sklearn.linear_model import LinearRegression, LogisticRegression

from ._utils_doubleml_sensitivity_manual import doubleml_sensitivity_manual, \
doubleml_sensitivity_benchmark_manual


@pytest.fixture(scope="module", params=[["X1"], ["X2"], ["X3"]])
def benchmarking_set(request):
return request.param


@pytest.fixture(scope='module',
params=[1, 3])
def n_rep(request):
Expand Down Expand Up @@ -99,3 +107,56 @@ def test_dml_sensitivity_benchmark(dml_sensitivity_multitreat_fixture):
assert all(dml_sensitivity_multitreat_fixture['benchmark'].index ==
dml_sensitivity_multitreat_fixture['d_cols'])
assert dml_sensitivity_multitreat_fixture['benchmark'].equals(dml_sensitivity_multitreat_fixture['benchmark_manual'])


@pytest.fixture(scope="module")
def test_dml_benchmark_fixture(benchmarking_set, n_rep):
random_state = 42
x, y, d = make_irm_data(n_obs=50, dim_x=5, theta=0, return_type="np.array")

classifier_class = LogisticRegression
regressor_class = LinearRegression

np.random.seed(3141)
dml_data = DoubleMLData.from_arrays(x=x, y=y, d=d)
x_list_long = copy.deepcopy(dml_data.x_cols)
dml_int = DoubleMLIRM(dml_data,
ml_m=classifier_class(random_state=random_state),
ml_g=regressor_class(),
n_folds=2,
n_rep=n_rep)
dml_int.fit(store_predictions=True)
dml_int.sensitivity_analysis()
dml_ext = copy.deepcopy(dml_int)
df_bm = dml_int.sensitivity_benchmark(benchmarking_set=benchmarking_set)

np.random.seed(3141)
dml_data_short = DoubleMLData.from_arrays(x=x, y=y, d=d)
dml_data_short.x_cols = [x for x in x_list_long if x not in benchmarking_set]
dml_short = DoubleMLIRM(dml_data_short,
ml_m=classifier_class(random_state=random_state),
ml_g=regressor_class(),
n_folds=2,
n_rep=n_rep)
dml_short.fit(store_predictions=True)
fit_args = {"external_predictions": {"d": {"ml_m": dml_short.predictions["ml_m"][:, :, 0],
"ml_g0": dml_short.predictions["ml_g0"][:, :, 0],
"ml_g1": dml_short.predictions["ml_g1"][:, :, 0],
}
},
}
dml_ext.sensitivity_analysis()
df_bm_ext = dml_ext.sensitivity_benchmark(benchmarking_set=benchmarking_set, fit_args=fit_args)

res_dict = {"default_benchmark": df_bm,
"external_benchmark": df_bm_ext}

return res_dict


@pytest.mark.ci
def test_dml_sensitivity_external_predictions(test_dml_benchmark_fixture):
assert np.allclose(test_dml_benchmark_fixture["default_benchmark"],
test_dml_benchmark_fixture["external_benchmark"],
rtol=1e-9,
atol=1e-4)