Skip to content

Commit 6b6ccd1

Browse files
authored
Merge branch 'master' into feat/jumpstart-instance-type-variants
2 parents e1d8d30 + d6f58b3 commit 6b6ccd1

File tree

6 files changed

+71
-2
lines changed

6 files changed

+71
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def __init__(
173173
instance_groups: Optional[List[InstanceGroup]] = None,
174174
training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None,
175175
training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None,
176+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
176177
container_entry_point: Optional[List[str]] = None,
177178
container_arguments: Optional[List[str]] = None,
178179
disable_output_compression: bool = False,
@@ -536,6 +537,8 @@ def __init__(
536537
a training job.
537538
disable_output_compression (bool): Optional. When set to true, Model is uploaded
538539
to Amazon S3 without compression after training finishes.
540+
enable_infra_check (bool or PipelineVariable): Optional.
541+
Specifies whether it is running Sagemaker built-in infra check jobs.
539542
"""
540543
instance_count = renamed_kwargs(
541544
"train_instance_count", "instance_count", instance_count, kwargs
@@ -665,6 +668,7 @@ def __init__(
665668
training_repository_credentials_provider_arn
666669
)
667670

671+
self.enable_infra_check = enable_infra_check
668672
# container entry point / arguments configs
669673
self.container_entry_point = container_entry_point
670674
self.container_arguments = container_arguments
@@ -1904,6 +1908,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19041908
"EnableInterContainerTrafficEncryption"
19051909
]
19061910

1911+
if "InfraCheckConfig" in job_details:
1912+
init_params["enable_infra_check"] = job_details["InfraCheckConfig"].get(
1913+
"EnableInfraCheck"
1914+
)
1915+
19071916
subnets, security_group_ids = vpc_utils.from_dict(job_details.get(vpc_utils.VPC_CONFIG_KEY))
19081917
if subnets:
19091918
init_params["subnets"] = subnets
@@ -2446,6 +2455,10 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
24462455
] = estimator.training_repository_credentials_provider_arn
24472456
train_args["training_image_config"] = training_image_config
24482457

2458+
if estimator.enable_infra_check is not None:
2459+
infra_check_config = {"EnableInfraCheck": estimator.enable_infra_check}
2460+
train_args["infra_check_config"] = infra_check_config
2461+
24492462
if estimator.container_entry_point is not None:
24502463
train_args["container_entry_point"] = estimator.container_entry_point
24512464

@@ -2661,6 +2674,7 @@ def __init__(
26612674
container_entry_point: Optional[List[str]] = None,
26622675
container_arguments: Optional[List[str]] = None,
26632676
disable_output_compression: bool = False,
2677+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
26642678
**kwargs,
26652679
):
26662680
"""Initialize an ``Estimator`` instance.
@@ -3020,6 +3034,8 @@ def __init__(
30203034
a training job.
30213035
disable_output_compression (bool): Optional. When set to true, Model is uploaded
30223036
to Amazon S3 without compression after training finishes.
3037+
enable_infra_check (bool or PipelineVariable): Optional.
3038+
Specifies whether it is running Sagemaker built-in infra check jobs.
30233039
"""
30243040
self.image_uri = image_uri
30253041
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}

src/sagemaker/fw_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@
151151
"1.12.1",
152152
"1.13.1",
153153
"2.0.0",
154+
"2.0.1",
154155
]
155156

156157

157-
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0"]
158+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1"]
158159

159160
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
160161
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [

src/sagemaker/session.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ def train( # noqa: C901
674674
enable_network_isolation=None,
675675
image_uri=None,
676676
training_image_config=None,
677+
infra_check_config=None,
677678
container_entry_point=None,
678679
container_arguments=None,
679680
algorithm_arn=None,
@@ -803,6 +804,15 @@ def train( # noqa: C901
803804
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
804805
* max_retry_attsmpts (int): Number of times a job should be retried.
805806
The key in RetryStrategy is 'MaxRetryAttempts'.
807+
infra_check_config(dict): Infra check configuration.
808+
Optionally, the dict can contain 'EnableInfraCheck'(bool).
809+
For example,
810+
811+
.. code:: python
812+
813+
infra_check_config = {
814+
"EnableInfraCheck": True,
815+
}
806816
Returns:
807817
str: ARN of the training job, if it is created.
808818
"""
@@ -866,6 +876,7 @@ def train( # noqa: C901
866876
enable_network_isolation=enable_network_isolation,
867877
image_uri=image_uri,
868878
training_image_config=training_image_config,
879+
infra_check_config=infra_check_config,
869880
container_entry_point=container_entry_point,
870881
container_arguments=container_arguments,
871882
algorithm_arn=algorithm_arn,
@@ -907,6 +918,7 @@ def _get_train_request( # noqa: C901
907918
enable_network_isolation=False,
908919
image_uri=None,
909920
training_image_config=None,
921+
infra_check_config=None,
910922
container_entry_point=None,
911923
container_arguments=None,
912924
algorithm_arn=None,
@@ -1063,6 +1075,9 @@ def _get_train_request( # noqa: C901
10631075
if training_image_config is not None:
10641076
train_request["AlgorithmSpecification"]["TrainingImageConfig"] = training_image_config
10651077

1078+
if infra_check_config is not None:
1079+
train_request["InfraCheckConfig"] = infra_check_config
1080+
10661081
if container_entry_point is not None:
10671082
train_request["AlgorithmSpecification"]["ContainerEntrypoint"] = container_entry_point
10681083

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
677677
Please add the new argument to the skip set below,
678678
and reach out to JumpStart team."""
679679

680-
init_args_to_skip: Set[str] = set(["kwargs"])
680+
init_args_to_skip: Set[str] = set(["kwargs", "enable_infra_check"])
681681
fit_args_to_skip: Set[str] = set()
682682
deploy_args_to_skip: Set[str] = set(["kwargs"])
683683

tests/unit/test_estimator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
REPO_DIR = "/tmp/repo_dir"
104104
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
105105
TRAINING_REPOSITORY_ACCESS_MODE = "VPC"
106+
ENABLE_INFRA_CHECK = True
106107
TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN = "arn:aws:lambda:us-west-2:1234567890:function:test"
107108
CONTAINER_ENTRY_POINT = ["entry_point1", "entry_point2"]
108109
CONTAINER_ARGUMENTS = ["container_arg1", "container_arg2"]
@@ -769,6 +770,39 @@ def test_framework_without_training_repository_config(sagemaker_session):
769770
assert args.get("training_image_config") is None
770771

771772

773+
def test_framework_without_infra_check_config(sagemaker_session):
774+
f = DummyFramework(
775+
entry_point=SCRIPT_PATH,
776+
role=ROLE,
777+
sagemaker_session=sagemaker_session,
778+
instance_groups=[
779+
InstanceGroup("group1", "ml.c4.xlarge", 1),
780+
InstanceGroup("group2", "ml.m4.xlarge", 2),
781+
],
782+
)
783+
f.fit("s3://mydata")
784+
sagemaker_session.train.assert_called_once()
785+
_, args = sagemaker_session.train.call_args
786+
assert args.get("health_check_config") is None
787+
788+
789+
def test_framework_with_infra_check_config(sagemaker_session):
790+
f = DummyFramework(
791+
entry_point=SCRIPT_PATH,
792+
role=ROLE,
793+
sagemaker_session=sagemaker_session,
794+
instance_groups=[
795+
InstanceGroup("group1", "ml.c4.xlarge", 1),
796+
InstanceGroup("group2", "ml.m4.xlarge", 2),
797+
],
798+
enable_infra_check=ENABLE_INFRA_CHECK,
799+
)
800+
f.fit("s3://mydata")
801+
sagemaker_session.train.assert_called_once()
802+
_, args = sagemaker_session.train.call_args
803+
assert args["infra_check_config"]["EnableInfraCheck"] == ENABLE_INFRA_CHECK
804+
805+
772806
def test_framework_with_container_entry_point(sagemaker_session):
773807
f = DummyFramework(
774808
entry_point=SCRIPT_PATH,

tests/unit/test_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,6 +1899,7 @@ def test_train_with_sagemaker_config_injection(sagemaker_session):
18991899
"TrainingRepositoryCredentialsProviderArn": "arn:aws:lambda:us-west-2:1234567897:function:test"
19001900
},
19011901
}
1902+
INFRA_CHECK_CONFIG = {"EnableInfraCheck": True}
19021903
CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"]
19031904
CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"]
19041905

@@ -1920,6 +1921,7 @@ def test_train_with_sagemaker_config_injection(sagemaker_session):
19201921
training_image_config=TRAINING_IMAGE_CONFIG,
19211922
container_entry_point=CONTAINER_ENTRY_POINT,
19221923
container_arguments=CONTAINER_ARGUMENTS,
1924+
infra_check_config=INFRA_CHECK_CONFIG,
19231925
)
19241926

19251927
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
@@ -1966,6 +1968,7 @@ def test_train_with_sagemaker_config_injection(sagemaker_session):
19661968
actual_train_args["AlgorithmSpecification"]["ContainerEntrypoint"] == CONTAINER_ENTRY_POINT
19671969
)
19681970
assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS
1971+
assert actual_train_args["InfraCheckConfig"] == INFRA_CHECK_CONFIG
19691972
assert actual_train_args["RoleArn"] == expected_role_arn
19701973
assert actual_train_args["ResourceConfig"] == {
19711974
"InstanceCount": INSTANCE_COUNT,

0 commit comments

Comments
 (0)