-
Notifications
You must be signed in to change notification settings - Fork 94
Add RMSEs and targets #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -483,8 +507,10 @@ def fit(self, n_jobs_cv=None, store_predictions=False, store_models=False): | |||
|
|||
self._set_score_elements(score_elements, self._i_rep, self._i_treat) | |||
|
|||
# calculate rmses and store predictions and targets of the nuisance models | |||
self._calc_rmses(preds['predictions'], preds['targets']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @SvenKlaassen ,
thanks for this PR! I think adding some diagnostics is really nice.
I think it should be possible to make this a bit more general by using (maybe only a subset of) sklearn's metrics either by letting users pass a callable
for evaluation of the nuisance predictions or by directly supporting the measures. For example, in case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that sklearn's measures have built in some methods to handle exceptions that might occur ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @PhilippBach, that would be helpful. I think I can add this, but I think keeping RMSE as default would be useful, since one would have to specify different metrics for each learner. RMSE is still useful for classifications.
Another option would be a different method which could evaluate the the nuisance function with a metric, but keeping RMSE as default for the summary.
@@ -434,7 +456,7 @@ def __psi_deriv(self): | |||
def __all_se(self): | |||
return self._all_se[self._i_treat, self._i_rep] | |||
|
|||
def fit(self, n_jobs_cv=None, store_predictions=False, store_models=False): | |||
def fit(self, n_jobs_cv=None, store_predictions=True, store_models=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @SvenKlaassen - I think it's a good idea to set the default for store_predictions
to True
👍
Co-authored-by: PhilippBach <[email protected]>
>>> obj_dml_data = dml.DoubleMLData(data, 'y', 'd') | ||
>>> dml_irm_obj = dml.DoubleMLIRM(obj_dml_data, ml_g, ml_m) | ||
>>> dml_irm_obj.fit() | ||
>>> dml_irm_obj.evaluate_learners(metric=mean_absolute_error) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid we get a little problem here if we use callables for classification here, like for example if we run instead:
import numpy as np
import doubleml as dml
from sklearn.metrics import log_loss
from doubleml.datasets import make_irm_data
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
np.random.seed(3141)
ml_g = RandomForestRegressor(n_estimators=100, max_features=20, max_depth=5, min_samples_leaf=2)
ml_m = RandomForestClassifier(n_estimators=100, max_features=20, max_depth=5, min_samples_leaf=2)
data = make_irm_data(theta=0.5, n_obs=500, dim_x=20, return_type='DataFrame')
obj_dml_data = dml.DoubleMLData(data, 'y', 'd')
dml_irm_obj = dml.DoubleMLIRM(obj_dml_data, ml_g, ml_m)
dml_irm_obj.fit()
dml_irm_obj.evaluate_learners(metric=log_loss)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'd have to check whether a learner is a regression or classification learner and then pass (optionally) two callables (one for the regressions, one for the classification tasks); Alternatively, one could pass through a keyword referring to the learner_name
but that's probably going to lead to a messy interface;
I think the default can still be RMSE (or another regression measure) for all nuisance parts, but the option to use a classification measure is probably reasonable, what do you think?
train_preds = list() | ||
train_targets = list() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I see, we're not using the train_targets
right now, right? I guess that's alright and we can leave a in-sample (fold-wise cross-validated) vs. out-of-sample (cross-fitted) evaluation for later. I'm wondering how we can implement this in a clever way... 🤔
fix exception message
|
||
# check metric | ||
if not callable(metric): | ||
raise TypeError('metric should be either a callable. ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the message in 10b180e
We should also demonstrate the use of the new feature(s) in a short example, see DoubleML/doubleml-docs#114 |
Description
Add RMSE evaluations for each nuisance component to the models. The RMSEs can be accessed through the method
.rmses
and are added to thesummary()
.Further, the targets for each component can be accessed through the method
.nuisance_targets
, which returns a dictionary which contains the nuisance targets for each nuisance component (as an array for each repetition and each coefficient).The new method
evaluate_learners
forDoubleML
objects allows to evaluate the nuisance learners for a callable metric which is based on vector-based (shape of(1, n)
) inputsy_pred
andy_true
(see e.g. scikit-learn).Notes
Fix a bug to correctly save models and predictions for the PLIV model #184
PR Checklist
Please fill out this PR checklist (see our contributing guidelines for details).