Skip to content

Commit 369be3e

Browse files
committed
add warnings for multiple treatments
1 parent c9fc8b6 commit 369be3e

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

doubleml/double_ml_irm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ def cate(self, basis):
351351
raise ValueError('Invalid score ' + self.score + '. ' +
352352
'Valid score ' + ' or '.join(valid_score) + '.')
353353

354+
if self.n_rep != 1:
355+
raise NotImplementedError('Only implemented for one repetition. ' +
356+
f'Number of repetitions is {str(self.n_rep)}.')
357+
354358
# define the orthogonal signal
355359
orth_signal = self.psi_elements['psi_b'].reshape(-1)
356360
# fit the best linear predictor
@@ -378,6 +382,10 @@ def gate(self, groups):
378382
raise ValueError('Invalid score ' + self.score + '. ' +
379383
'Valid score ' + ' or '.join(valid_score) + '.')
380384

385+
if self.n_rep != 1:
386+
raise NotImplementedError('Only implemented for one repetition. ' +
387+
f'Number of repetitions is {str(self.n_rep)}.')
388+
381389
if not isinstance(groups, pd.DataFrame):
382390
raise TypeError('Groups must be of DataFrame type. '
383391
f'Groups of type {str(type(groups))} was passed.')

doubleml/tests/test_doubleml_exceptions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,19 @@ def test_doubleml_exception_gate():
752752
with pytest.raises(ValueError, match=msg):
753753
dml_irm_obj.gate(groups=2)
754754

755+
dml_irm_obj = DoubleMLIRM(dml_data_irm,
756+
ml_g=Lasso(),
757+
ml_m=LogisticRegression(),
758+
trimming_threshold=0.05,
759+
n_folds=5,
760+
score='ATE',
761+
n_rep=2)
762+
dml_irm_obj.fit()
763+
764+
msg = 'Only implemented for one repetition. Number of repetitions is 2.'
765+
with pytest.raises(NotImplementedError, match=msg):
766+
dml_irm_obj.gate(groups=2)
767+
755768

756769
@pytest.mark.ci
757770
def test_doubleml_exception_cate():
@@ -766,3 +779,15 @@ def test_doubleml_exception_cate():
766779
msg = 'Invalid score ATTE. Valid score ATE.'
767780
with pytest.raises(ValueError, match=msg):
768781
dml_irm_obj.cate(basis=2)
782+
783+
dml_irm_obj = DoubleMLIRM(dml_data_irm,
784+
ml_g=Lasso(),
785+
ml_m=LogisticRegression(),
786+
trimming_threshold=0.05,
787+
n_folds=5,
788+
score='ATE',
789+
n_rep=2)
790+
dml_irm_obj.fit()
791+
msg = 'Only implemented for one repetition. Number of repetitions is 2.'
792+
with pytest.raises(NotImplementedError, match=msg):
793+
dml_irm_obj.cate(basis=2)

0 commit comments

Comments
 (0)