Skip to content

Commit 5cc810c

Browse files
authored
Merge branch 'dev' into Lineage_endpoints_modelPackageGroup
2 parents 605e239 + a8323a9 commit 5cc810c

22 files changed

+677
-33
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ venv/
2727
*.swp
2828
.docker/
2929
env/
30-
.vscode/
30+
.vscode/
31+
.python-version

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ Pipeline
8282
.. autoclass:: sagemaker.workflow.pipeline._PipelineExecution
8383
:members:
8484

85+
Parallelism Configuration
86+
-------------------------
87+
88+
.. autoclass:: sagemaker.workflow.parallelism_config.ParallelismConfiguration
89+
:members:
90+
8591
Pipeline Experiment Config
8692
--------------------------
8793

src/sagemaker/clarify.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,15 @@ def __init__(
290290
probability_threshold (float): An optional value for binary prediction tasks in which
291291
the model returns a probability, to indicate the threshold to convert the
292292
prediction to a boolean value. Default is 0.5.
293-
label_headers (list): List of label values - one for each score of the ``probability``.
293+
label_headers (list[str]): List of headers, each for a predicted score in model output.
294+
For bias analysis, it is used to extract the label value with the highest score as
295+
predicted label. For explainability job, It is used to beautify the analysis report
296+
by replacing placeholders like "label0".
294297
"""
295298
self.label = label
296299
self.probability = probability
297300
self.probability_threshold = probability_threshold
301+
self.label_headers = label_headers
298302
if probability_threshold is not None:
299303
try:
300304
float(probability_threshold)
@@ -1060,10 +1064,10 @@ def run_explainability(
10601064
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
10611065
Config of the specific explainability method or a list of ExplainabilityConfig
10621066
objects. Currently, SHAP and PDP are the two methods supported.
1063-
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
1064-
model output for the predicted scores to be explained. This is not required if the
1065-
model output is a single score. Alternatively, an instance of
1066-
ModelPredictedLabelConfig can be provided.
1067+
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
1068+
Index or JSONPath to locate the predicted scores in the model output. This is not
1069+
required if the model output is a single score. Alternatively, it can be an instance
1070+
of ModelPredictedLabelConfig to provide more parameters like label_headers.
10671071
wait (bool): Whether the call should wait until the job completes (default: True).
10681072
logs (bool): Whether to show the logs produced by the job.
10691073
Only meaningful when ``wait`` is True (default: True).

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,6 +2343,7 @@ def _stage_user_code_in_s3(self):
23432343
dependencies=self.dependencies,
23442344
kms_key=kms_key,
23452345
s3_resource=self.sagemaker_session.s3_resource,
2346+
settings=self.sagemaker_session.settings,
23462347
)
23472348

23482349
def _model_source_dir(self):

src/sagemaker/fw_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import shutil
2020
import tempfile
2121
from collections import namedtuple
22+
from typing import Optional
2223

2324
import sagemaker.image_uris
25+
from sagemaker.session_settings import SessionSettings
2426
import sagemaker.utils
2527

2628
from sagemaker.deprecations import renamed_warning
@@ -203,6 +205,7 @@ def tar_and_upload_dir(
203205
dependencies=None,
204206
kms_key=None,
205207
s3_resource=None,
208+
settings: Optional[SessionSettings] = None,
206209
):
207210
"""Package source files and upload a compress tar file to S3.
208211
@@ -230,6 +233,9 @@ def tar_and_upload_dir(
230233
s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource
231234
for S3 connections, can be used to customize the configuration,
232235
e.g. set the endpoint URL (default: None).
236+
settings (sagemaker.session_settings.SessionSettings): Optional. The settings
237+
of the SageMaker ``Session``, can be used to override the default encryption
238+
behavior (default: None).
233239
Returns:
234240
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
235241
script name.
@@ -241,6 +247,7 @@ def tar_and_upload_dir(
241247
dependencies = dependencies or []
242248
key = "%s/sourcedir.tar.gz" % s3_key_prefix
243249
tmp = tempfile.mkdtemp()
250+
encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts
244251

245252
try:
246253
source_files = _list_files_to_compress(script, directory) + dependencies
@@ -250,6 +257,10 @@ def tar_and_upload_dir(
250257

251258
if kms_key:
252259
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
260+
elif encrypt_artifact:
261+
# encrypt the tarball at rest in S3 with the default AWS managed KMS key for S3
262+
# see https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html#API_PutObject_RequestSyntax
263+
extra_args = {"ServerSideEncryption": "aws:kms"}
253264
else:
254265
extra_args = None
255266

src/sagemaker/lineage/query.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
self._session = sagemaker_session
8686

8787
def to_lineage_object(self):
88-
"""Convert the ``Vertex`` object to its corresponding ``Artifact`` or ``Context`` object."""
88+
"""Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object."""
8989
from sagemaker.lineage.artifact import Artifact, ModelArtifact
9090
from sagemaker.lineage.context import Context, EndpointContext
9191
from sagemaker.lineage.artifact import DatasetArtifact
@@ -213,6 +213,44 @@ def _convert_api_response(self, response) -> LineageQueryResult:
213213

214214
return converted
215215

216+
def _collapse_cross_account_artifacts(self, query_response):
217+
"""Collapse the duplicate vertices and edges for cross-account."""
218+
for edge in query_response.edges:
219+
if (
220+
"artifact" in edge.source_arn
221+
and "artifact" in edge.destination_arn
222+
and edge.source_arn.split("/")[1] == edge.destination_arn.split("/")[1]
223+
and edge.source_arn != edge.destination_arn
224+
):
225+
edge_source_arn = edge.source_arn
226+
edge_destination_arn = edge.destination_arn
227+
self._update_cross_account_edge(
228+
edges=query_response.edges,
229+
arn=edge_source_arn,
230+
duplicate_arn=edge_destination_arn,
231+
)
232+
self._update_cross_account_vertex(
233+
query_response=query_response, duplicate_arn=edge_destination_arn
234+
)
235+
236+
# remove the duplicate edges from cross account
237+
new_edge = [e for e in query_response.edges if not e.source_arn == e.destination_arn]
238+
query_response.edges = new_edge
239+
240+
return query_response
241+
242+
def _update_cross_account_edge(self, edges, arn, duplicate_arn):
243+
"""Replace the duplicate arn with arn in edges list."""
244+
for idx, e in enumerate(edges):
245+
if e.destination_arn == duplicate_arn:
246+
edges[idx].destination_arn = arn
247+
elif e.source_arn == duplicate_arn:
248+
edges[idx].source_arn = arn
249+
250+
def _update_cross_account_vertex(self, query_response, duplicate_arn):
251+
"""Remove the vertex with duplicate arn in the vertices list."""
252+
query_response.vertices = [v for v in query_response.vertices if not v.arn == duplicate_arn]
253+
216254
def query(
217255
self,
218256
start_arns: List[str],
@@ -240,5 +278,7 @@ def query(
240278
Filters=query_filter._to_request_dict() if query_filter else {},
241279
MaxDepth=max_depth,
242280
)
281+
query_response = self._convert_api_response(query_response)
282+
query_response = self._collapse_cross_account_artifacts(query_response)
243283

244-
return self._convert_api_response(query_response)
284+
return query_response

src/sagemaker/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def _upload_code(self, key_prefix, repack=False):
11311131
script=self.entry_point,
11321132
directory=self.source_dir,
11331133
dependencies=self.dependencies,
1134+
settings=self.sagemaker_session.settings,
11341135
)
11351136

11361137
if repack and self.model_data is not None and self.entry_point is not None:

src/sagemaker/model_monitor/clarify_model_monitoring.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker import image_uris, s3
2727
from sagemaker.session import Session
2828
from sagemaker.utils import name_from_base
29-
from sagemaker.clarify import SageMakerClarifyProcessor
29+
from sagemaker.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig
3030

3131
_LOGGER = logging.getLogger(__name__)
3232

@@ -833,9 +833,10 @@ def suggest_baseline(
833833
specific explainability method. Currently, only SHAP is supported.
834834
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
835835
endpoint to be created.
836-
model_scores (int or str): Index or JSONPath location in the model output for the
837-
predicted scores to be explained. This is not required if the model output is
838-
a single score.
836+
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
837+
Index or JSONPath to locate the predicted scores in the model output. This is not
838+
required if the model output is a single score. Alternatively, it can be an instance
839+
of ModelPredictedLabelConfig to provide more parameters like label_headers.
839840
wait (bool): Whether the call should wait until the job completes (default: False).
840841
logs (bool): Whether to show the logs produced by the job.
841842
Only meaningful when wait is True (default: False).
@@ -865,14 +866,24 @@ def suggest_baseline(
865866
headers = copy.deepcopy(data_config.headers)
866867
if headers and data_config.label in headers:
867868
headers.remove(data_config.label)
869+
if model_scores is None:
870+
inference_attribute = None
871+
label_headers = None
872+
elif isinstance(model_scores, ModelPredictedLabelConfig):
873+
inference_attribute = str(model_scores.label)
874+
label_headers = model_scores.label_headers
875+
else:
876+
inference_attribute = str(model_scores)
877+
label_headers = None
868878
self.latest_baselining_job_config = ClarifyBaseliningConfig(
869879
analysis_config=ExplainabilityAnalysisConfig(
870880
explainability_config=explainability_config,
871881
model_config=model_config,
872882
headers=headers,
883+
label_headers=label_headers,
873884
),
874885
features_attribute=data_config.features,
875-
inference_attribute=model_scores if model_scores is None else str(model_scores),
886+
inference_attribute=inference_attribute,
876887
)
877888
self.latest_baselining_job_name = baselining_job_name
878889
self.latest_baselining_job = ClarifyBaseliningJob(
@@ -1166,7 +1177,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
11661177
class ExplainabilityAnalysisConfig:
11671178
"""Analysis configuration for ModelExplainabilityMonitor."""
11681179

1169-
def __init__(self, explainability_config, model_config, headers=None):
1180+
def __init__(self, explainability_config, model_config, headers=None, label_headers=None):
11701181
"""Creates an analysis config dictionary.
11711182
11721183
Args:
@@ -1175,13 +1186,19 @@ def __init__(self, explainability_config, model_config, headers=None):
11751186
model_config (sagemaker.clarify.ModelConfig): Config object related to bias
11761187
configurations.
11771188
headers (list[str]): A list of feature names (without label) of model/endpint input.
1189+
label_headers (list[str]): List of headers, each for a predicted score in model output.
1190+
It is used to beautify the analysis report by replacing placeholders like "label0".
1191+
11781192
"""
1193+
predictor_config = model_config.get_predictor_config()
11791194
self.analysis_config = {
11801195
"methods": explainability_config.get_explainability_config(),
1181-
"predictor": model_config.get_predictor_config(),
1196+
"predictor": predictor_config,
11821197
}
11831198
if headers is not None:
11841199
self.analysis_config["headers"] = headers
1200+
if label_headers is not None:
1201+
predictor_config["label_headers"] = label_headers
11851202

11861203
def _to_dict(self):
11871204
"""Generates a request dictionary using the parameters provided to the class."""

src/sagemaker/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
sts_regional_endpoint,
4343
)
4444
from sagemaker import exceptions
45+
from sagemaker.session_settings import SessionSettings
4546

4647
LOGGER = logging.getLogger("sagemaker")
4748

@@ -85,6 +86,7 @@ def __init__(
8586
sagemaker_runtime_client=None,
8687
sagemaker_featurestore_runtime_client=None,
8788
default_bucket=None,
89+
settings=SessionSettings(),
8890
):
8991
"""Initialize a SageMaker ``Session``.
9092
@@ -110,13 +112,16 @@ def __init__(
110112
If not provided, a default bucket will be created based on the following format:
111113
"sagemaker-{region}-{aws-account-id}".
112114
Example: "sagemaker-my-custom-bucket".
115+
settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional
116+
parameters to apply to the session.
113117
"""
114118
self._default_bucket = None
115119
self._default_bucket_name_override = default_bucket
116120
self.s3_resource = None
117121
self.s3_client = None
118122
self.config = None
119123
self.lambda_client = None
124+
self.settings = settings
120125

121126
self._initialize(
122127
boto_session=boto_session,

src/sagemaker/session_settings.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Defines classes to parametrize a SageMaker ``Session``."""
14+
15+
from __future__ import absolute_import
16+
17+
18+
class SessionSettings(object):
19+
"""Optional container class for settings to apply to a SageMaker session."""
20+
21+
def __init__(self, encrypt_repacked_artifacts=True) -> None:
22+
"""Initialize the ``SessionSettings`` of a SageMaker ``Session``.
23+
24+
Args:
25+
encrypt_repacked_artifacts (bool): Flag to indicate whether to encrypt the artifacts
26+
at rest in S3 using the default AWS managed KMS key for S3 when a custom KMS key
27+
is not provided (Default: True).
28+
"""
29+
self._encrypt_repacked_artifacts = encrypt_repacked_artifacts
30+
31+
@property
32+
def encrypt_repacked_artifacts(self) -> bool:
33+
"""Return True if repacked artifacts at rest in S3 should be encrypted by default."""
34+
return self._encrypt_repacked_artifacts

src/sagemaker/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from six.moves.urllib import parse
3030

3131
from sagemaker import deprecations
32+
from sagemaker.session_settings import SessionSettings
3233

3334

3435
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
@@ -429,8 +430,15 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
429430
bucket, key = url.netloc, url.path.lstrip("/")
430431
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
431432

433+
settings = (
434+
sagemaker_session.settings if sagemaker_session is not None else SessionSettings()
435+
)
436+
encrypt_artifact = settings.encrypt_repacked_artifacts
437+
432438
if kms_key:
433439
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
440+
elif encrypt_artifact:
441+
extra_args = {"ServerSideEncryption": "aws:kms"}
434442
else:
435443
extra_args = None
436444
sagemaker_session.boto_session.resource(
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline Parallelism Configuration"""
14+
from __future__ import absolute_import
15+
from sagemaker.workflow.entities import RequestType
16+
17+
18+
class ParallelismConfiguration:
19+
"""Parallelism config for SageMaker pipeline."""
20+
21+
def __init__(self, max_parallel_execution_steps: int):
22+
"""Create a ParallelismConfiguration
23+
24+
Args:
25+
max_parallel_execution_steps, int:
26+
max number of steps which could be parallelized
27+
"""
28+
self.max_parallel_execution_steps = max_parallel_execution_steps
29+
30+
def to_request(self) -> RequestType:
31+
"""Returns: the request structure."""
32+
return {
33+
"MaxParallelExecutionSteps": self.max_parallel_execution_steps,
34+
}

0 commit comments

Comments
 (0)