Skip to content

Commit a4c6c44

Browse files
committed
add return_type tests for pred, targets and rmses
1 parent 1e5f958 commit a4c6c44

File tree

3 files changed

+68
-7
lines changed

3 files changed

+68
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ share/python-wheels/
2525
*.egg-info/
2626
.installed.cfg
2727
*.egg
28+
*.vscode
2829
MANIFEST
2930
*.idea

doubleml/double_ml.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self,
4747

4848
# initialize predictions and target to None which are only stored if method fit is called with store_predictions=True
4949
self._predictions = None
50-
self._nuisance_target = None
50+
self._nuisance_targets = None
5151
self._rmses = None
5252

5353
# initialize models to None which are only stored if method fit is called with store_models=True
@@ -231,11 +231,11 @@ def predictions(self):
231231
return self._predictions
232232

233233
@property
234-
def nuisance_target(self):
234+
def nuisance_targets(self):
235235
"""
236236
The outcome of the nuisance models.
237237
"""
238-
return self._nuisance_target
238+
return self._nuisance_targets
239239

240240
@property
241241
def rmses(self):
@@ -249,7 +249,7 @@ def models(self):
249249
"""
250250
The fitted nuisance models.
251251
"""
252-
return self._model
252+
return self._models
253253

254254
def get_params(self, learner):
255255
"""
@@ -1010,8 +1010,8 @@ def _initialize_boot_arrays(self, n_rep_boot):
10101010
def _initialize_predictions_and_targets(self):
10111011
self._predictions = {learner: np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_coefs), np.nan)
10121012
for learner in self.params_names}
1013-
self._nuisance_target = {learner: np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_coefs), np.nan)
1014-
for learner in self.params_names}
1013+
self._nuisance_targets = {learner: np.full((self._dml_data.n_obs, self.n_rep, self._dml_data.n_coefs), np.nan)
1014+
for learner in self.params_names}
10151015

10161016
def _initialize_rmses(self):
10171017
self._rmses = {learner: np.full((self.n_rep, self._dml_data.n_coefs), np.nan)
@@ -1024,7 +1024,7 @@ def _initialize_models(self):
10241024
def _store_predictions_and_targets(self, preds, targets):
10251025
for learner in self.params_names:
10261026
self._predictions[learner][:, self._i_rep, self._i_treat] = preds[learner]
1027-
self._nuisance_target[learner][:, self._i_rep, self._i_treat] = targets[learner]
1027+
self._nuisance_targets[learner][:, self._i_rep, self._i_treat] = targets[learner]
10281028

10291029
def _calc_rmses(self, preds, targets):
10301030
for learner in self.params_names:

doubleml/tests/test_doubleml_return_types.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,63 @@ def test_stored_models():
165165

166166
assert np.all([isinstance(mdl, plr_dml1.learner['ml_l'].__class__) for mdl in plr_dml1.models['ml_l']['d'][0]])
167167
assert np.all([isinstance(mdl, plr_dml1.learner['ml_m'].__class__) for mdl in plr_dml1.models['ml_m']['d'][0]])
168+
169+
170+
@pytest.mark.ci
171+
def test_stored_predictions():
172+
assert plr_dml1.predictions['ml_l'].shape == (n_obs, n_rep, n_treat)
173+
assert plr_dml1.predictions['ml_m'].shape == (n_obs, n_rep, n_treat)
174+
175+
assert pliv_dml1.predictions['ml_l'].shape == (n_obs, n_rep, n_treat)
176+
assert pliv_dml1.predictions['ml_m'].shape == (n_obs, n_rep, n_treat)
177+
assert pliv_dml1.predictions['ml_r'].shape == (n_obs, n_rep, n_treat)
178+
179+
assert irm_dml1.predictions['ml_g0'].shape == (n_obs, n_rep, n_treat)
180+
assert irm_dml1.predictions['ml_g1'].shape == (n_obs, n_rep, n_treat)
181+
assert irm_dml1.predictions['ml_m'].shape == (n_obs, n_rep, n_treat)
182+
183+
assert iivm_dml1.predictions['ml_g0'].shape == (n_obs, n_rep, n_treat)
184+
assert iivm_dml1.predictions['ml_g1'].shape == (n_obs, n_rep, n_treat)
185+
assert iivm_dml1.predictions['ml_m'].shape == (n_obs, n_rep, n_treat)
186+
assert iivm_dml1.predictions['ml_r0'].shape == (n_obs, n_rep, n_treat)
187+
assert iivm_dml1.predictions['ml_r1'].shape == (n_obs, n_rep, n_treat)
188+
189+
190+
@pytest.mark.ci
191+
def test_stored_nuisance_targets():
192+
assert plr_dml1.nuisance_targets['ml_l'].shape == (n_obs, n_rep, n_treat)
193+
assert plr_dml1.nuisance_targets['ml_m'].shape == (n_obs, n_rep, n_treat)
194+
195+
assert pliv_dml1.nuisance_targets['ml_l'].shape == (n_obs, n_rep, n_treat)
196+
assert pliv_dml1.nuisance_targets['ml_m'].shape == (n_obs, n_rep, n_treat)
197+
assert pliv_dml1.nuisance_targets['ml_r'].shape == (n_obs, n_rep, n_treat)
198+
199+
assert irm_dml1.nuisance_targets['ml_g0'].shape == (n_obs, n_rep, n_treat)
200+
assert irm_dml1.nuisance_targets['ml_g1'].shape == (n_obs, n_rep, n_treat)
201+
assert irm_dml1.nuisance_targets['ml_m'].shape == (n_obs, n_rep, n_treat)
202+
203+
assert iivm_dml1.nuisance_targets['ml_g0'].shape == (n_obs, n_rep, n_treat)
204+
assert iivm_dml1.nuisance_targets['ml_g1'].shape == (n_obs, n_rep, n_treat)
205+
assert iivm_dml1.nuisance_targets['ml_m'].shape == (n_obs, n_rep, n_treat)
206+
assert iivm_dml1.nuisance_targets['ml_r0'].shape == (n_obs, n_rep, n_treat)
207+
assert iivm_dml1.nuisance_targets['ml_r1'].shape == (n_obs, n_rep, n_treat)
208+
209+
210+
@pytest.mark.ci
211+
def test_rmses():
212+
assert plr_dml1.rmses['ml_l'].shape == (n_rep, n_treat)
213+
assert plr_dml1.rmses['ml_m'].shape == (n_rep, n_treat)
214+
215+
assert pliv_dml1.rmses['ml_l'].shape == (n_rep, n_treat)
216+
assert pliv_dml1.rmses['ml_m'].shape == (n_rep, n_treat)
217+
assert pliv_dml1.rmses['ml_r'].shape == (n_rep, n_treat)
218+
219+
assert irm_dml1.rmses['ml_g0'].shape == (n_rep, n_treat)
220+
assert irm_dml1.rmses['ml_g1'].shape == (n_rep, n_treat)
221+
assert irm_dml1.rmses['ml_m'].shape == (n_rep, n_treat)
222+
223+
assert iivm_dml1.rmses['ml_g0'].shape == (n_rep, n_treat)
224+
assert iivm_dml1.rmses['ml_g1'].shape == (n_rep, n_treat)
225+
assert iivm_dml1.rmses['ml_m'].shape == (n_rep, n_treat)
226+
assert iivm_dml1.rmses['ml_r0'].shape == (n_rep, n_treat)
227+
assert iivm_dml1.rmses['ml_r1'].shape == (n_rep, n_treat)

0 commit comments

Comments
 (0)