@@ -165,3 +165,63 @@ def test_stored_models():
165
165
166
166
assert np .all ([isinstance (mdl , plr_dml1 .learner ['ml_l' ].__class__ ) for mdl in plr_dml1 .models ['ml_l' ]['d' ][0 ]])
167
167
assert np .all ([isinstance (mdl , plr_dml1 .learner ['ml_m' ].__class__ ) for mdl in plr_dml1 .models ['ml_m' ]['d' ][0 ]])
168
+
169
+
170
+ @pytest .mark .ci
171
+ def test_stored_predictions ():
172
+ assert plr_dml1 .predictions ['ml_l' ].shape == (n_obs , n_rep , n_treat )
173
+ assert plr_dml1 .predictions ['ml_m' ].shape == (n_obs , n_rep , n_treat )
174
+
175
+ assert pliv_dml1 .predictions ['ml_l' ].shape == (n_obs , n_rep , n_treat )
176
+ assert pliv_dml1 .predictions ['ml_m' ].shape == (n_obs , n_rep , n_treat )
177
+ assert pliv_dml1 .predictions ['ml_r' ].shape == (n_obs , n_rep , n_treat )
178
+
179
+ assert irm_dml1 .predictions ['ml_g0' ].shape == (n_obs , n_rep , n_treat )
180
+ assert irm_dml1 .predictions ['ml_g1' ].shape == (n_obs , n_rep , n_treat )
181
+ assert irm_dml1 .predictions ['ml_m' ].shape == (n_obs , n_rep , n_treat )
182
+
183
+ assert iivm_dml1 .predictions ['ml_g0' ].shape == (n_obs , n_rep , n_treat )
184
+ assert iivm_dml1 .predictions ['ml_g1' ].shape == (n_obs , n_rep , n_treat )
185
+ assert iivm_dml1 .predictions ['ml_m' ].shape == (n_obs , n_rep , n_treat )
186
+ assert iivm_dml1 .predictions ['ml_r0' ].shape == (n_obs , n_rep , n_treat )
187
+ assert iivm_dml1 .predictions ['ml_r1' ].shape == (n_obs , n_rep , n_treat )
188
+
189
+
190
+ @pytest .mark .ci
191
+ def test_stored_nuisance_targets ():
192
+ assert plr_dml1 .nuisance_targets ['ml_l' ].shape == (n_obs , n_rep , n_treat )
193
+ assert plr_dml1 .nuisance_targets ['ml_m' ].shape == (n_obs , n_rep , n_treat )
194
+
195
+ assert pliv_dml1 .nuisance_targets ['ml_l' ].shape == (n_obs , n_rep , n_treat )
196
+ assert pliv_dml1 .nuisance_targets ['ml_m' ].shape == (n_obs , n_rep , n_treat )
197
+ assert pliv_dml1 .nuisance_targets ['ml_r' ].shape == (n_obs , n_rep , n_treat )
198
+
199
+ assert irm_dml1 .nuisance_targets ['ml_g0' ].shape == (n_obs , n_rep , n_treat )
200
+ assert irm_dml1 .nuisance_targets ['ml_g1' ].shape == (n_obs , n_rep , n_treat )
201
+ assert irm_dml1 .nuisance_targets ['ml_m' ].shape == (n_obs , n_rep , n_treat )
202
+
203
+ assert iivm_dml1 .nuisance_targets ['ml_g0' ].shape == (n_obs , n_rep , n_treat )
204
+ assert iivm_dml1 .nuisance_targets ['ml_g1' ].shape == (n_obs , n_rep , n_treat )
205
+ assert iivm_dml1 .nuisance_targets ['ml_m' ].shape == (n_obs , n_rep , n_treat )
206
+ assert iivm_dml1 .nuisance_targets ['ml_r0' ].shape == (n_obs , n_rep , n_treat )
207
+ assert iivm_dml1 .nuisance_targets ['ml_r1' ].shape == (n_obs , n_rep , n_treat )
208
+
209
+
210
+ @pytest .mark .ci
211
+ def test_rmses ():
212
+ assert plr_dml1 .rmses ['ml_l' ].shape == (n_rep , n_treat )
213
+ assert plr_dml1 .rmses ['ml_m' ].shape == (n_rep , n_treat )
214
+
215
+ assert pliv_dml1 .rmses ['ml_l' ].shape == (n_rep , n_treat )
216
+ assert pliv_dml1 .rmses ['ml_m' ].shape == (n_rep , n_treat )
217
+ assert pliv_dml1 .rmses ['ml_r' ].shape == (n_rep , n_treat )
218
+
219
+ assert irm_dml1 .rmses ['ml_g0' ].shape == (n_rep , n_treat )
220
+ assert irm_dml1 .rmses ['ml_g1' ].shape == (n_rep , n_treat )
221
+ assert irm_dml1 .rmses ['ml_m' ].shape == (n_rep , n_treat )
222
+
223
+ assert iivm_dml1 .rmses ['ml_g0' ].shape == (n_rep , n_treat )
224
+ assert iivm_dml1 .rmses ['ml_g1' ].shape == (n_rep , n_treat )
225
+ assert iivm_dml1 .rmses ['ml_m' ].shape == (n_rep , n_treat )
226
+ assert iivm_dml1 .rmses ['ml_r0' ].shape == (n_rep , n_treat )
227
+ assert iivm_dml1 .rmses ['ml_r1' ].shape == (n_rep , n_treat )
0 commit comments