Skip to content

Commit df6be73

Browse files
committed
fix nuisance_evaluation example
1 parent f64630c commit df6be73

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

doubleml/double_ml.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,7 @@ def evaluate_learners(self, learners=None, metric=_rmse):
10641064
where ``n`` specifies the number of observations. Remark that some models like IRM are
10651065
not able to provide all values for ``y_true`` for all learners and might contain
10661066
some ``nan`` values in the target vector.
1067-
Default is the euclidean distance.
1067+
Default is the root-mean-square error.
10681068
10691069
Returns
10701070
-------
@@ -1085,10 +1085,13 @@ def evaluate_learners(self, learners=None, metric=_rmse):
10851085
>>> obj_dml_data = dml.DoubleMLData(data, 'y', 'd')
10861086
>>> dml_irm_obj = dml.DoubleMLIRM(obj_dml_data, ml_g, ml_m)
10871087
>>> dml_irm_obj.fit()
1088-
>>> dml_irm_obj.evaluate_learners(metric=mean_absolute_error)
1089-
{'ml_g0': array([[1.13318973]]),
1090-
'ml_g1': array([[0.91659939]]),
1091-
'ml_m': array([[0.36350912]])}
1088+
>>> def mae(y_true, y_pred):
1089+
>>> subset = np.logical_not(np.isnan(y_true))
1090+
>>> return mean_absolute_error(y_true[subset], y_pred[subset])
1091+
>>> dml_irm_obj.evaluate_learners(metric=mae)
1092+
{'ml_g0': array([[0.85974356]]),
1093+
'ml_g1': array([[0.85280376]]),
1094+
'ml_m': array([[0.35365143]])}
10921095
"""
10931096
# if no learners are provided try to evaluate all learners
10941097
if learners is None:

0 commit comments

Comments
 (0)