Skip to content

Commit 4166ceb

Browse files
authored
Merge branch 'master' into fix-localmode-subprocess-termination
2 parents dd7f66c + 9be4c8a commit 4166ceb

25 files changed

+268
-1597
lines changed

buildspec-deploy.yml

Lines changed: 0 additions & 29 deletions
This file was deleted.

buildspec-localmodetests.yml

Lines changed: 0 additions & 15 deletions
This file was deleted.

buildspec-notebooktests.yml

Lines changed: 0 additions & 10 deletions
This file was deleted.

buildspec-release.yml

Lines changed: 0 additions & 21 deletions
This file was deleted.

buildspec-slowtests.yml

Lines changed: 0 additions & 15 deletions
This file was deleted.

buildspec-unittests.yml

Lines changed: 0 additions & 22 deletions
This file was deleted.

buildspec.yml

Lines changed: 0 additions & 30 deletions
This file was deleted.

doc/api/inference/model.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,3 @@ Model
1616
:undoc-members:
1717
:show-inheritance:
1818

19-
.. autoclass:: sagemaker.serverless.model.LambdaModel
20-
:members:
21-
:undoc-members:
22-
:show-inheritance:

doc/api/inference/predictors.rst

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,3 @@ Make real-time predictions against SageMaker endpoints with Python objects
77
:members:
88
:undoc-members:
99
:show-inheritance:
10-
11-
.. autoclass:: sagemaker.serverless.predictor.LambdaPredictor
12-
:members:
13-
:undoc-members:
14-
:show-inheritance:

doc/overview.rst

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,50 +1063,6 @@ You can also find these notebooks in the **Advanced Functionality** section of t
10631063
For information about using sample notebooks in a SageMaker notebook instance, see `Use Example Notebooks <https://docs.aws.amazon.com/sagemaker/latest/dg/howitworks-nbexamples.html>`__
10641064
in the AWS documentation.
10651065
1066-
********************
1067-
Serverless Inference
1068-
********************
1069-
1070-
You can use the SageMaker Python SDK to perform serverless inference on Lambda.
1071-
1072-
To deploy models to Lambda, you must complete the following prerequisites:
1073-
1074-
- `Package your model and inference code as a container image. <https://docs.aws.amazon.com/lambda/latest/dg/images-create.html>`_
1075-
- `Create a role that lists Lambda as a trusted entity. <https://docs.aws.amazon.com/lambda/latest/dg/lambda-intro-execution-role.html#permissions-executionrole-console>`_
1076-
1077-
After completing the prerequisites, you can deploy your model to Lambda using
1078-
the `LambdaModel`_ class.
1079-
1080-
.. code:: python
1081-
1082-
from sagemaker.serverless import LambdaModel
1083-
1084-
image_uri = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-lambda-repository:latest"
1085-
role = "arn:aws:iam::123456789012:role/MyLambdaExecutionRole"
1086-
1087-
model = LambdaModel(image_uri=image_uri, role=role)
1088-
predictor = model.deploy("my-lambda-function", timeout=20, memory_size=4092)
1089-
1090-
The ``deploy`` method returns a `LambdaPredictor`_ instance. Use the
1091-
`LambdaPredictor`_ ``predict`` method to perform inference on Lambda.
1092-
1093-
.. code:: python
1094-
1095-
url = "https://example.com/cat.jpeg"
1096-
predictor.predict({"url": url}) # {'class': 'tabby'}
1097-
1098-
Once you are done performing inference on Lambda, free the `LambdaModel`_ and
1099-
`LambdaPredictor`_ resources using the ``delete_model`` and ``delete_predictor``
1100-
methods.
1101-
1102-
.. code:: python
1103-
1104-
model.delete_model()
1105-
predictor.delete_predictor()
1106-
1107-
.. _LambdaModel : https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.serverless.model.LambdaModel
1108-
.. _LambdaPredictor : https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.serverless.predictor.LambdaPredictor
1109-
11101066
******************
11111067
SageMaker Workflow
11121068
******************

src/sagemaker/clarify.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
dataset_type="text/csv",
3939
s3_data_distribution_type="FullyReplicated",
4040
s3_compression_type="None",
41+
joinsource=None,
4142
):
4243
"""Initializes a configuration of both input and output datasets.
4344
@@ -57,6 +58,11 @@ def __init__(
5758
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
5859
"ShardedByS3Key".
5960
s3_compression_type (str): Valid options are "None" or "Gzip".
61+
joinsource (str): The name or index of the column in the dataset that acts an
62+
identifier column (for instance, while performing a join). This column is only
63+
used as an identifier, and not used for any other computations. This is an
64+
optional field in all cases except when the dataset contains more than one file,
65+
and `save_local_shap_values` is set to true in SHAPConfig.
6066
"""
6167
if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]:
6268
raise ValueError(
@@ -77,6 +83,7 @@ def __init__(
7783
_set(features, "features", self.analysis_config)
7884
_set(headers, "headers", self.analysis_config)
7985
_set(label, "label", self.analysis_config)
86+
_set(joinsource, "joinsource_name_or_index", self.analysis_config)
8087

8188
def get_config(self):
8289
"""Returns part of an analysis config dictionary."""
@@ -300,6 +307,37 @@ def get_explainability_config(self):
300307
return None
301308

302309

310+
class PDPConfig(ExplainabilityConfig):
311+
"""Config class for Partial Dependence Plots (PDP).
312+
313+
If PDP is requested, the Partial Dependence Plots will be included in the report, and the
314+
corresponding values will be included in the analysis output.
315+
"""
316+
317+
def __init__(self, features=None, grid_resolution=15, top_k_features=10):
318+
"""Initializes config for PDP.
319+
320+
Args:
321+
features (None or list): List of features names or indices for which partial dependence
322+
plots must be computed and plotted. When ShapConfig is provided, this parameter is
323+
optional as Clarify will try to compute the partial dependence plots for top
324+
feature based on SHAP attributions. When ShapConfig is not provided, 'features'
325+
must be provided.
326+
grid_resolution (int): In case of numerical features, this number represents that
327+
number of buckets that range of values must be divided into. This decides the
328+
granularity of the grid in which the PDP are plotted.
329+
top_k_features (int): Set the number of top SHAP attributes to be selected to compute
330+
partial dependence plots.
331+
"""
332+
self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features}
333+
if features is not None:
334+
self.pdp_config["features"] = features
335+
336+
def get_explainability_config(self):
337+
"""Returns config."""
338+
return copy.deepcopy({"pdp": self.pdp_config})
339+
340+
303341
class SHAPConfig(ExplainabilityConfig):
304342
"""Config class of SHAP."""
305343

@@ -792,8 +830,9 @@ def run_explainability(
792830
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
793831
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
794832
endpoint to be created.
795-
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
796-
specific explainability method. Currently, only SHAP is supported.
833+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
834+
Config of the specific explainability method or a list of ExplainabilityConfig
835+
objects. Currently, SHAP and PDP are the two methods supported.
797836
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
798837
model output for the predicted scores to be explained. This is not required if the
799838
model output is a single score. Alternatively, an instance of
@@ -827,7 +866,30 @@ def run_explainability(
827866
predictor_config.update(predicted_label_config)
828867
else:
829868
_set(model_scores, "label", predictor_config)
830-
analysis_config["methods"] = explainability_config.get_explainability_config()
869+
870+
explainability_methods = {}
871+
if isinstance(explainability_config, list):
872+
if len(explainability_config) == 0:
873+
raise ValueError("Please provide at least one explainability config.")
874+
for config in explainability_config:
875+
explain_config = config.get_explainability_config()
876+
explainability_methods.update(explain_config)
877+
if not len(explainability_methods.keys()) == len(explainability_config):
878+
raise ValueError("Duplicate explainability configs are provided")
879+
if (
880+
"shap" not in explainability_methods
881+
and explainability_methods["pdp"].get("features", None) is None
882+
):
883+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
884+
else:
885+
if (
886+
isinstance(explainability_config, PDPConfig)
887+
and explainability_config.get_explainability_config()["pdp"].get("features", None)
888+
is None
889+
):
890+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
891+
explainability_methods = explainability_config.get_explainability_config()
892+
analysis_config["methods"] = explainability_methods
831893
analysis_config["predictor"] = predictor_config
832894
if job_name is None:
833895
if self.job_name_prefix:

src/sagemaker/deprecations.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,27 @@
2121
V2_URL = "https://sagemaker.readthedocs.io/en/stable/v2.html"
2222

2323

24-
def _warn(msg):
24+
def _warn(msg, sdk_version=None):
2525
"""Generic warning raiser referencing V2
2626
2727
Args:
2828
phrase: The phrase to include in the warning.
29+
sdk_version: the sdk version of removal of support.
2930
"""
30-
full_msg = f"{msg} in sagemaker>=2.\nSee: {V2_URL} for details."
31+
_sdk_version = sdk_version if sdk_version is not None else "2"
32+
full_msg = f"{msg} in sagemaker>={_sdk_version}.\nSee: {V2_URL} for details."
3133
warnings.warn(full_msg, DeprecationWarning, stacklevel=2)
3234
logger.warning(full_msg)
3335

3436

35-
def removed_warning(phrase):
37+
def removed_warning(phrase, sdk_version=None):
3638
"""Raise a warning for a no-op in sagemaker>=2
3739
3840
Args:
3941
phrase: the prefix phrase of the warning message.
42+
sdk_version: the sdk version of removal of support.
4043
"""
41-
_warn(f"{phrase} is a no-op")
44+
_warn(f"{phrase} is a no-op", sdk_version)
4245

4346

4447
def renamed_warning(phrase):
@@ -146,26 +149,32 @@ def func(*args, **kwargs): # pylint: disable=W0613
146149
return func
147150

148151

149-
def deprecated(obj):
152+
def deprecated(sdk_version=None):
150153
"""Decorator for raising deprecated warning for a feature in sagemaker>=2
151154
155+
Args:
156+
sdk_version (str): the sdk version of removal of support.
157+
152158
Usage:
153-
@deprecated
159+
@deprecated()
154160
def sample_function():
155161
print("xxxx....")
156162
157-
@deprecated
163+
@deprecated(sdk_version="2.66")
158164
class SampleClass():
159165
def __init__(self):
160166
print("xxxx....")
161167
162168
"""
163169

164-
def wrapper(*args, **kwargs):
165-
removed_warning(obj.__name__)
166-
return obj(*args, **kwargs)
170+
def deprecate(obj):
171+
def wrapper(*args, **kwargs):
172+
removed_warning(obj.__name__, sdk_version)
173+
return obj(*args, **kwargs)
174+
175+
return wrapper
167176

168-
return wrapper
177+
return deprecate
169178

170179

171180
def deprecated_function(func, name):

0 commit comments

Comments
 (0)