Skip to content

Commit 72fe6b8

Browse files
committed
add evaluate_learner()
1 parent db6845e commit 72fe6b8

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

doubleml/double_ml.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44

55
from sklearn.base import is_regressor, is_classifier
6+
from sklearn.metrics import mean_squared_error
67

78
from scipy.stats import norm
89

@@ -1038,6 +1039,71 @@ def _store_models(self, models):
10381039
for learner in self.params_names:
10391040
self._models[learner][self._dml_data.d_cols[self._i_treat]][self._i_rep] = models[learner]
10401041

1042+
def evaluate_learners(self, learners=None, metric=mean_squared_error):
1043+
"""
1044+
Evaluate fitted learners for DoubleML models on crossvalidated predicitons.
1045+
1046+
Parameters
1047+
----------
1048+
learners : list
1049+
A list of strings which correspond to the nuisance functions of the model.
1050+
1051+
metric : callable
1052+
A callable function with inputs ``y_pred`` and ``y_true``.
1053+
Default is the euclidean distance.
1054+
1055+
Returns
1056+
-------
1057+
dist : dict
1058+
A dictionary containing the evaluated metric for each learner.
1059+
1060+
Examples
1061+
--------
1062+
>>> import numpy as np
1063+
>>> import doubleml as dml
1064+
>>> from sklearn.metrics import mean_absolute_error
1065+
>>> from doubleml.datasets import make_irm_data
1066+
>>> from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
1067+
>>> np.random.seed(3141)
1068+
>>> ml_g = RandomForestRegressor(n_estimators=100, max_features=20, max_depth=5, min_samples_leaf=2)
1069+
>>> ml_m = RandomForestClassifier(n_estimators=100, max_features=20, max_depth=5, min_samples_leaf=2)
1070+
>>> data = make_irm_data(theta=0.5, n_obs=500, dim_x=20, return_type='DataFrame')
1071+
>>> obj_dml_data = dml.DoubleMLData(data, 'y', 'd')
1072+
>>> dml_irm_obj = dml.DoubleMLIRM(obj_dml_data, ml_g, ml_m)
1073+
>>> dml_irm_obj.fit()
1074+
>>> dml_irm_obj.evaluate_learners(metric=mean_absolute_error)
1075+
{'ml_g0': array([[1.13318973]]),
1076+
'ml_g1': array([[0.91659939]]),
1077+
'ml_m': array([[0.36350912]])}
1078+
"""
1079+
# if no learners are provided try to evaluate all learners
1080+
if learners is None:
1081+
learners = self.params_names
1082+
1083+
# check metric
1084+
if not callable(metric):
1085+
raise TypeError('metric should be either a callable. '
1086+
'%r was passed.' % metric)
1087+
1088+
if all(learner in self.params_names for learner in learners):
1089+
if self.nuisance_targets is None:
1090+
raise ValueError('Apply fit() before evaluate_learners().')
1091+
else:
1092+
dist = {learner: np.full((self.n_rep, self._dml_data.n_coefs), np.nan)
1093+
for learner in learners}
1094+
for learner in learners:
1095+
for rep in range(self.n_rep):
1096+
for coef_idx in range(self._dml_data.n_coefs):
1097+
res = metric(y_pred=self.predictions[learner][:, rep, coef_idx].reshape(1, -1),
1098+
y_true=self.nuisance_targets[learner][:, rep, coef_idx].reshape(1, -1))
1099+
if not np.isfinite(res):
1100+
raise ValueError(f'Evaluation from learner {str(learner)} is not finite.')
1101+
dist[learner][rep, coef_idx] = res
1102+
return dist
1103+
else:
1104+
raise ValueError(f'The learners have to be a subset of {str(self.params_names)}. '
1105+
f'Learners {str(learners)} provided.')
1106+
10411107
def draw_sample_splitting(self):
10421108
"""
10431109
Draw sample splitting for DoubleML models.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
import numpy as np
3+
import doubleml as dml
4+
from sklearn.metrics import mean_absolute_error, mean_squared_error
5+
from doubleml.datasets import make_irm_data
6+
from sklearn.base import clone
7+
8+
from sklearn.linear_model import LogisticRegression, LinearRegression
9+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
10+
11+
12+
np.random.seed(3141)
13+
data = make_irm_data(theta=0.5, n_obs=200, dim_x=5, return_type='DataFrame')
14+
obj_dml_data = dml.DoubleMLData(data, 'y', 'd')
15+
16+
17+
@pytest.fixture(scope='module',
18+
params=[[LinearRegression(),
19+
LogisticRegression(solver='lbfgs', max_iter=250)],
20+
[RandomForestRegressor(max_depth=2, n_estimators=10),
21+
RandomForestClassifier(max_depth=2, n_estimators=10)]])
22+
def learner(request):
23+
return request.param
24+
25+
26+
@pytest.fixture(scope='module',
27+
params=[1, 5])
28+
def n_rep(request):
29+
return request.param
30+
31+
32+
@pytest.fixture(scope='module',
33+
params=['dml1', 'dml2'])
34+
def dml_procedure(request):
35+
return request.param
36+
37+
38+
@pytest.fixture(scope='module',
39+
params=[mean_absolute_error, mean_squared_error])
40+
def metric(request):
41+
return request.param
42+
43+
44+
@pytest.fixture(scope='module',
45+
params=[0.01, 0.05])
46+
def trimming_threshold(request):
47+
return request.param
48+
49+
50+
@pytest.fixture(scope='module')
51+
def dml_irm_eval_learner_fixture(metric, learner, dml_procedure, trimming_threshold, n_rep):
52+
# Set machine learning methods for m & g
53+
ml_g = clone(learner[0])
54+
ml_m = clone(learner[1])
55+
56+
np.random.seed(3141)
57+
dml_irm_obj = dml.DoubleMLIRM(obj_dml_data,
58+
ml_g, ml_m,
59+
n_folds=2,
60+
n_rep=n_rep,
61+
dml_procedure=dml_procedure,
62+
trimming_threshold=trimming_threshold)
63+
dml_irm_obj.fit()
64+
res = dml_irm_obj.evaluate_learners(metric=metric)
65+
return res
66+
67+
68+
@pytest.mark.ci
69+
def test_dml_irm_eval_learner(dml_irm_eval_learner_fixture, n_rep):
70+
assert dml_irm_eval_learner_fixture['ml_g0'].shape == (n_rep, 1)
71+
assert dml_irm_eval_learner_fixture['ml_g1'].shape == (n_rep, 1)
72+
assert dml_irm_eval_learner_fixture['ml_m'].shape == (n_rep, 1)

doubleml/tests/test_doubleml_exceptions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,3 +791,35 @@ def test_doubleml_exception_cate():
791791
msg = 'Only implemented for one repetition. Number of repetitions is 2.'
792792
with pytest.raises(NotImplementedError, match=msg):
793793
dml_irm_obj.cate(basis=2)
794+
795+
796+
@pytest.mark.ci
797+
def test_double_ml_exception_evaluate_learner():
798+
dml_irm_obj = DoubleMLIRM(dml_data_irm,
799+
ml_g=Lasso(),
800+
ml_m=LogisticRegression(),
801+
trimming_threshold=0.05,
802+
n_folds=5,
803+
score='ATTE')
804+
805+
msg = r'Apply fit\(\) before evaluate_learners\(\).'
806+
with pytest.raises(ValueError, match=msg):
807+
dml_irm_obj.evaluate_learners()
808+
809+
dml_irm_obj.fit()
810+
811+
msg = "metric should be either a callable. 'mse' was passed."
812+
with pytest.raises(TypeError, match=msg):
813+
dml_irm_obj.evaluate_learners(metric="mse")
814+
815+
msg = (r"The learners have to be a subset of \['ml_g0', 'ml_g1', 'ml_m'\]. "
816+
r"Learners \['ml_g', 'ml_m'\] provided.")
817+
with pytest.raises(ValueError, match=msg):
818+
dml_irm_obj.evaluate_learners(learners=['ml_g', 'ml_m'])
819+
820+
msg = 'Evaluation from learner ml_g0 is not finite.'
821+
822+
def eval_fct(y_pred, y_true):
823+
return np.nan
824+
with pytest.raises(ValueError, match=msg):
825+
dml_irm_obj.evaluate_learners(metric=eval_fct)

0 commit comments

Comments
 (0)