Skip to content

Commit 644cf6b

Browse files
committed
move IPW integration test to better test file
1 parent cc62438 commit 644cf6b

File tree

2 files changed

+51
-53
lines changed

2 files changed

+51
-53
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import arviz as az
1415
import numpy as np
1516
import pandas as pd
1617
import pytest
@@ -597,6 +598,56 @@ def test_iv_reg():
597598
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
598599

599600

601+
@pytest.mark.integration
602+
def test_inverse_prop():
603+
"""Test the InversePropensityWeighting class."""
604+
df = cp.load_data("nhefs")
605+
sample_kwargs = {
606+
"tune": 100,
607+
"draws": 500,
608+
"chains": 2,
609+
"cores": 2,
610+
"random_seed": 100,
611+
}
612+
result = cp.InversePropensityWeighting(
613+
df,
614+
formula="trt ~ 1 + age + race",
615+
outcome_variable="outcome",
616+
weighting_scheme="robust",
617+
model=cp.pymc_models.PropensityScore(sample_kwargs=sample_kwargs),
618+
)
619+
assert isinstance(result.idata, az.InferenceData)
620+
ps = result.idata.posterior["p"].mean(dim=("chain", "draw"))
621+
w1, w2, _, _ = result.make_doubly_robust_adjustment(ps)
622+
assert isinstance(w1, pd.Series)
623+
assert isinstance(w2, pd.Series)
624+
w1, w2, n1, nw = result.make_raw_adjustments(ps)
625+
assert isinstance(w1, pd.Series)
626+
assert isinstance(w2, pd.Series)
627+
w1, w2, n1, n2 = result.make_robust_adjustments(ps)
628+
assert isinstance(w1, pd.Series)
629+
assert isinstance(w2, pd.Series)
630+
w1, w2, n1, n2 = result.make_overlap_adjustments(ps)
631+
assert isinstance(w1, pd.Series)
632+
assert isinstance(w2, pd.Series)
633+
ate_list = result.get_ate(0, result.idata)
634+
assert isinstance(ate_list, list)
635+
ate_list = result.get_ate(0, result.idata, method="raw")
636+
assert isinstance(ate_list, list)
637+
ate_list = result.get_ate(0, result.idata, method="robust")
638+
assert isinstance(ate_list, list)
639+
ate_list = result.get_ate(0, result.idata, method="overlap")
640+
assert isinstance(ate_list, list)
641+
fig, axs = result.plot_ate(prop_draws=1, ate_draws=10)
642+
assert isinstance(fig, plt.Figure)
643+
assert isinstance(axs, list)
644+
assert all(isinstance(ax, plt.Axes) for ax in axs)
645+
fig, axs = result.plot_balance_ecdf("age")
646+
assert isinstance(fig, plt.Figure)
647+
assert isinstance(axs, list)
648+
assert all(isinstance(ax, plt.Axes) for ax in axs)
649+
650+
600651
# DEPRECATION WARNING TESTS ============================================================
601652

602653

causalpy/tests/test_misc.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
Miscellaneous unit tests
1616
"""
1717

18-
import arviz as az
19-
import pandas as pd
20-
from matplotlib import pyplot as plt
21-
2218
import causalpy as cp
2319

2420
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
@@ -41,52 +37,3 @@ def test_regression_kink_gradient_change():
4137
assert cp.RegressionKink._eval_gradient_change(0, 0, -2, 1) == -2.0
4238
assert cp.RegressionKink._eval_gradient_change(-1, -1, -2, 1) == -1.0
4339
assert cp.RegressionKink._eval_gradient_change(1, 0, -2, 1) == -1.0
44-
45-
46-
def test_inverse_prop():
47-
"""Test the InversePropensityWeighting class."""
48-
df = cp.load_data("nhefs")
49-
sample_kwargs = {
50-
"tune": 100,
51-
"draws": 500,
52-
"chains": 2,
53-
"cores": 2,
54-
"random_seed": 100,
55-
}
56-
result = cp.InversePropensityWeighting(
57-
df,
58-
formula="trt ~ 1 + age + race",
59-
outcome_variable="outcome",
60-
weighting_scheme="robust",
61-
model=cp.pymc_models.PropensityScore(sample_kwargs=sample_kwargs),
62-
)
63-
assert isinstance(result.idata, az.InferenceData)
64-
ps = result.idata.posterior["p"].mean(dim=("chain", "draw"))
65-
w1, w2, _, _ = result.make_doubly_robust_adjustment(ps)
66-
assert isinstance(w1, pd.Series)
67-
assert isinstance(w2, pd.Series)
68-
w1, w2, n1, nw = result.make_raw_adjustments(ps)
69-
assert isinstance(w1, pd.Series)
70-
assert isinstance(w2, pd.Series)
71-
w1, w2, n1, n2 = result.make_robust_adjustments(ps)
72-
assert isinstance(w1, pd.Series)
73-
assert isinstance(w2, pd.Series)
74-
w1, w2, n1, n2 = result.make_overlap_adjustments(ps)
75-
assert isinstance(w1, pd.Series)
76-
assert isinstance(w2, pd.Series)
77-
ate_list = result.get_ate(0, result.idata)
78-
assert isinstance(ate_list, list)
79-
ate_list = result.get_ate(0, result.idata, method="raw")
80-
assert isinstance(ate_list, list)
81-
ate_list = result.get_ate(0, result.idata, method="robust")
82-
assert isinstance(ate_list, list)
83-
ate_list = result.get_ate(0, result.idata, method="overlap")
84-
assert isinstance(ate_list, list)
85-
fig, axs = result.plot_ate(prop_draws=1, ate_draws=10)
86-
assert isinstance(fig, plt.Figure)
87-
assert isinstance(axs, list)
88-
assert all(isinstance(ax, plt.Axes) for ax in axs)
89-
fig, axs = result.plot_balance_ecdf("age")
90-
assert isinstance(fig, plt.Figure)
91-
assert isinstance(axs, list)
92-
assert all(isinstance(ax, plt.Axes) for ax in axs)

0 commit comments

Comments
 (0)