Skip to content

Commit 464e19d

Browse files
brunopistonemollyheamazon
authored andcommitted
Aligned disable_output_compression for @Remote with Estimator (#5094)
1 parent b50c293 commit 464e19d

File tree

5 files changed

+69
-6
lines changed

5 files changed

+69
-6
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def remote(
9090
spark_config: SparkConfig = None,
9191
use_spot_instances=False,
9292
max_wait_time_in_seconds=None,
93+
disable_output_compression: bool = False,
9394
use_torchrun: bool = False,
9495
use_mpirun: bool = False,
9596
nproc_per_node: Optional[int] = None,
@@ -283,13 +284,16 @@ def remote(
283284
After this amount of time Amazon SageMaker will stop waiting for managed spot training
284285
job to complete. Defaults to ``None``.
285286
287+
disable_output_compression (bool): Optional. When set to true, Model is uploaded to
288+
Amazon S3 without compression after training finishes.
289+
286290
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
287291
Defaults to ``False``.
288292
289293
use_mpirun (bool): Specifies whether to use mpirun for distributed training.
290294
Defaults to ``False``.
291295
292-
nproc_per_node (Optional int): Specifies the number of processes per node for
296+
nproc_per_node (int): Optional. Specifies the number of processes per node for
293297
distributed training. Defaults to ``None``.
294298
This is defined automatically configured on the instance type.
295299
"""
@@ -324,6 +328,7 @@ def _remote(func):
324328
spark_config=spark_config,
325329
use_spot_instances=use_spot_instances,
326330
max_wait_time_in_seconds=max_wait_time_in_seconds,
331+
disable_output_compression=disable_output_compression,
327332
use_torchrun=use_torchrun,
328333
use_mpirun=use_mpirun,
329334
nproc_per_node=nproc_per_node,
@@ -543,6 +548,7 @@ def __init__(
543548
spark_config: SparkConfig = None,
544549
use_spot_instances=False,
545550
max_wait_time_in_seconds=None,
551+
disable_output_compression: bool = False,
546552
use_torchrun: bool = False,
547553
use_mpirun: bool = False,
548554
nproc_per_node: Optional[int] = None,
@@ -736,13 +742,16 @@ def __init__(
736742
After this amount of time Amazon SageMaker will stop waiting for managed spot training
737743
job to complete. Defaults to ``None``.
738744
745+
disable_output_compression (bool): Optional. When set to true, Model is uploaded to
746+
Amazon S3 without compression after training finishes.
747+
739748
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
740749
Defaults to ``False``.
741750
742751
use_mpirun (bool): Specifies whether to use mpirun for distributed training.
743752
Defaults to ``False``.
744753
745-
nproc_per_node (Optional int): Specifies the number of processes per node for
754+
nproc_per_node (int): Optional. Specifies the number of processes per node for
746755
distributed training. Defaults to ``None``.
747756
This is defined automatically configured on the instance type.
748757
"""
@@ -790,6 +799,7 @@ def __init__(
790799
spark_config=spark_config,
791800
use_spot_instances=use_spot_instances,
792801
max_wait_time_in_seconds=max_wait_time_in_seconds,
802+
disable_output_compression=disable_output_compression,
793803
use_torchrun=use_torchrun,
794804
use_mpirun=use_mpirun,
795805
nproc_per_node=nproc_per_node,

src/sagemaker/remote_function/job.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def __init__(
373373
spark_config: SparkConfig = None,
374374
use_spot_instances=False,
375375
max_wait_time_in_seconds=None,
376+
disable_output_compression: bool = False,
376377
use_torchrun: bool = False,
377378
use_mpirun: bool = False,
378379
nproc_per_node: Optional[int] = None,
@@ -558,13 +559,16 @@ def __init__(
558559
After this amount of time Amazon SageMaker will stop waiting for managed spot
559560
training job to complete. Defaults to ``None``.
560561
562+
disable_output_compression (bool): Optional. When set to true, Model is uploaded to
563+
Amazon S3 without compression after training finishes.
564+
561565
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
562566
Defaults to ``False``.
563567
564568
use_mpirun (bool): Specifies whether to use mpirun for distributed training.
565569
Defaults to ``False``.
566570
567-
nproc_per_node (Optional int): Specifies the number of processes per node for
571+
nproc_per_node (int): Optional. Specifies the number of processes per node for
568572
distributed training. Defaults to ``None``.
569573
This is defined automatically configured on the instance type.
570574
"""
@@ -725,6 +729,7 @@ def __init__(
725729
tags = format_tags(tags)
726730
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)
727731

732+
self.disable_output_compression = disable_output_compression
728733
self.use_torchrun = use_torchrun
729734
self.use_mpirun = use_mpirun
730735
self.nproc_per_node = nproc_per_node
@@ -954,6 +959,8 @@ def compile(
954959
output_config = {"S3OutputPath": s3_base_uri}
955960
if job_settings.s3_kms_key is not None:
956961
output_config["KmsKeyId"] = job_settings.s3_kms_key
962+
if job_settings.disable_output_compression:
963+
output_config["CompressionType"] = "NONE"
957964
request_dict["OutputDataConfig"] = output_config
958965

959966
container_args = ["--s3_base_uri", s3_base_uri]

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ def test_remote_decorator_fields_consistency(get_execution_role, session):
907907
"use_spot_instances",
908908
"max_wait_time_in_seconds",
909909
"custom_file_filter",
910+
"disable_output_compression",
910911
"use_torchrun",
911912
"use_mpirun",
912913
"nproc_per_node",

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,7 @@ def test_consistency_between_remote_and_step_decorator():
15041504
"s3_kms_key",
15051505
"s3_root_uri",
15061506
"sagemaker_session",
1507+
"disable_output_compression",
15071508
"use_torchrun",
15081509
"use_mpirun",
15091510
"nproc_per_node",

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,24 +291,47 @@ def mock_get_current_run():
291291
return current_run
292292

293293

294-
def describe_training_job_response(job_status):
295-
return {
294+
def describe_training_job_response(job_status, disable_output_compression=False):
295+
job_response = {
296296
"TrainingJobArn": TRAINING_JOB_ARN,
297297
"TrainingJobStatus": job_status,
298298
"ResourceConfig": {
299299
"InstanceCount": 1,
300300
"InstanceType": "ml.c4.xlarge",
301301
"VolumeSizeInGB": 30,
302302
},
303-
"OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"},
304303
}
305304

305+
if disable_output_compression:
306+
output_config = {
307+
"S3OutputPath": "s3://sagemaker-123/image_uri/output",
308+
"CompressionType": "NONE",
309+
}
310+
else:
311+
output_config = {
312+
"S3OutputPath": "s3://sagemaker-123/image_uri/output",
313+
"CompressionType": "NONE",
314+
}
315+
316+
job_response["OutputDataConfig"] = output_config
317+
318+
return job_response
319+
306320

307321
COMPLETED_TRAINING_JOB = describe_training_job_response("Completed")
308322
INPROGRESS_TRAINING_JOB = describe_training_job_response("InProgress")
309323
CANCELLED_TRAINING_JOB = describe_training_job_response("Stopped")
310324
FAILED_TRAINING_JOB = describe_training_job_response("Failed")
311325

326+
COMPLETED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response(
327+
"Completed", True
328+
)
329+
INPROGRESS_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response(
330+
"InProgress", True
331+
)
332+
CANCELLED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response("Stopped", True)
333+
FAILED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response("Failed", True)
334+
312335

313336
def mock_session():
314337
session = Mock()
@@ -1303,6 +1326,27 @@ def test_describe(session, *args):
13031326
session().sagemaker_client.describe_training_job.assert_called_once()
13041327

13051328

1329+
@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts")
1330+
@patch("sagemaker.remote_function.job._prepare_and_upload_workspace")
1331+
@patch("sagemaker.remote_function.job.StoredFunction")
1332+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
1333+
def test_describe_disable_output_compression(session, *args):
1334+
1335+
job_settings = _JobSettings(
1336+
image_uri=IMAGE,
1337+
s3_root_uri=S3_URI,
1338+
role=ROLE_ARN,
1339+
instance_type="ml.m5.large",
1340+
disable_output_compression=True,
1341+
)
1342+
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
1343+
1344+
job.describe()
1345+
assert job.describe() == COMPLETED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION
1346+
1347+
session().sagemaker_client.describe_training_job.assert_called_once()
1348+
1349+
13061350
@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts")
13071351
@patch("sagemaker.remote_function.job._prepare_and_upload_workspace")
13081352
@patch("sagemaker.remote_function.job.StoredFunction")

0 commit comments

Comments
 (0)