Skip to content

Commit 02dacb2

Browse files
committed
convert scikitlearn models behind the scenes
1 parent dede64a commit 02dacb2

16 files changed

+994
-1027
lines changed

causalpy/experiments/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
from abc import abstractmethod
1919

20+
from sklearn.base import RegressorMixin
21+
2022
from causalpy.pymc_models import PyMCModel
21-
from causalpy.skl_models import ScikitLearnModel
23+
from causalpy.skl_models import create_causalpy_compatible_class
2224

2325

2426
class BaseExperiment:
@@ -28,13 +30,18 @@ class BaseExperiment:
2830
supports_ols: bool
2931

3032
def __init__(self, model=None):
33+
# Ensure we've made any provided Scikit Learn model (as identified as being type
34+
# RegressorMixin) compatible with CausalPy by appending our custom methods.
35+
if isinstance(model, RegressorMixin):
36+
model = create_causalpy_compatible_class(model)
37+
3138
if model is not None:
3239
self.model = model
3340

3441
if isinstance(self.model, PyMCModel) and not self.supports_bayes:
3542
raise ValueError("Bayesian models not supported.")
3643

37-
if isinstance(self.model, ScikitLearnModel) and not self.supports_ols:
44+
if isinstance(self.model, RegressorMixin) and not self.supports_ols:
3845
raise ValueError("OLS models not supported.")
3946

4047
if self.model is None:
@@ -57,7 +64,7 @@ def plot(self, *args, **kwargs) -> tuple:
5764
"""
5865
if isinstance(self.model, PyMCModel):
5966
return self.bayesian_plot(*args, **kwargs)
60-
elif isinstance(self.model, ScikitLearnModel):
67+
elif isinstance(self.model, RegressorMixin):
6168
return self.ols_plot(*args, **kwargs)
6269
else:
6370
raise ValueError("Unsupported model type")

causalpy/experiments/diff_in_diff.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import seaborn as sns
2222
from matplotlib import pyplot as plt
2323
from patsy import build_design_matrices, dmatrices
24+
from sklearn.base import RegressorMixin
2425

2526
from causalpy.custom_exceptions import (
2627
DataException,
2728
FormulaException,
2829
)
2930
from causalpy.plot_utils import plot_xY
3031
from causalpy.pymc_models import PyMCModel
31-
from causalpy.skl_models import ScikitLearnModel
3232
from causalpy.utils import _is_variable_dummy_coded, convert_to_string, round_num
3333

3434
from .base import BaseExperiment
@@ -106,7 +106,7 @@ def __init__(
106106
if isinstance(self.model, PyMCModel):
107107
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
108108
self.model.fit(X=self.X, y=self.y, coords=COORDS)
109-
elif isinstance(self.model, ScikitLearnModel):
109+
elif isinstance(self.model, RegressorMixin):
110110
self.model.fit(X=self.X, y=self.y)
111111
else:
112112
raise ValueError("Model type not recognized")
@@ -181,7 +181,7 @@ def __init__(
181181
self.causal_impact = self.model.idata.posterior["beta"].isel(
182182
{"coeffs": i}
183183
)
184-
elif isinstance(self.model, ScikitLearnModel):
184+
elif isinstance(self.model, RegressorMixin):
185185
# This is the coefficient on the interaction term
186186
# TODO: THIS IS NOT YET CORRECT ?????
187187
self.causal_impact = (

causalpy/experiments/prepostfit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
import pandas as pd
2323
from matplotlib import pyplot as plt
2424
from patsy import build_design_matrices, dmatrices
25+
from sklearn.base import RegressorMixin
2526

2627
from causalpy.custom_exceptions import BadIndexException
2728
from causalpy.plot_utils import plot_xY
2829
from causalpy.pymc_models import PyMCModel
29-
from causalpy.skl_models import ScikitLearnModel
3030
from causalpy.utils import round_num
3131

3232
from .base import BaseExperiment
@@ -77,7 +77,7 @@ def __init__(
7777
if isinstance(self.model, PyMCModel):
7878
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}
7979
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
80-
elif isinstance(self.model, ScikitLearnModel):
80+
elif isinstance(self.model, RegressorMixin):
8181
self.model.fit(X=self.pre_X, y=self.pre_y)
8282
else:
8383
raise ValueError("Model type not recognized")

causalpy/experiments/prepostnegd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
import seaborn as sns
2424
from matplotlib import pyplot as plt
2525
from patsy import build_design_matrices, dmatrices
26+
from sklearn.base import RegressorMixin
2627

2728
from causalpy.custom_exceptions import (
2829
DataException,
2930
)
3031
from causalpy.plot_utils import plot_xY
3132
from causalpy.pymc_models import PyMCModel
32-
from causalpy.skl_models import ScikitLearnModel
3333
from causalpy.utils import _is_variable_dummy_coded, round_num
3434

3535
from .base import BaseExperiment
@@ -115,7 +115,7 @@ def __init__(
115115
if isinstance(self.model, PyMCModel):
116116
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
117117
self.model.fit(X=self.X, y=self.y, coords=COORDS)
118-
elif isinstance(self.model, ScikitLearnModel):
118+
elif isinstance(self.model, RegressorMixin):
119119
raise NotImplementedError("Not implemented for OLS model")
120120
else:
121121
raise ValueError("Model type not recognized")

causalpy/experiments/regression_discontinuity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
import seaborn as sns
2323
from matplotlib import pyplot as plt
2424
from patsy import build_design_matrices, dmatrices
25+
from sklearn.base import RegressorMixin
2526

2627
from causalpy.custom_exceptions import (
2728
DataException,
2829
FormulaException,
2930
)
3031
from causalpy.plot_utils import plot_xY
3132
from causalpy.pymc_models import PyMCModel
32-
from causalpy.skl_models import ScikitLearnModel
3333
from causalpy.utils import _is_variable_dummy_coded, convert_to_string, round_num
3434

3535
from .base import BaseExperiment
@@ -126,7 +126,7 @@ def __init__(
126126
# fit the model to the observed (pre-intervention) data
127127
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
128128
self.model.fit(X=self.X, y=self.y, coords=COORDS)
129-
elif isinstance(self.model, ScikitLearnModel):
129+
elif isinstance(self.model, RegressorMixin):
130130
self.model.fit(X=self.X, y=self.y)
131131
else:
132132
raise ValueError("Model type not recognized")

causalpy/skl_models.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from causalpy.utils import round_num
2424

2525

26-
class ScikitLearnModel:
26+
class ScikitLearnAdaptor:
2727
"""Base class for scikit-learn models that can be used for causal inference."""
2828

2929
def calculate_impact(self, y_true, y_pred):
@@ -53,7 +53,7 @@ def get_coeffs(self):
5353
return np.squeeze(self.coef_)
5454

5555

56-
class WeightedProportion(ScikitLearnModel, LinearModel, RegressorMixin):
56+
class WeightedProportion(ScikitLearnAdaptor, LinearModel, RegressorMixin):
5757
"""Weighted proportion model for causal inference. Used for synthetic control
5858
methods for example"""
5959

@@ -82,11 +82,19 @@ def predict(self, X):
8282

8383
def create_causalpy_compatible_class(
8484
estimator: type[RegressorMixin],
85-
) -> type[ScikitLearnModel]:
85+
) -> type[RegressorMixin]:
8686
"""This function takes a scikit-learn estimator and returns a new class that is
8787
compatible with CausalPy."""
88-
89-
class Model(ScikitLearnModel, estimator):
90-
pass
91-
92-
return Model
88+
_add_mixin_methods(estimator, ScikitLearnAdaptor)
89+
return estimator
90+
91+
92+
def _add_mixin_methods(model_instance, mixin_class):
93+
"""Utility function to bind mixin methods to an existing model instance."""
94+
for attr_name in dir(mixin_class):
95+
attr = getattr(mixin_class, attr_name)
96+
if callable(attr) and not attr_name.startswith("__"):
97+
# Bind the method to the instance
98+
method = attr.__get__(model_instance, model_instance.__class__)
99+
setattr(model_instance, attr_name, method)
100+
return model_instance

causalpy/tests/test_input_validation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from sklearn.linear_model import LinearRegression
2525

26-
CustomLinearRegression = cp.create_causalpy_compatible_class(LinearRegression)
2726

2827
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
2928

@@ -254,7 +253,7 @@ def test_rd_validation_treated_in_formula():
254253
_ = cp.RegressionDiscontinuity(
255254
df,
256255
formula="y ~ 1 + x",
257-
model=CustomLinearRegression(),
256+
model=LinearRegression(),
258257
treatment_threshold=0.5,
259258
)
260259

@@ -281,7 +280,7 @@ def test_rd_validation_treated_is_dummy():
281280
_ = cp.RegressionDiscontinuity(
282281
df,
283282
formula="y ~ 1 + x + treated",
284-
model=CustomLinearRegression(),
283+
model=LinearRegression(),
285284
treatment_threshold=0.5,
286285
)
287286

causalpy/tests/test_integration_skl_examples.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from sklearn.linear_model import LinearRegression
2121

2222
import causalpy as cp
23-
from causalpy.skl_models import ScikitLearnModel
24-
25-
CustomLinearRegression = cp.create_causalpy_compatible_class(LinearRegression)
2623

2724

2825
@pytest.mark.integration
@@ -42,7 +39,7 @@ def test_did():
4239
group_variable_name="group",
4340
treated=1,
4441
untreated=0,
45-
model=CustomLinearRegression(),
42+
model=LinearRegression(),
4643
)
4744
assert isinstance(data, pd.DataFrame)
4845
assert isinstance(result, cp.DifferenceInDifferences)
@@ -71,7 +68,7 @@ def test_rd_drinking():
7168
df,
7269
formula="all ~ 1 + age + treated",
7370
running_variable_name="age",
74-
model=CustomLinearRegression(),
71+
model=LinearRegression(),
7572
treatment_threshold=21,
7673
epsilon=0.001,
7774
)
@@ -103,7 +100,7 @@ def test_its():
103100
df,
104101
treatment_time,
105102
formula="y ~ 1 + t + C(month)",
106-
model=CustomLinearRegression(),
103+
model=LinearRegression(),
107104
)
108105
assert isinstance(df, pd.DataFrame)
109106
assert isinstance(result, cp.InterruptedTimeSeries)
@@ -165,7 +162,7 @@ def test_rd_linear_main_effects():
165162
result = cp.RegressionDiscontinuity(
166163
data,
167164
formula="y ~ 1 + x + treated",
168-
model=CustomLinearRegression(),
165+
model=LinearRegression(),
169166
treatment_threshold=0.5,
170167
epsilon=0.001,
171168
)
@@ -191,7 +188,7 @@ def test_rd_linear_main_effects_bandwidth():
191188
result = cp.skl_experiments.RegressionDiscontinuity(
192189
data,
193190
formula="y ~ 1 + x + treated",
194-
model=CustomLinearRegression(),
191+
model=LinearRegression(),
195192
treatment_threshold=0.5,
196193
epsilon=0.001,
197194
bandwidth=0.3,
@@ -217,7 +214,7 @@ def test_rd_linear_with_interaction():
217214
result = cp.RegressionDiscontinuity(
218215
data,
219216
formula="y ~ 1 + x + treated + x:treated",
220-
model=CustomLinearRegression(),
217+
model=LinearRegression(),
221218
treatment_threshold=0.5,
222219
epsilon=0.001,
223220
)
@@ -238,18 +235,13 @@ def test_rd_linear_with_gaussian_process():
238235
1. data is a dataframe
239236
2. skl_experiements.RegressionDiscontinuity returns correct type
240237
"""
241-
242-
# create a custom GaussianProcessRegressor class by subclassing
243-
# GaussianProcessRegressor and adding the ScikitLearnModel mixin
244-
class CustomGaussianProcessRegressor(GaussianProcessRegressor, ScikitLearnModel):
245-
pass
246-
247238
data = cp.load_data("rd")
248239
kernel = 1.0 * ExpSineSquared(1.0, 5.0) + WhiteKernel(1e-1)
249240
result = cp.RegressionDiscontinuity(
250241
data,
251242
formula="y ~ 1 + x + treated",
252-
model=CustomGaussianProcessRegressor(kernel=kernel),
243+
model=GaussianProcessRegressor(kernel=kernel),
244+
model_kwargs={"kernel": kernel},
253245
treatment_threshold=0.5,
254246
epsilon=0.001,
255247
)
@@ -275,7 +267,7 @@ def test_did_deprecation_warning():
275267
group_variable_name="group",
276268
treated=1,
277269
untreated=0,
278-
model=CustomLinearRegression(),
270+
model=LinearRegression(),
279271
)
280272
assert isinstance(result, cp.DifferenceInDifferences)
281273

@@ -294,7 +286,7 @@ def test_its_deprecation_warning():
294286
df,
295287
treatment_time,
296288
formula="y ~ 1 + t + C(month)",
297-
model=CustomLinearRegression(),
289+
model=LinearRegression(),
298290
)
299291
assert isinstance(result, cp.InterruptedTimeSeries)
300292

@@ -322,7 +314,7 @@ def test_rd_deprecation_warning():
322314
result = cp.skl_experiments.RegressionDiscontinuity(
323315
data,
324316
formula="y ~ 1 + x + treated",
325-
model=CustomLinearRegression(),
317+
model=LinearRegression(),
326318
treatment_threshold=0.5,
327319
epsilon=0.001,
328320
)

0 commit comments

Comments
 (0)