Skip to content

Add a DoubleML Framework Class #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 116 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
de5ee08
move tests for dummy learners to utils
SvenKlaassen Dec 21, 2023
69282b0
move and rename utils
SvenKlaassen Dec 21, 2023
da17874
move blp and policytree to utils
SvenKlaassen Dec 22, 2023
b8764d0
update utils init
SvenKlaassen Dec 22, 2023
441c7fa
first iv_model submodule
SvenKlaassen Dec 22, 2023
b185bd8
fis plr and irm blp model
SvenKlaassen Dec 22, 2023
f3224b8
rename sensitivity not implemented error
SvenKlaassen Dec 22, 2023
a6b47a2
move pliv tests to submodule
SvenKlaassen Dec 22, 2023
93c246b
fix unit tests
SvenKlaassen Dec 22, 2023
969b240
create did_model submodule
SvenKlaassen Dec 22, 2023
01246d3
shorten submodule names
SvenKlaassen Dec 22, 2023
cdc2354
updated submodule description
SvenKlaassen Dec 22, 2023
53786fa
create plr submodule
SvenKlaassen Dec 22, 2023
6cb1c7a
create irm submodule
SvenKlaassen Dec 22, 2023
6c22f02
move tests to irm submodule
SvenKlaassen Dec 22, 2023
d679446
rename plr to plm and move pliv to plm submodule
SvenKlaassen Dec 22, 2023
5cdfad2
move iivm to irm submodule
SvenKlaassen Dec 22, 2023
62be459
remove iv submodule
SvenKlaassen Dec 22, 2023
7e4f1c5
remove doubleml from test names
SvenKlaassen Dec 29, 2023
b9bf726
Reduce deprication warnings did
SvenKlaassen Dec 29, 2023
54d3863
fix tests
SvenKlaassen Jan 2, 2024
fffa952
fix codacy issues
SvenKlaassen Jan 2, 2024
ee6f755
Update _utils_irm_manual.py
SvenKlaassen Jan 2, 2024
f2b3113
simplify checks and initialization in doubleml fit
SvenKlaassen Jan 2, 2024
557741e
simplify doubleml fit method
SvenKlaassen Jan 2, 2024
6952ced
start removing apply cross-fitting
SvenKlaassen Jan 2, 2024
0347057
update resampling for stratification
SvenKlaassen Jan 2, 2024
0698a48
fix unit tests for sample splits
SvenKlaassen Jan 2, 2024
c684ee2
Merge branch 's-restructure-doubleml' into s-remove-apply-crossfitting
SvenKlaassen Jan 2, 2024
5994c1a
fix unit tests
SvenKlaassen Jan 3, 2024
22e16a4
fix unit tests
SvenKlaassen Jan 3, 2024
b7127c2
fix propensity score adjustments
SvenKlaassen Jan 3, 2024
8d34a73
update trimming for external predictions
SvenKlaassen Jan 5, 2024
5c428a6
Merge branch 's-restructure-doubleml' into s-remove-apply-crossfitting
SvenKlaassen Jan 12, 2024
4b3090b
Remove dml_procedure from doubleml
SvenKlaassen Jan 14, 2024
4254da2
remove dml_procedure from did classes
SvenKlaassen Jan 14, 2024
6b41b63
fix did_cs tests
SvenKlaassen Jan 14, 2024
bfacfb4
adapt did tests
SvenKlaassen Jan 15, 2024
57cbd14
remove dml1 from did_manual
SvenKlaassen Jan 15, 2024
7e78dfe
remove dml_procedure from plr model
SvenKlaassen Jan 15, 2024
9d1fb0d
remove dml_procedure from pliv model
SvenKlaassen Jan 15, 2024
e7b9129
remove dml_procedure from irm model
SvenKlaassen Jan 15, 2024
832967a
remove dml_procedure from iivm model
SvenKlaassen Jan 15, 2024
fc870d1
remove dml_procedure from cvar model
SvenKlaassen Jan 15, 2024
5732392
remove dml_procedure from pq model
SvenKlaassen Jan 15, 2024
94359fe
fix pq unit tests
SvenKlaassen Jan 15, 2024
00adaf6
remove dml_procedure from lpq model
SvenKlaassen Jan 15, 2024
2bc2894
remove dml_procedure from qte model
SvenKlaassen Jan 15, 2024
4cb1144
remove dml_procedure from remaining tests
SvenKlaassen Jan 15, 2024
20a3f28
move manual policy tree
SvenKlaassen Jan 15, 2024
0823821
Create test_var_est_and_aggregation.py
SvenKlaassen Jan 19, 2024
770fd48
add doubleml framework class
SvenKlaassen Jan 19, 2024
348f755
Update __init__.py
SvenKlaassen Jan 19, 2024
1f8fd51
add unit tests for framework
SvenKlaassen Jan 19, 2024
7cd9dc7
add coverage unit tests
SvenKlaassen Jan 19, 2024
6ff117a
Update __init__.py
SvenKlaassen Jan 19, 2024
ef10764
add clustering flag to framework
SvenKlaassen Jan 19, 2024
60e6c6e
adapt scaling factors
SvenKlaassen Jan 19, 2024
4c97aff
add cluster option to framework
SvenKlaassen Jan 19, 2024
9eeab72
fix doubleml param est with cluster
SvenKlaassen Jan 22, 2024
85dda2d
add assertion error message and remove cluster not implemented
SvenKlaassen Jan 22, 2024
5790ec3
add first exception tests
SvenKlaassen Jan 22, 2024
0f9bf30
fix penalty in logistic reg test
SvenKlaassen Jan 22, 2024
81aa251
extend input exception tests
SvenKlaassen Jan 22, 2024
7d930a4
add consistency checks on doubleml_framework
SvenKlaassen Jan 23, 2024
cf9b5de
extend unit tests on framework exceptions
SvenKlaassen Jan 23, 2024
46f0144
Update test_framework_exceptions.py
SvenKlaassen Jan 23, 2024
017d4d6
remove verbose from _fit_and_predict due to sklearn change
SvenKlaassen Jan 23, 2024
91330d0
further remove verbose from _fit_predict()
SvenKlaassen Jan 23, 2024
938142d
Adapt _fit_and_predict for earlier sklearn versions (1.3.X)
SvenKlaassen Jan 23, 2024
09fb09a
Merge branch 's-remove-apply-crossfitting' into s-add-dml-framework
SvenKlaassen Jan 29, 2024
73e2432
fix unit test utils
SvenKlaassen Jan 29, 2024
9c51b93
Merge branch 'main' into s-restructure-doubleml
SvenKlaassen Feb 2, 2024
7bb829c
move test for weighted irm
SvenKlaassen Feb 2, 2024
12cad97
Merge branch 'main' into s-restructure-doubleml
SvenKlaassen Feb 2, 2024
ece1b45
update github workflow actions to node.js 20
SvenKlaassen Feb 2, 2024
618c260
Merge branch 's-restructure-doubleml' into s-remove-apply-crossfitting
SvenKlaassen Feb 2, 2024
adbfa19
Merge branch 's-remove-apply-crossfitting' into s-add-dml-framework
SvenKlaassen Feb 2, 2024
ff9a5f9
add tstats and pvals to framework
SvenKlaassen Feb 9, 2024
ca98d33
Create test_framework_pval_corrections.py
SvenKlaassen Feb 9, 2024
016d36c
add p_adjust (basic version) with unit tests
SvenKlaassen Feb 9, 2024
b4e4959
p-val median aggregation
SvenKlaassen Feb 9, 2024
8d6a4ca
exception test for p_adjust
SvenKlaassen Feb 9, 2024
d65b8f8
implemement romano wolf method
SvenKlaassen Feb 9, 2024
2bdf249
update p_adjust unit test and ass all_ p_values to p_adjust
SvenKlaassen Feb 9, 2024
9eae7a8
extend framework properties and add shape descriptions
SvenKlaassen Feb 9, 2024
50d9fdc
fix typo and tests
SvenKlaassen Feb 9, 2024
435e57a
remove boot_coef
SvenKlaassen Feb 11, 2024
546a512
remove boot_theta from boot_manual
SvenKlaassen Feb 11, 2024
6c7e7d2
remove boot_coef from did tests
SvenKlaassen Feb 11, 2024
4a3278c
remove boot_coefs from plm tests
SvenKlaassen Feb 12, 2024
270e77b
remove boot coefs from irm model
SvenKlaassen Feb 12, 2024
216f9b9
remove boot_coefs from return type tests
SvenKlaassen Feb 12, 2024
00f2fa0
Merge branch 's-remove-apply-crossfitting' into s-add-dml-framework
SvenKlaassen Feb 12, 2024
da7ebfe
add construct framework method
SvenKlaassen Feb 12, 2024
1057279
add framework attribute to DoubleML class
SvenKlaassen Feb 12, 2024
e9ce321
add boot_method, boot_t_stat and n_rep_boot attributes to DoubleMLFra…
SvenKlaassen Feb 12, 2024
20679bf
update bootstrap and attributes to framework object
SvenKlaassen Feb 12, 2024
5a807b7
adjust boot_t tests for did models
SvenKlaassen Feb 12, 2024
9f4ed63
adjust boot_t tests for irm models
SvenKlaassen Feb 12, 2024
ea5adb6
adjust boot_t tests for plm models
SvenKlaassen Feb 12, 2024
740946a
fix model default tests
SvenKlaassen Feb 12, 2024
70f9495
remove bootstrap arrays from DoubleML class
SvenKlaassen Feb 12, 2024
670577f
exchange confint in DoubleML class
SvenKlaassen Feb 12, 2024
76389f4
add seperate exceptions for DoubleMLQTE
SvenKlaassen Feb 13, 2024
97c4bfe
extend exception message for joint ci
SvenKlaassen Feb 13, 2024
6ae7c9b
move p_adjust from DoubleML to DoubleMLFramework
SvenKlaassen Feb 13, 2024
84b1c28
fix QTE score docstring
SvenKlaassen Feb 13, 2024
99c84eb
remove setters and use kwargs for DoubleMLQTE
SvenKlaassen Feb 13, 2024
9b08a91
update DoubleMLQTE with basic framework
SvenKlaassen Feb 13, 2024
b76bdaf
Update test_qte_exceptions.py
SvenKlaassen Feb 13, 2024
d033d86
update framework docstrings
SvenKlaassen Feb 13, 2024
d0b4722
remove boot_coef from qte
SvenKlaassen Feb 13, 2024
57720b2
update DoubleMLQTE for Framework
SvenKlaassen Feb 13, 2024
9900656
Fix DoubleMLQTE summary for unfitted model
SvenKlaassen Feb 13, 2024
2864714
update manual bootstrap boot_qte to correct shape
SvenKlaassen Feb 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Initialize CodeQL
uses: github/codeql-action/init@v2
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
queries: +security-and-quality

- name: Autobuild
uses: github/codeql-action/autobuild@v2
uses: github/codeql-action/autobuild@v3

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{ matrix.language }}"
6 changes: 3 additions & 3 deletions .github/workflows/deploy_pkg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
persist-credentials: false

- name: Install python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.8'

Expand All @@ -32,7 +32,7 @@ jobs:
pip install wheel
python setup.py sdist bdist_wheel

- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4
with:
name: DoubleML-pkg
path: dist/
4 changes: 2 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ jobs:
- {os: 'ubuntu-latest', python-version: '3.11'}

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
fetch-depth: 2
- name: Set up Python ${{ matrix.config.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.config.python-version }}
- name: Install OpenMP runtime for unit tests with xgboost learners
Expand Down
33 changes: 19 additions & 14 deletions doubleml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
from pkg_resources import get_distribution

from .double_ml_plr import DoubleMLPLR
from .double_ml_pliv import DoubleMLPLIV
from .double_ml_irm import DoubleMLIRM
from .double_ml_iivm import DoubleMLIIVM
from .double_ml_framework import concat
from .double_ml_framework import DoubleMLFramework
from .plm.plr import DoubleMLPLR
from .plm.pliv import DoubleMLPLIV
from .irm.irm import DoubleMLIRM
from .irm.iivm import DoubleMLIIVM
from .double_ml_data import DoubleMLData, DoubleMLClusterData
from .double_ml_blp import DoubleMLBLP
from .double_ml_did import DoubleMLDID
from .double_ml_did_cs import DoubleMLDIDCS
from .double_ml_qte import DoubleMLQTE
from .double_ml_pq import DoubleMLPQ
from .double_ml_lpq import DoubleMLLPQ
from .double_ml_cvar import DoubleMLCVAR
from .double_ml_policytree import DoubleMLPolicyTree
from .did.did import DoubleMLDID
from .did.did_cs import DoubleMLDIDCS
from .irm.qte import DoubleMLQTE
from .irm.pq import DoubleMLPQ
from .irm.lpq import DoubleMLLPQ
from .irm.cvar import DoubleMLCVAR

__all__ = ['DoubleMLPLR',
from .utils.blp import DoubleMLBLP
from .utils.policytree import DoubleMLPolicyTree

__all__ = ['concat',
'DoubleMLFramework',
'DoubleMLPLR',
'DoubleMLPLIV',
'DoubleMLIRM',
'DoubleMLIIVM',
'DoubleMLData',
'DoubleMLClusterData',
'DoubleMLBLP',
'DoubleMLDID',
'DoubleMLDIDCS',
'DoubleMLPQ',
'DoubleMLQTE',
'DoubleMLLPQ',
'DoubleMLCVAR',
'DoubleMLBLP',
'DoubleMLPolicyTree']

__version__ = get_distribution('doubleml').version
11 changes: 11 additions & 0 deletions doubleml/did/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
The :mod:`doubleml.did` module implements double machine learning estimates based on difference in differences models.
"""

from .did import DoubleMLDID
from .did_cs import DoubleMLDIDCS

__all__ = [
"DoubleMLDID",
"DoubleMLDIDCS",
]
28 changes: 9 additions & 19 deletions doubleml/double_ml_did.py → doubleml/did/did.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from sklearn.utils.multiclass import type_of_target
import warnings

from .double_ml import DoubleML
from .double_ml_data import DoubleMLData
from .double_ml_score_mixins import LinearScoreMixin
from ..double_ml import DoubleML
from ..double_ml_data import DoubleMLData
from ..double_ml_score_mixins import LinearScoreMixin

from ._utils import _dml_cv_predict, _get_cond_smpls, _dml_tune, _trimm
from ._utils_checks import _check_score, _check_trimming, _check_finite_predictions, _check_is_propensity
from ..utils._estimation import _dml_cv_predict, _get_cond_smpls, _dml_tune, _trimm
from ..utils._checks import _check_score, _check_trimming, _check_finite_predictions, _check_is_propensity


class DoubleMLDID(LinearScoreMixin, DoubleML):
Expand Down Expand Up @@ -49,10 +49,6 @@ class DoubleMLDID(LinearScoreMixin, DoubleML):
Indicates whether to use a sligthly different normalization from Sant'Anna and Zhao (2020).
Default is ``True``.

dml_procedure : str
A str (``'dml1'`` or ``'dml2'``) specifying the double machine learning algorithm.
Default is ``'dml2'``.

trimming_rule : str
A str (``'truncate'`` is the only choice) specifying the trimming approach.
Default is ``'truncate'``.
Expand All @@ -65,10 +61,6 @@ class DoubleMLDID(LinearScoreMixin, DoubleML):
Indicates whether the sample splitting should be drawn during initialization of the object.
Default is ``True``.

apply_cross_fitting : bool
Indicates whether cross-fitting should be applied.
Default is ``True``.

Examples
--------
>>> import numpy as np
Expand All @@ -93,18 +85,14 @@ def __init__(self,
n_rep=1,
score='observational',
in_sample_normalization=True,
dml_procedure='dml2',
trimming_rule='truncate',
trimming_threshold=1e-2,
draw_sample_splitting=True,
apply_cross_fitting=True):
draw_sample_splitting=True):
super().__init__(obj_dml_data,
n_folds,
n_rep,
score,
dml_procedure,
draw_sample_splitting,
apply_cross_fitting)
draw_sample_splitting)

self._check_data(self._dml_data)
valid_scores = ['observational', 'experimental']
Expand All @@ -117,6 +105,8 @@ def __init__(self,

# set stratication for resampling
self._strata = self._dml_data.d
if draw_sample_splitting:
self.draw_sample_splitting()

# check learners
ml_g_is_classifier = self._check_learner(ml_g, 'ml_g', regressor=True, classifier=True)
Expand Down
28 changes: 9 additions & 19 deletions doubleml/double_ml_did_cs.py → doubleml/did/did_cs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from sklearn.utils.multiclass import type_of_target
import warnings

from .double_ml import DoubleML
from .double_ml_data import DoubleMLData
from .double_ml_score_mixins import LinearScoreMixin
from ..double_ml import DoubleML
from ..double_ml_data import DoubleMLData
from ..double_ml_score_mixins import LinearScoreMixin

from ._utils import _dml_cv_predict, _trimm, _get_cond_smpls_2d, _dml_tune
from ._utils_checks import _check_score, _check_trimming, _check_finite_predictions, _check_is_propensity
from ..utils._estimation import _dml_cv_predict, _trimm, _get_cond_smpls_2d, _dml_tune
from ..utils._checks import _check_score, _check_trimming, _check_finite_predictions, _check_is_propensity


class DoubleMLDIDCS(LinearScoreMixin, DoubleML):
Expand Down Expand Up @@ -49,10 +49,6 @@ class DoubleMLDIDCS(LinearScoreMixin, DoubleML):
Indicates whether to use a sligthly different normalization from Sant'Anna and Zhao (2020).
Default is ``True``.

dml_procedure : str
A str (``'dml1'`` or ``'dml2'``) specifying the double machine learning algorithm.
Default is ``'dml2'``.

trimming_rule : str
A str (``'truncate'`` is the only choice) specifying the trimming approach.
Default is ``'truncate'``.
Expand All @@ -65,10 +61,6 @@ class DoubleMLDIDCS(LinearScoreMixin, DoubleML):
Indicates whether the sample splitting should be drawn during initialization of the object.
Default is ``True``.

apply_cross_fitting : bool
Indicates whether cross-fitting should be applied.
Default is ``True``.

Examples
--------
>>> import numpy as np
Expand All @@ -93,18 +85,14 @@ def __init__(self,
n_rep=1,
score='observational',
in_sample_normalization=True,
dml_procedure='dml2',
trimming_rule='truncate',
trimming_threshold=1e-2,
draw_sample_splitting=True,
apply_cross_fitting=True):
draw_sample_splitting=True):
super().__init__(obj_dml_data,
n_folds,
n_rep,
score,
dml_procedure,
draw_sample_splitting,
apply_cross_fitting)
draw_sample_splitting)

self._check_data(self._dml_data)
valid_scores = ['observational', 'experimental']
Expand All @@ -117,6 +105,8 @@ def __init__(self,

# set stratication for resampling
self._strata = self._dml_data.d.reshape(-1, 1) + 2 * self._dml_data.t.reshape(-1, 1)
if draw_sample_splitting:
self.draw_sample_splitting()

# check learners
ml_g_is_classifier = self._check_learner(ml_g, 'ml_g', regressor=True, classifier=True)
Expand Down
Empty file added doubleml/did/tests/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
from sklearn.base import clone

from ._utils import fit_predict, fit_predict_proba, tune_grid_search
from ._utils_did_manual import did_dml1, did_dml2
from ...tests._utils import fit_predict, fit_predict_proba, tune_grid_search
from ._utils_did_manual import did_dml2


def fit_did_cs(y, x, d, t,
learner_g, learner_m, all_smpls, dml_procedure, score, in_sample_normalization,
learner_g, learner_m, all_smpls, score, in_sample_normalization,
n_rep=1, g_d0_t0_params=None, g_d0_t1_params=None,
g_d1_t0_params=None, g_d1_t1_params=None, m_params=None,
trimming_threshold=1e-2):
Expand Down Expand Up @@ -57,11 +57,7 @@ def fit_did_cs(y, x, d, t,
all_psi_a.append(psi_a)
all_psi_b.append(psi_b)

if dml_procedure == 'dml1':
thetas[i_rep], ses[i_rep] = did_dml1(psi_a, psi_b, smpls)
else:
assert dml_procedure == 'dml2'
thetas[i_rep], ses[i_rep] = did_dml2(psi_a, psi_b)
thetas[i_rep], ses[i_rep] = did_dml2(psi_a, psi_b)

theta = np.median(thetas)
se = np.sqrt(np.median(np.power(ses, 2) * n_obs + np.power(thetas - theta, 2)) / n_obs)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
from sklearn.base import clone

from ._utils_boot import boot_manual, draw_weights
from ._utils import fit_predict, fit_predict_proba, tune_grid_search
from ...tests._utils_boot import boot_manual, draw_weights
from ...tests._utils import fit_predict, fit_predict_proba, tune_grid_search


def fit_did(y, x, d,
learner_g, learner_m, all_smpls, dml_procedure, score, in_sample_normalization,
learner_g, learner_m, all_smpls, score, in_sample_normalization,
n_rep=1, g0_params=None, g1_params=None, m_params=None,
trimming_threshold=1e-2):
n_obs = len(y)
Expand Down Expand Up @@ -43,11 +43,7 @@ def fit_did(y, x, d,
all_psi_a.append(psi_a)
all_psi_b.append(psi_b)

if dml_procedure == 'dml1':
thetas[i_rep], ses[i_rep] = did_dml1(psi_a, psi_b, smpls)
else:
assert dml_procedure == 'dml2'
thetas[i_rep], ses[i_rep] = did_dml2(psi_a, psi_b)
thetas[i_rep], ses[i_rep] = did_dml2(psi_a, psi_b)

theta = np.median(thetas)
se = np.sqrt(np.median(np.power(ses, 2) * n_obs + np.power(thetas - theta, 2)) / n_obs)
Expand Down Expand Up @@ -107,25 +103,6 @@ def compute_did_residuals(y, g_hat0_list, g_hat1_list, m_hat_list, p_hat_list, s
return resid_d0, g_hat0, g_hat1, m_hat, p_hat


def did_dml1(psi_a, psi_b, smpls):
thetas = np.zeros(len(smpls))
n_obs = len(psi_a)

for idx, (_, test_index) in enumerate(smpls):
thetas[idx] = - np.mean(psi_b[test_index]) / np.mean(psi_a[test_index])
theta_hat = np.mean(thetas)

if len(smpls) > 1:
se = np.sqrt(var_did(theta_hat, psi_a, psi_b, n_obs))
else:
assert len(smpls) == 1
test_index = smpls[0][1]
n_obs = len(test_index)
se = np.sqrt(var_did(theta_hat, psi_a[test_index], psi_b[test_index], n_obs))

return theta_hat, se


def did_dml2(psi_a, psi_b):
n_obs = len(psi_a)
theta_hat = - np.mean(psi_b) / np.mean(psi_a)
Expand Down Expand Up @@ -176,7 +153,6 @@ def var_did(theta, psi_a, psi_b, n_obs):

def boot_did(y, thetas, ses, all_psi_a, all_psi_b,
all_smpls, bootstrap, n_rep_boot, n_rep=1, apply_cross_fitting=True):
all_boot_theta = list()
all_boot_t_stat = list()
for i_rep in range(n_rep):
smpls = all_smpls[i_rep]
Expand All @@ -186,16 +162,14 @@ def boot_did(y, thetas, ses, all_psi_a, all_psi_b,
test_index = smpls[0][1]
n_obs = len(test_index)
weights = draw_weights(bootstrap, n_rep_boot, n_obs)
boot_theta, boot_t_stat = boot_did_single_split(
boot_t_stat = boot_did_single_split(
thetas[i_rep], all_psi_a[i_rep], all_psi_b[i_rep], smpls,
ses[i_rep], weights, n_rep_boot, apply_cross_fitting)
all_boot_theta.append(boot_theta)
all_boot_t_stat.append(boot_t_stat)

boot_theta = np.hstack(all_boot_theta)
boot_t_stat = np.hstack(all_boot_t_stat)

return boot_theta, boot_t_stat
return boot_t_stat


def boot_did_single_split(theta, psi_a, psi_b,
Expand All @@ -208,9 +182,9 @@ def boot_did_single_split(theta, psi_a, psi_b,
J = np.mean(psi_a[test_index])

psi = np.multiply(psi_a, theta) + psi_b
boot_theta, boot_t_stat = boot_manual(psi, J, smpls, se, weights, n_rep_boot, apply_cross_fitting)
boot_t_stat = boot_manual(psi, J, smpls, se, weights, n_rep_boot, apply_cross_fitting)

return boot_theta, boot_t_stat
return boot_t_stat


def tune_nuisance_did(y, x, d, ml_g, ml_m, smpls, score, n_folds_tune,
Expand Down
Loading