Fix for [Bug]: KeyError in DoubleMLPLIV.fit() with multiple instruments and store_predictions=True #185
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
In the case of multiple instruments, the function DoubleMLPLIV.fit() throws an error when executed with the parameter 'store_predictions=True', as reported in [Bug]: KeyError in DoubleMLPLIV.fit() with multiple instruments and store_predictions=True.
This occurs because the list of learner names is defined as follows in the case of multiple instruments:
line 269 in double_ml_pliv.py: param_names = ['ml_l', 'ml_r'] + ['ml_m_' + z_col for z_col in self._dml_data.z_cols]
The predictions for the nuisance functions are however saved using the keys 'ml_l', 'ml_m', 'ml_r', 'ml_g'.
To store the predictions if store_predictions=True in DoubleMLPLIV.fit(), we use preds['predictions'] but iterate over param_names. Hence, the reported KeyError occurred.
To solve this error, I propose to add the predictions and models made for each instrument in the case of multiple instruments, i.e. for the list of learners
['ml_m_' + z_col for z_col in self._dml_data.z_cols]
,to the dictionaries preds['predictions'] and preds['models'], respectively.
The proposed code passes all (unit) tests.
Reference to Issues or PRs
[Bug]: KeyError in DoubleMLPLIV.fit() with multiple instruments and store_predictions=True #184