-
Notifications
You must be signed in to change notification settings - Fork 77
Major code refactor to unify quasi experiment classes #381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
72 commits
Select commit
Hold shift + click to select a range
df1989f
Very initial work on the big refactor
drbenvincent d36e289
update uml diagram
drbenvincent 0945fc4
change from dict to dataclass
drbenvincent d0d3bc3
Friday evening progress
drbenvincent f87577b
bayesian anova sorted + made plot methods static methods
drbenvincent e82325d
regression discontinuity sorted
drbenvincent 4ccea2c
regression discontinuity sorted + add summary methods
drbenvincent 119f749
rename classes
drbenvincent 841cca4
add inverse propensity weighting
drbenvincent e5a07d9
add docstrings back in
drbenvincent 7b1e600
make doctests pass
drbenvincent 48cf9b5
make lots of tests pass
drbenvincent 5c2b103
make test pass - Line2D
drbenvincent b23b373
fix import
drbenvincent 0435456
add docstrings to make interrogate pre-commit check pass
drbenvincent ac2389d
Update expt_prepostnegd.py
drbenvincent c6ae453
fix _causal_impact_summary_stat
drbenvincent ff5122f
tidy up plotting
drbenvincent 291dc47
fix score
drbenvincent 4d10175
add convert_to_string + fix typo
drbenvincent 4442d5b
ensure fig, ax returned from plot method
drbenvincent 3757199
add causal impact arrow back in to DID plot
drbenvincent 6c4e43c
zero failing tests
drbenvincent ae7c405
add test coverage for plot method to integration tests
drbenvincent a85ccfb
fix errors in tests + make some tests pass
drbenvincent 00c1290
make a test pass
drbenvincent 5b5ccd2
fix docstring + add type hint
drbenvincent c3df3eb
make another plot test pass
drbenvincent 2fb344a
fix plotting issue with some scikit-learn models
drbenvincent fd74658
fix problem with GaussianProcessRegressor. Back to ZERO failing tests
drbenvincent 41bf080
re-run most notebooks
drbenvincent c77de98
fix plotting of control units for synthetic control + add test coverage
drbenvincent ea2b859
update UML
drbenvincent b0b4539
tweaks
drbenvincent b0dc8c6
remove commented code
drbenvincent 0af9bfb
pre-commit autoupdate
drbenvincent 28f3b07
improve manual type checking sections in the experiment modules
drbenvincent e84c199
add/improve module level docstrings
drbenvincent b0eabff
add missing docstrings back in
drbenvincent 1b26499
fix failing doctest
drbenvincent 133ee3b
Merge branch 'main' into refactor
juanitorduz 5c07e4d
change data validation mixins to class methods
drbenvincent 224ec84
remove commented imports
drbenvincent 100784f
_input_validation -> input_validation
drbenvincent d6e058c
change all asserts outside of tests into checks which raise exceptions
drbenvincent dab5824
experiment submodule
drbenvincent 4b8141d
remove setting attributes to None in ExperimentDesign base class
drbenvincent 47d4479
plot_ATE -> plot_ate
drbenvincent 70e58f8
change experiment.py::ExperimentalDesign to base.py::BaseExperiment
drbenvincent c080fa9
remove PlotComponent abstract base class
drbenvincent f915f77
add return type hints for plot methods + fix up test assertions
drbenvincent 6822c61
Handle deprecation like a grown up
drbenvincent 1cdf7c2
convert the deprecation wrapper classes into functions
drbenvincent fbc4c94
move plotting into experiment classes, removing PlotComponent entirely
drbenvincent fa640b8
create new helper function, create_causalpy_compatible_class
drbenvincent 4edc6d5
fix typo
drbenvincent 38b0e68
Update causalpy/experiments/diff_in_diff.py
drbenvincent 6a9214b
remove **kwargs being passed to BaseExperiment.__init__
drbenvincent 0cf4f46
remove unnecessary kwargs in did experiments
drbenvincent 60cfb2a
update uml
drbenvincent 121fe46
remove old API pages
drbenvincent cc62438
better model/experiment compatability
drbenvincent 644cf6b
move IPW integration test to better test file
drbenvincent dede64a
add warning to NotImplementedError
drbenvincent 02dacb2
convert scikitlearn models behind the scenes
drbenvincent 3f9763a
fix import in one of the tests
drbenvincent 01ce582
improve deprecation warnings
drbenvincent 6f6fade
fix typo
drbenvincent 66f22aa
update deprecation warnings for the scikit-learn experiment classes
drbenvincent 77b0b2b
Merge branch 'main' into refactor
drbenvincent 9bc3d25
depricated -> deprecated
drbenvincent e0b0847
remove redundant sample_kwargs definition in test
drbenvincent File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The PyMC Labs Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright 2024 The PyMC Labs Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Base class for quasi experimental designs. | ||
""" | ||
|
||
from abc import abstractmethod | ||
|
||
from sklearn.base import RegressorMixin | ||
|
||
from causalpy.pymc_models import PyMCModel | ||
from causalpy.skl_models import create_causalpy_compatible_class | ||
|
||
|
||
class BaseExperiment: | ||
"""Base class for quasi experimental designs.""" | ||
|
||
supports_bayes: bool | ||
supports_ols: bool | ||
|
||
def __init__(self, model=None): | ||
# Ensure we've made any provided Scikit Learn model (as identified as being type | ||
# RegressorMixin) compatible with CausalPy by appending our custom methods. | ||
if isinstance(model, RegressorMixin): | ||
model = create_causalpy_compatible_class(model) | ||
|
||
if model is not None: | ||
self.model = model | ||
|
||
if isinstance(self.model, PyMCModel) and not self.supports_bayes: | ||
raise ValueError("Bayesian models not supported.") | ||
|
||
if isinstance(self.model, RegressorMixin) and not self.supports_ols: | ||
raise ValueError("OLS models not supported.") | ||
|
||
if self.model is None: | ||
raise ValueError("model not set or passed.") | ||
|
||
@property | ||
def idata(self): | ||
"""Return the InferenceData object of the model. Only relevant for PyMC models.""" | ||
return self.model.idata | ||
|
||
def print_coefficients(self, round_to=None): | ||
"""Ask the model to print its coefficients.""" | ||
self.model.print_coefficients(self.labels, round_to) | ||
|
||
def plot(self, *args, **kwargs) -> tuple: | ||
"""Plot the model. | ||
|
||
Internally, this function dispatches to either `bayesian_plot` or `ols_plot` | ||
depending on the model type. | ||
""" | ||
if isinstance(self.model, PyMCModel): | ||
return self.bayesian_plot(*args, **kwargs) | ||
elif isinstance(self.model, RegressorMixin): | ||
return self.ols_plot(*args, **kwargs) | ||
else: | ||
raise ValueError("Unsupported model type") | ||
|
||
@abstractmethod | ||
def bayesian_plot(self, *args, **kwargs): | ||
"""Abstract method for plotting the model.""" | ||
raise NotImplementedError("bayesian_plot method not yet implemented") | ||
|
||
@abstractmethod | ||
def ols_plot(self, *args, **kwargs): | ||
"""Abstract method for plotting the model.""" | ||
raise NotImplementedError("ols_plot method not yet implemented") | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.