-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG] Common Private Loss Module with tempita #20567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 44 commits
Commits
Show all changes
133 commits
Select commit
Hold shift + click to select a range
f8362d7
ENH add common link function submodule
lorentzenchr afdb67e
ENH add common loss function submodule
lorentzenchr 830b814
CLN replace deprecated np.int by int
lorentzenchr 9504c89
DOC document default=1 for n_threads
lorentzenchr fb3bce2
CLN comments and line wrapping
lorentzenchr 2c86bf4
CLN comments and doc
lorentzenchr d68c07e
BUG remove useless line of code
lorentzenchr 3d9c800
CLN remove line that was commented out
lorentzenchr aba1b67
CLN nitpicks in comments and docstrings
lorentzenchr 022e418
ENH set NPY_NO_DEPRECATED_API
lorentzenchr 49bb402
MNT change NPY_1_13_API_VERSION to NPY_1_7_API_VERSION
lorentzenchr 6d77090
MNT comment out NPY_NO_DEPRECATED_API
lorentzenchr ceda673
TST restructure domain test cases
lorentzenchr c73e3fa
DOC add losses to API reference
lorentzenchr e650522
MNT add classes to __init__
lorentzenchr a31d8fb
CLN fix import
lorentzenchr e5b6266
DOC minor docstring changes
lorentzenchr 3492383
TST prefer docstring over comment
lorentzenchr 9d86d82
ENH define loss.is_multiclass
lorentzenchr cc90e4d
DOC fix typos
lorentzenchr d0b48ac
CLN address review comments
lorentzenchr 7794617
DOC small docstring improvements
lorentzenchr 35b7423
TST test more losses in test_specific_fit_intercept_only
lorentzenchr b390002
FIX test_loss_boundary
lorentzenchr 12b4634
ENH Tempita for losses
lorentzenchr 061a41b
MNT apply black
lorentzenchr 98f8877
TST replace np.quantile by np.percentile
lorentzenchr 3f8ffe9
ENH make Interval a dataclass
lorentzenchr b5e61d2
DOC improve docstrings in link.py
lorentzenchr cfdd67c
MNT use numpy dtype instead of Python type
lorentzenchr 73311e9
TST add negative intervals
lorentzenchr 3692280
ENH add __post_init__ to class Interval
lorentzenchr 4cb4d3d
MNT rename cython losses
lorentzenchr c13e3b1
TST loss.predict_proba
lorentzenchr 390ff19
TST predict_proba and gradient_proba
lorentzenchr ebd9f40
MNT use is_multiclass in tests instead of n_classes <= 2
lorentzenchr 60e0fc4
DOC docstring predict_proba and more
lorentzenchr 386aae9
Merge branch 'main' into loss_module_tempita
lorentzenchr 70199a4
MNT remove top_path from gen_from_templates
lorentzenchr 5721077
CI add --allow-releaseinfo-change in circleci
lorentzenchr f53e3f8
TST test graceful squeezing
lorentzenchr df072e4
Merge branch 'main' into loss_module_tempita
lorentzenchr a5c1d3c
CLN no extra methods for HalfSquaredError
lorentzenchr 7bee26b
TST remove testing if approx_hessian=True
lorentzenchr 15b7c99
DOC remove loss module for classes.rst
lorentzenchr 67a3069
TST that losses can be pickled
lorentzenchr 696d18f
TST add test_loss_on_specific_values
lorentzenchr cef6e24
FIX make cython inheritance happy and losses pickable
lorentzenchr 130306c
ENH support const memoryviews by ReadonlyWrapper
lorentzenchr 27f3dea
address review comments
lorentzenchr 86d659c
FEA add ReadonlyWrapper
lorentzenchr 7d30abb
ENH add common link function submodule
lorentzenchr 53a4774
ENH add common loss function submodule
lorentzenchr 25012bc
CLN replace deprecated np.int by int
lorentzenchr d6c9307
DOC document default=1 for n_threads
lorentzenchr 5682a72
CLN comments and line wrapping
lorentzenchr 6ace462
CLN comments and doc
lorentzenchr 1a5ae1c
BUG remove useless line of code
lorentzenchr af93e7b
CLN remove line that was commented out
lorentzenchr 9ed2096
CLN nitpicks in comments and docstrings
lorentzenchr b5e8224
ENH set NPY_NO_DEPRECATED_API
lorentzenchr c5c0d55
MNT change NPY_1_13_API_VERSION to NPY_1_7_API_VERSION
lorentzenchr d7b105e
MNT comment out NPY_NO_DEPRECATED_API
lorentzenchr 4b48294
TST restructure domain test cases
lorentzenchr a4572a4
DOC add losses to API reference
lorentzenchr 64964d6
MNT add classes to __init__
lorentzenchr 43b6269
CLN fix import
lorentzenchr f90049b
DOC minor docstring changes
lorentzenchr 47691ff
TST prefer docstring over comment
lorentzenchr b2e0856
ENH define loss.is_multiclass
lorentzenchr 35ce8d2
DOC fix typos
lorentzenchr 5904c8b
CLN address review comments
lorentzenchr 3a7122b
DOC small docstring improvements
lorentzenchr c3b7658
TST test more losses in test_specific_fit_intercept_only
lorentzenchr eae2def
FIX test_loss_boundary
lorentzenchr ec5fd02
ENH Tempita for losses
lorentzenchr a9d93b4
MNT apply black
lorentzenchr f8a024a
TST replace np.quantile by np.percentile
lorentzenchr dfcd078
ENH make Interval a dataclass
lorentzenchr b5c5bf5
DOC improve docstrings in link.py
lorentzenchr 3d6a477
MNT use numpy dtype instead of Python type
lorentzenchr c2f0f8e
TST add negative intervals
lorentzenchr 33cabc4
ENH add __post_init__ to class Interval
lorentzenchr 4cd2826
MNT rename cython losses
lorentzenchr 5f61a90
TST loss.predict_proba
lorentzenchr 7e1af6c
TST predict_proba and gradient_proba
lorentzenchr fb1ab5a
MNT use is_multiclass in tests instead of n_classes <= 2
lorentzenchr bdb6d18
DOC docstring predict_proba and more
lorentzenchr 0750300
MNT remove top_path from gen_from_templates
lorentzenchr 04df8a9
TST test graceful squeezing
lorentzenchr 86b9293
CLN no extra methods for HalfSquaredError
lorentzenchr 330b98e
TST remove testing if approx_hessian=True
lorentzenchr d9b6bc8
DOC remove loss module for classes.rst
lorentzenchr 96ab3ba
TST that losses can be pickled
lorentzenchr 6c8136e
TST add test_loss_on_specific_values
lorentzenchr fa15691
FIX make cython inheritance happy and losses pickable
lorentzenchr 202953b
ENH support const memoryviews by ReadonlyWrapper
lorentzenchr 9a89ff6
address review comments
lorentzenchr 5a54fbe
CLN nitpick
lorentzenchr 2820be5
CLN import ReadonlyWrapper from utils
lorentzenchr 73054cd
Merge
lorentzenchr 310eca7
Merge branch 'main' into loss_module_tempita
lorentzenchr 4525085
MNT replace ReadonlyWrapper by ReadonlyArrayWrapper
lorentzenchr 439ee83
trigger CI
lorentzenchr a201fd0
MNT rename out parameters
lorentzenchr 0c7c68b
CLN address review comments
lorentzenchr 5f3b0a5
TST increase maxiter
lorentzenchr d6d1516
Merge branch 'main' into loss_module_tempita
lorentzenchr 12f529e
Revert "TST increase maxiter"
lorentzenchr 643ad05
MNT composition instead of inheritance
lorentzenchr cb3cacc
MNT interval_raw_prediction never used
lorentzenchr 2147c61
CLN closs and link as args in __init__
lorentzenchr 48fc50b
trigger CI
lorentzenchr 1a031ff
trigger CI
lorentzenchr 18e0c4e
Merge branch 'main' into loss_module_tempita
lorentzenchr 0068c68
DEBUG print infos
lorentzenchr 903be0b
DEBUG remove readonly_memmap test
lorentzenchr 4492c58
Revert "DEBUG remove readonly_memmap test"
lorentzenchr ef3b9e7
TST skip test if data not aligned
lorentzenchr 942c142
Revert "DEBUG print infos"
lorentzenchr 0b1d618
TST zeros instead of empty initial guess
lorentzenchr c30535f
Revert "TST skip test if data not aligned"
lorentzenchr e9c551c
DEBUG set boundscheck=True
lorentzenchr 88e543c
Revert "DEBUG set boundscheck=True"
lorentzenchr a0b3b86
CLN setup.py
lorentzenchr 00328b2
CLN rename binary and categorical cross entropy to binomial and multi…
lorentzenchr 54f78fc
Merge branch 'main' into loss_module_tempita
lorentzenchr d7d25b6
Merge branch 'main' into loss_module_tempita
lorentzenchr 98d4790
trigger CI
lorentzenchr 9d8b0e9
Merge branch 'main' into loss_module_tempita
lorentzenchr 3b9403f
TST aligned create_memmap_backed_data
lorentzenchr f5949d3
FIX replace CyBinaryCrossEntropy by CyHalfBinomialLoss
lorentzenchr d967740
MNT remove Cython compiler directives due to #21512
lorentzenchr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" | ||
The :mod:`sklearn._loss` module includes loss function classes suitable for | ||
fitting classification and regression tasks. | ||
""" | ||
|
||
from .loss import ( | ||
HalfSquaredError, | ||
AbsoluteError, | ||
PinballLoss, | ||
HalfPoissonLoss, | ||
HalfGammaLoss, | ||
HalfTweedieLoss, | ||
BinaryCrossEntropy, | ||
CategoricalCrossEntropy, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"HalfSquaredError", | ||
"AbsoluteError", | ||
"PinballLoss", | ||
"HalfPoissonLoss", | ||
"HalfGammaLoss", | ||
"HalfTweedieLoss", | ||
"BinaryCrossEntropy", | ||
"CategoricalCrossEntropy", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# cython: language_level=3 | ||
|
||
import numpy as np | ||
cimport numpy as np | ||
|
||
np.import_array() | ||
|
||
|
||
# Fused types for y_true, y_pred, raw_prediction | ||
ctypedef fused Y_DTYPE_C: | ||
np.npy_float64 | ||
np.npy_float32 | ||
|
||
|
||
# Fused types for gradient and hessian | ||
ctypedef fused G_DTYPE_C: | ||
np.npy_float64 | ||
np.npy_float32 | ||
|
||
|
||
# Struct to return 2 doubles | ||
ctypedef struct double_pair: | ||
double val1 | ||
double val2 | ||
|
||
|
||
# C base class for loss functions | ||
cdef class CyLossFunction: | ||
cdef double cy_loss(self, double y_true, double raw_prediction) nogil | ||
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil | ||
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil | ||
|
||
|
||
cdef class CyHalfSquaredError(CyLossFunction): | ||
cdef double cy_loss(self, double y_true, double raw_prediction) nogil | ||
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil | ||
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil | ||
|
||
|
||
cdef class CyAbsoluteError(CyLossFunction): | ||
cdef double cy_loss(self, double y_true, double raw_prediction) nogil | ||
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil | ||
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil | ||
|
||
|
||
cdef class CyPinballLoss(CyLossFunction): | ||
cdef readonly double quantile # readonly makes it inherited by children | ||
lorentzenchr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cdef double cy_loss(self, double y_true, double raw_prediction) nogil | ||
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil | ||
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil | ||
|
||
|
||
cdef class CyHalfPoissonLoss(CyLossFunction): | ||
cdef double cy_loss(self, double y_true, double raw_prediction) nogil | ||
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil | ||
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil | ||
|
||
|
||
cdef class CyHalfGammaLoss(CyLossFunction): | ||
cdef double cy_loss(self, double y_true, double raw_prediction) nogil | ||
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil | ||
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil | ||
|
||
|
||
cdef class CyHalfTweedieLoss(CyLossFunction): | ||
cdef readonly double power # readonly makes it inherited by children | ||
cdef double cy_loss(self, double y_true, double raw_prediction) nogil | ||
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil | ||
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil | ||
|
||
|
||
cdef class CyBinaryCrossEntropy(CyLossFunction): | ||
cdef double cy_loss(self, double y_true, double raw_prediction) nogil | ||
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil | ||
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.