15
15
Miscellaneous unit tests
16
16
"""
17
17
18
- import arviz as az
19
- import pandas as pd
20
- from matplotlib import pyplot as plt
21
-
22
18
import causalpy as cp
23
19
24
20
sample_kwargs = {"tune" : 20 , "draws" : 20 , "chains" : 2 , "cores" : 2 }
@@ -41,52 +37,3 @@ def test_regression_kink_gradient_change():
41
37
assert cp .RegressionKink ._eval_gradient_change (0 , 0 , - 2 , 1 ) == - 2.0
42
38
assert cp .RegressionKink ._eval_gradient_change (- 1 , - 1 , - 2 , 1 ) == - 1.0
43
39
assert cp .RegressionKink ._eval_gradient_change (1 , 0 , - 2 , 1 ) == - 1.0
44
-
45
-
46
- def test_inverse_prop ():
47
- """Test the InversePropensityWeighting class."""
48
- df = cp .load_data ("nhefs" )
49
- sample_kwargs = {
50
- "tune" : 100 ,
51
- "draws" : 500 ,
52
- "chains" : 2 ,
53
- "cores" : 2 ,
54
- "random_seed" : 100 ,
55
- }
56
- result = cp .InversePropensityWeighting (
57
- df ,
58
- formula = "trt ~ 1 + age + race" ,
59
- outcome_variable = "outcome" ,
60
- weighting_scheme = "robust" ,
61
- model = cp .pymc_models .PropensityScore (sample_kwargs = sample_kwargs ),
62
- )
63
- assert isinstance (result .idata , az .InferenceData )
64
- ps = result .idata .posterior ["p" ].mean (dim = ("chain" , "draw" ))
65
- w1 , w2 , _ , _ = result .make_doubly_robust_adjustment (ps )
66
- assert isinstance (w1 , pd .Series )
67
- assert isinstance (w2 , pd .Series )
68
- w1 , w2 , n1 , nw = result .make_raw_adjustments (ps )
69
- assert isinstance (w1 , pd .Series )
70
- assert isinstance (w2 , pd .Series )
71
- w1 , w2 , n1 , n2 = result .make_robust_adjustments (ps )
72
- assert isinstance (w1 , pd .Series )
73
- assert isinstance (w2 , pd .Series )
74
- w1 , w2 , n1 , n2 = result .make_overlap_adjustments (ps )
75
- assert isinstance (w1 , pd .Series )
76
- assert isinstance (w2 , pd .Series )
77
- ate_list = result .get_ate (0 , result .idata )
78
- assert isinstance (ate_list , list )
79
- ate_list = result .get_ate (0 , result .idata , method = "raw" )
80
- assert isinstance (ate_list , list )
81
- ate_list = result .get_ate (0 , result .idata , method = "robust" )
82
- assert isinstance (ate_list , list )
83
- ate_list = result .get_ate (0 , result .idata , method = "overlap" )
84
- assert isinstance (ate_list , list )
85
- fig , axs = result .plot_ate (prop_draws = 1 , ate_draws = 10 )
86
- assert isinstance (fig , plt .Figure )
87
- assert isinstance (axs , list )
88
- assert all (isinstance (ax , plt .Axes ) for ax in axs )
89
- fig , axs = result .plot_balance_ecdf ("age" )
90
- assert isinstance (fig , plt .Figure )
91
- assert isinstance (axs , list )
92
- assert all (isinstance (ax , plt .Axes ) for ax in axs )
0 commit comments