Skip to content

Commit c8b8f2f

Browse files
authored
Merge branch 'master' into feat/instance-specific-jumpstart-host-requirements
2 parents 324d0c7 + c799d1a commit c8b8f2f

File tree

14 files changed

+341
-7
lines changed

14 files changed

+341
-7
lines changed

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,6 +1958,7 @@
19581958
"ap-northeast-2": "763104351884",
19591959
"ap-northeast-3": "364406365360",
19601960
"ap-south-1": "763104351884",
1961+
"ap-south-2": "772153158452",
19611962
"ap-southeast-1": "763104351884",
19621963
"ap-southeast-2": "763104351884",
19631964
"ap-southeast-3": "907027046896",
@@ -1966,11 +1967,13 @@
19661967
"cn-north-1": "727897471807",
19671968
"cn-northwest-1": "727897471807",
19681969
"eu-central-1": "763104351884",
1970+
"eu-central-2": "380420809688",
19691971
"eu-north-1": "763104351884",
19701972
"eu-west-1": "763104351884",
19711973
"eu-west-2": "763104351884",
19721974
"eu-west-3": "763104351884",
19731975
"eu-south-1": "692866216735",
1976+
"eu-south-2": "503227376785",
19741977
"me-south-1": "217643126080",
19751978
"sa-east-1": "763104351884",
19761979
"us-east-1": "763104351884",
@@ -1997,6 +2000,7 @@
19972000
"ap-northeast-2": "763104351884",
19982001
"ap-northeast-3": "364406365360",
19992002
"ap-south-1": "763104351884",
2003+
"ap-south-2": "772153158452",
20002004
"ap-southeast-1": "763104351884",
20012005
"ap-southeast-2": "763104351884",
20022006
"ap-southeast-3": "907027046896",
@@ -2005,11 +2009,13 @@
20052009
"cn-north-1": "727897471807",
20062010
"cn-northwest-1": "727897471807",
20072011
"eu-central-1": "763104351884",
2012+
"eu-central-2": "380420809688",
20082013
"eu-north-1": "763104351884",
20092014
"eu-west-1": "763104351884",
20102015
"eu-west-2": "763104351884",
20112016
"eu-west-3": "763104351884",
20122017
"eu-south-1": "692866216735",
2018+
"eu-south-2": "503227376785",
20132019
"me-south-1": "217643126080",
20142020
"sa-east-1": "763104351884",
20152021
"us-east-1": "763104351884",
@@ -2036,6 +2042,7 @@
20362042
"ap-northeast-2": "763104351884",
20372043
"ap-northeast-3": "364406365360",
20382044
"ap-south-1": "763104351884",
2045+
"ap-south-2": "772153158452",
20392046
"ap-southeast-1": "763104351884",
20402047
"ap-southeast-2": "763104351884",
20412048
"ap-southeast-3": "907027046896",
@@ -2044,11 +2051,13 @@
20442051
"cn-north-1": "727897471807",
20452052
"cn-northwest-1": "727897471807",
20462053
"eu-central-1": "763104351884",
2054+
"eu-central-2": "380420809688",
20472055
"eu-north-1": "763104351884",
20482056
"eu-west-1": "763104351884",
20492057
"eu-west-2": "763104351884",
20502058
"eu-west-3": "763104351884",
20512059
"eu-south-1": "692866216735",
2060+
"eu-south-2": "503227376785",
20522061
"me-south-1": "217643126080",
20532062
"sa-east-1": "763104351884",
20542063
"us-east-1": "763104351884",
@@ -2075,6 +2084,7 @@
20752084
"ap-northeast-2": "763104351884",
20762085
"ap-northeast-3": "364406365360",
20772086
"ap-south-1": "763104351884",
2087+
"ap-south-2": "772153158452",
20782088
"ap-southeast-1": "763104351884",
20792089
"ap-southeast-2": "763104351884",
20802090
"ap-southeast-3": "907027046896",
@@ -2083,11 +2093,13 @@
20832093
"cn-north-1": "727897471807",
20842094
"cn-northwest-1": "727897471807",
20852095
"eu-central-1": "763104351884",
2096+
"eu-central-2": "380420809688",
20862097
"eu-north-1": "763104351884",
20872098
"eu-west-1": "763104351884",
20882099
"eu-west-2": "763104351884",
20892100
"eu-west-3": "763104351884",
20902101
"eu-south-1": "692866216735",
2102+
"eu-south-2": "503227376785",
20912103
"me-south-1": "217643126080",
20922104
"sa-east-1": "763104351884",
20932105
"us-east-1": "763104351884",

src/sagemaker/remote_function/core/serialization.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,12 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
141141
return cloudpickle.loads(bytes_to_deserialize)
142142
except Exception as e:
143143
raise DeserializationError(
144-
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e))
144+
"Error when deserializing bytes downloaded from {}: {}. "
145+
"NOTE: this may be caused by inconsistent sagemaker python sdk versions "
146+
"where remote function runs versus the one used on client side. "
147+
"If the sagemaker versions do not match, a warning message would "
148+
"be logged starting with 'Inconsistent sagemaker versions found'. "
149+
"Please check it to validate.".format(s3_uri, repr(e))
145150
) from e
146151

147152

src/sagemaker/remote_function/job.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,12 @@ def compile(
786786
container_args.extend(
787787
["--client_python_version", RuntimeEnvironmentManager()._current_python_version()]
788788
)
789+
container_args.extend(
790+
[
791+
"--client_sagemaker_pysdk_version",
792+
RuntimeEnvironmentManager()._current_sagemaker_pysdk_version(),
793+
]
794+
)
789795
container_args.extend(
790796
[
791797
"--dependency_settings",

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def main(sys_args=None):
5656
try:
5757
args = _parse_args(sys_args)
5858
client_python_version = args.client_python_version
59+
client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version
5960
job_conda_env = args.job_conda_env
6061
pipeline_execution_id = args.pipeline_execution_id
6162
dependency_settings = _DependencySettings.from_string(args.dependency_settings)
@@ -64,6 +65,9 @@ def main(sys_args=None):
6465
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
6566

6667
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
68+
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
69+
client_sagemaker_pysdk_version
70+
)
6771

6872
user = getpass.getuser()
6973
if user != "root":
@@ -274,6 +278,7 @@ def _parse_args(sys_args):
274278
parser = argparse.ArgumentParser()
275279
parser.add_argument("--job_conda_env", type=str)
276280
parser.add_argument("--client_python_version", type=str)
281+
parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None)
277282
parser.add_argument("--pipeline_execution_id", type=str)
278283
parser.add_argument("--dependency_settings", type=str)
279284
parser.add_argument("--func_step_s3_dir", type=str)

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import dataclasses
2525
import json
2626

27+
import sagemaker
28+
2729

2830
class _UTCFormatter(logging.Formatter):
2931
"""Class that overrides the default local time provider in log formatter."""
@@ -326,6 +328,11 @@ def _current_python_version(self):
326328

327329
return f"{sys.version_info.major}.{sys.version_info.minor}".strip()
328330

331+
def _current_sagemaker_pysdk_version(self):
332+
"""Returns the current sagemaker python sdk version where program is running"""
333+
334+
return sagemaker.__version__
335+
329336
def _validate_python_version(self, client_python_version: str, conda_env: str = None):
330337
"""Validate the python version
331338
@@ -344,6 +351,29 @@ def _validate_python_version(self, client_python_version: str, conda_env: str =
344351
f"is same as the local python version."
345352
)
346353

354+
def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version):
355+
"""Validate the sagemaker python sdk version
356+
357+
Validates if the sagemaker python sdk version where remote function runs
358+
matches the one used on client side.
359+
Otherwise, log a warning to call out that unexpected behaviors
360+
may occur in this case.
361+
"""
362+
job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version()
363+
if (
364+
client_sagemaker_pysdk_version
365+
and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version
366+
):
367+
logger.warning(
368+
"Inconsistent sagemaker versions found: "
369+
"sagemaker pysdk version found in the container is "
370+
"'%s' which does not match the '%s' on the local client. "
371+
"Please make sure that the python version used in the training container "
372+
"is the same as the local python version in case of unexpected behaviors.",
373+
job_sagemaker_pysdk_version,
374+
client_sagemaker_pysdk_version,
375+
)
376+
347377

348378
def _run_and_get_output_shell_cmd(cmd: str) -> str:
349379
"""Run and return the output of the given shell command"""

src/sagemaker/transformer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ def transform_with_monitoring(
337337
wait: bool = True,
338338
pipeline_name: str = None,
339339
role: str = None,
340+
fail_on_violation: bool = True,
340341
):
341342
"""Runs a transform job with monitoring job.
342343
@@ -352,7 +353,6 @@ def transform_with_monitoring(
352353
]): the monitoring configuration used for run model monitoring.
353354
monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`):
354355
the check job (processing job) cluster resource configuration.
355-
transform_step_args (_JobStepArguments): the transform step transform arguments.
356356
data (str): Input data location in S3 for the transform job
357357
data_type (str): What the S3 location defines (default: 'S3Prefix').
358358
Valid values:
@@ -400,8 +400,6 @@ def transform_with_monitoring(
400400
monitor_before_transform (bgool): If to run data quality
401401
or model explainability monitoring type,
402402
a true value of this flag indicates running the check step before the transform job.
403-
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
404-
check step when a violation is detected.
405403
supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path
406404
to the supplied statistics object representing the statistics JSON file
407405
which will be used for drift to check (default: None).
@@ -411,6 +409,8 @@ def transform_with_monitoring(
411409
wait (bool): To determine if needed to wait for the pipeline execution to complete
412410
pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step
413411
role (str): Execution role
412+
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
413+
check step when a violation is detected.
414414
"""
415415

416416
transformer = self
@@ -454,6 +454,7 @@ def transform_with_monitoring(
454454
monitor_before_transform=monitor_before_transform,
455455
supplied_baseline_constraints=supplied_baseline_constraints,
456456
supplied_baseline_statistics=supplied_baseline_statistics,
457+
fail_on_violation=fail_on_violation,
457458
)
458459

459460
pipeline_name = (

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def test_gated_model_training_v1(setup):
108108
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
109109
environment={"accept_eula": "true"},
110110
max_run=259200, # avoid exceeding resource limits
111+
tolerate_vulnerable_model=True,
111112
)
112113

113114
# uses ml.g5.12xlarge instance

tests/integ/test_transformer.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,3 +709,67 @@ def test_transformer_and_monitoring_job(
709709
assert execution_step["StepStatus"] == "Succeeded"
710710

711711
xgb_model.delete_model()
712+
713+
714+
def test_transformer_and_monitoring_job_to_pass_with_no_failure_in_violation(
715+
pipeline_session,
716+
sagemaker_session,
717+
role,
718+
pipeline_name,
719+
check_job_config,
720+
data_bias_check_config,
721+
):
722+
xgb_model_data_s3 = pipeline_session.upload_data(
723+
path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"),
724+
key_prefix="integ-test-data/xgboost/model",
725+
)
726+
data_bias_supplied_baseline_constraints = Constraints.from_file_path(
727+
constraints_file_path=os.path.join(
728+
DATA_DIR, "pipeline/clarify_check_step/data_bias/bad_cases/analysis.json"
729+
),
730+
sagemaker_session=sagemaker_session,
731+
).file_s3_uri
732+
733+
xgb_model = XGBoostModel(
734+
model_data=xgb_model_data_s3,
735+
framework_version="1.3-1",
736+
role=role,
737+
sagemaker_session=sagemaker_session,
738+
entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"),
739+
enable_network_isolation=True,
740+
)
741+
742+
xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE)
743+
744+
transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform"
745+
transformer = Transformer(
746+
model_name=xgb_model.name,
747+
strategy="SingleRecord",
748+
instance_type="ml.m5.xlarge",
749+
instance_count=1,
750+
output_path=transform_output,
751+
sagemaker_session=pipeline_session,
752+
)
753+
754+
transform_input = pipeline_session.upload_data(
755+
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
756+
key_prefix="integ-test-data/xgboost_abalone/abalone",
757+
)
758+
759+
execution = transformer.transform_with_monitoring(
760+
monitoring_config=data_bias_check_config,
761+
monitoring_resource_config=check_job_config,
762+
data=transform_input,
763+
content_type="text/libsvm",
764+
supplied_baseline_constraints=data_bias_supplied_baseline_constraints,
765+
role=role,
766+
fail_on_violation=False,
767+
)
768+
769+
execution_steps = execution.list_steps()
770+
assert len(execution_steps) == 2
771+
772+
for execution_step in execution_steps:
773+
assert execution_step["StepStatus"] == "Succeeded"
774+
775+
xgb_model.delete_model()

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from unittest import TestCase
55
from unittest.mock import Mock, patch
6+
import datetime
67

78
import pytest
89
from sagemaker.jumpstart.constants import (
@@ -207,6 +208,10 @@ def test_list_jumpstart_models_simple_case(
207208
patched_get_manifest.assert_called()
208209
patched_get_model_specs.assert_not_called()
209210

211+
@pytest.mark.skipif(
212+
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
213+
reason="Contact JumpStart team to fix flaky test.",
214+
)
210215
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
211216
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
212217
def test_list_jumpstart_models_script_filter(

tests/unit/sagemaker/remote_function/core/test_serialization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def square(x):
198198
with pytest.raises(
199199
DeserializationError,
200200
match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: "
201-
+ r"RuntimeError\('some failure when loads'\)",
201+
+ r"RuntimeError\('some failure when loads'\). "
202+
+ r"NOTE: this may be caused by inconsistent sagemaker python sdk versions",
202203
):
203204
deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
204205

@@ -397,7 +398,8 @@ def __init__(self, x):
397398
with pytest.raises(
398399
DeserializationError,
399400
match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: "
400-
+ r"RuntimeError\('some failure when loads'\)",
401+
+ r"RuntimeError\('some failure when loads'\). "
402+
+ r"NOTE: this may be caused by inconsistent sagemaker python sdk versions",
401403
):
402404
deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)
403405

0 commit comments

Comments
 (0)