Skip to content

Aligned disable_output_compression for @remote with Estimator #5094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def remote(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
disable_output_compression: bool = False,
use_torchrun: bool = False,
use_mpirun: bool = False,
nproc_per_node: Optional[int] = None,
Expand Down Expand Up @@ -283,13 +284,16 @@ def remote(
After this amount of time Amazon SageMaker will stop waiting for managed spot training
job to complete. Defaults to ``None``.

disable_output_compression (bool): Optional. When set to true, Model is uploaded to
Amazon S3 without compression after training finishes.

use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

use_mpirun (bool): Specifies whether to use mpirun for distributed training.
Defaults to ``False``.

nproc_per_node (Optional int): Specifies the number of processes per node for
nproc_per_node (int): Optional. Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
Expand Down Expand Up @@ -324,6 +328,7 @@ def _remote(func):
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
disable_output_compression=disable_output_compression,
use_torchrun=use_torchrun,
use_mpirun=use_mpirun,
nproc_per_node=nproc_per_node,
Expand Down Expand Up @@ -543,6 +548,7 @@ def __init__(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
disable_output_compression: bool = False,
use_torchrun: bool = False,
use_mpirun: bool = False,
nproc_per_node: Optional[int] = None,
Expand Down Expand Up @@ -736,13 +742,16 @@ def __init__(
After this amount of time Amazon SageMaker will stop waiting for managed spot training
job to complete. Defaults to ``None``.

disable_output_compression (bool): Optional. When set to true, Model is uploaded to
Amazon S3 without compression after training finishes.

use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

use_mpirun (bool): Specifies whether to use mpirun for distributed training.
Defaults to ``False``.

nproc_per_node (Optional int): Specifies the number of processes per node for
nproc_per_node (int): Optional. Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
Expand Down Expand Up @@ -790,6 +799,7 @@ def __init__(
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
disable_output_compression=disable_output_compression,
use_torchrun=use_torchrun,
use_mpirun=use_mpirun,
nproc_per_node=nproc_per_node,
Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def __init__(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
disable_output_compression: bool = False,
use_torchrun: bool = False,
use_mpirun: bool = False,
nproc_per_node: Optional[int] = None,
Expand Down Expand Up @@ -558,13 +559,16 @@ def __init__(
After this amount of time Amazon SageMaker will stop waiting for managed spot
training job to complete. Defaults to ``None``.

disable_output_compression (bool): Optional. When set to true, Model is uploaded to
Amazon S3 without compression after training finishes.

use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

use_mpirun (bool): Specifies whether to use mpirun for distributed training.
Defaults to ``False``.

nproc_per_node (Optional int): Specifies the number of processes per node for
nproc_per_node (int): Optional. Specifies the number of processes per node for
distributed training. Defaults to ``None``.
This is defined automatically configured on the instance type.
"""
Expand Down Expand Up @@ -725,6 +729,7 @@ def __init__(
tags = format_tags(tags)
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)

self.disable_output_compression = disable_output_compression
self.use_torchrun = use_torchrun
self.use_mpirun = use_mpirun
self.nproc_per_node = nproc_per_node
Expand Down Expand Up @@ -954,6 +959,8 @@ def compile(
output_config = {"S3OutputPath": s3_base_uri}
if job_settings.s3_kms_key is not None:
output_config["KmsKeyId"] = job_settings.s3_kms_key
if job_settings.disable_output_compression:
output_config["CompressionType"] = "NONE"
request_dict["OutputDataConfig"] = output_config

container_args = ["--s3_base_uri", s3_base_uri]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ def test_remote_decorator_fields_consistency(get_execution_role, session):
"use_spot_instances",
"max_wait_time_in_seconds",
"custom_file_filter",
"disable_output_compression",
"use_torchrun",
"use_mpirun",
"nproc_per_node",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/remote_function/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,7 @@ def test_consistency_between_remote_and_step_decorator():
"s3_kms_key",
"s3_root_uri",
"sagemaker_session",
"disable_output_compression",
"use_torchrun",
"use_mpirun",
"nproc_per_node",
Expand Down
50 changes: 47 additions & 3 deletions tests/unit/sagemaker/remote_function/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,24 +291,47 @@ def mock_get_current_run():
return current_run


def describe_training_job_response(job_status):
return {
def describe_training_job_response(job_status, disable_output_compression=False):
job_response = {
"TrainingJobArn": TRAINING_JOB_ARN,
"TrainingJobStatus": job_status,
"ResourceConfig": {
"InstanceCount": 1,
"InstanceType": "ml.c4.xlarge",
"VolumeSizeInGB": 30,
},
"OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"},
}

if disable_output_compression:
output_config = {
"S3OutputPath": "s3://sagemaker-123/image_uri/output",
"CompressionType": "NONE",
}
else:
output_config = {
"S3OutputPath": "s3://sagemaker-123/image_uri/output",
"CompressionType": "NONE",
}

job_response["OutputDataConfig"] = output_config

return job_response


COMPLETED_TRAINING_JOB = describe_training_job_response("Completed")
INPROGRESS_TRAINING_JOB = describe_training_job_response("InProgress")
CANCELLED_TRAINING_JOB = describe_training_job_response("Stopped")
FAILED_TRAINING_JOB = describe_training_job_response("Failed")

COMPLETED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response(
"Completed", True
)
INPROGRESS_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response(
"InProgress", True
)
CANCELLED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response("Stopped", True)
FAILED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response("Failed", True)


def mock_session():
session = Mock()
Expand Down Expand Up @@ -1303,6 +1326,27 @@ def test_describe(session, *args):
session().sagemaker_client.describe_training_job.assert_called_once()


@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts")
@patch("sagemaker.remote_function.job._prepare_and_upload_workspace")
@patch("sagemaker.remote_function.job.StoredFunction")
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
def test_describe_disable_output_compression(session, *args):

job_settings = _JobSettings(
image_uri=IMAGE,
s3_root_uri=S3_URI,
role=ROLE_ARN,
instance_type="ml.m5.large",
disable_output_compression=True,
)
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})

job.describe()
assert job.describe() == COMPLETED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION

session().sagemaker_client.describe_training_job.assert_called_once()


@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts")
@patch("sagemaker.remote_function.job._prepare_and_upload_workspace")
@patch("sagemaker.remote_function.job.StoredFunction")
Expand Down