Skip to content

Commit cbd2ed9

Browse files
brunopistonesage-makerEC2 Default User
authored
Added torchrun compatibility for distributet training across multiple GPUs in a single node (single instance) (#4766)
* Added torchrun execution for remote jobs * added integration tests * docstring for use_torchrun and nproc_per_node * code formatting * added telemetry tracking * indentation fixed * runned linter * fixed string length, sagemaker_remote/job.py * fixed test cases for remote function, run test cases for all sagemaker items * reduced length of docstring --------- Co-authored-by: sage-maker <[email protected]> Co-authored-by: EC2 Default User <[email protected]>
1 parent 30a5478 commit cbd2ed9

File tree

7 files changed

+142
-1
lines changed

7 files changed

+142
-1
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def remote(
9090
spark_config: SparkConfig = None,
9191
use_spot_instances=False,
9292
max_wait_time_in_seconds=None,
93+
use_torchrun=False,
94+
nproc_per_node=1,
9395
):
9496
"""Decorator for running the annotated function as a SageMaker training job.
9597
@@ -278,6 +280,12 @@ def remote(
278280
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
279281
After this amount of time Amazon SageMaker will stop waiting for managed spot training
280282
job to complete. Defaults to ``None``.
283+
284+
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
285+
Defaults to ``False``.
286+
287+
nproc_per_node (int): Specifies the number of processes per node for distributed training.
288+
Defaults to ``1``.
281289
"""
282290

283291
def _remote(func):
@@ -310,6 +318,8 @@ def _remote(func):
310318
spark_config=spark_config,
311319
use_spot_instances=use_spot_instances,
312320
max_wait_time_in_seconds=max_wait_time_in_seconds,
321+
use_torchrun=use_torchrun,
322+
nproc_per_node=nproc_per_node,
313323
)
314324

315325
@functools.wraps(func)
@@ -521,6 +531,8 @@ def __init__(
521531
spark_config: SparkConfig = None,
522532
use_spot_instances=False,
523533
max_wait_time_in_seconds=None,
534+
use_torchrun=False,
535+
nproc_per_node=1,
524536
):
525537
"""Constructor for RemoteExecutor
526538
@@ -709,6 +721,12 @@ def __init__(
709721
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
710722
After this amount of time Amazon SageMaker will stop waiting for managed spot training
711723
job to complete. Defaults to ``None``.
724+
725+
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
726+
Defaults to ``False``.
727+
728+
nproc_per_node (int): Specifies the number of processes per node.
729+
Defaults to ``1``.
712730
"""
713731
self.max_parallel_jobs = max_parallel_jobs
714732

@@ -749,6 +767,8 @@ def __init__(
749767
spark_config=spark_config,
750768
use_spot_instances=use_spot_instances,
751769
max_wait_time_in_seconds=max_wait_time_in_seconds,
770+
use_torchrun=use_torchrun,
771+
nproc_per_node=nproc_per_node,
752772
)
753773

754774
self._state_condition = threading.Condition()

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(
5555
hmac_key: str,
5656
s3_kms_key: str = None,
5757
context: Context = Context(),
58+
use_torchrun: bool = False,
59+
nproc_per_node: int = 1,
5860
):
5961
"""Construct a StoredFunction object.
6062
@@ -65,12 +67,16 @@ def __init__(
6567
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
6668
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
6769
context: Build or run context of a pipeline step.
70+
use_torchrun: Whether to use torchrun for distributed training.
71+
nproc_per_node: Number of processes per node for distributed training.
6872
"""
6973
self.sagemaker_session = sagemaker_session
7074
self.s3_base_uri = s3_base_uri
7175
self.s3_kms_key = s3_kms_key
7276
self.hmac_key = hmac_key
7377
self.context = context
78+
self.use_torchrun = use_torchrun
79+
self.nproc_per_node = nproc_per_node
7480

7581
self.func_upload_path = s3_path_join(
7682
s3_base_uri, context.step_name, context.func_step_s3_dir

src/sagemaker/remote_function/job.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,52 @@
162162
fi
163163
"""
164164

165+
ENTRYPOINT_TORCHRUN_SCRIPT = f"""
166+
#!/bin/bash
167+
168+
# Entry point for bootstrapping runtime environment and invoking remote function with torchrun
169+
170+
set -eu
171+
172+
PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
173+
export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
174+
printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n"
175+
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176+
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
177+
178+
179+
printf "INFO: Bootstraping runtime environment.\\n"
180+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
181+
182+
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
183+
then
184+
if [ -f "remote_function_conda_env.txt" ]
185+
then
186+
cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt
187+
fi
188+
printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n"
189+
cd {JOB_REMOTE_FUNCTION_WORKSPACE}
190+
fi
191+
192+
if [ -f "remote_function_conda_env.txt" ]
193+
then
194+
conda_env=$(cat remote_function_conda_env.txt)
195+
196+
if which mamba >/dev/null; then
197+
conda_exe="mamba"
198+
else
199+
conda_exe="conda"
200+
fi
201+
202+
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
203+
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
204+
-m sagemaker.remote_function.invoke_function "$@"
205+
else
206+
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
207+
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
208+
fi
209+
"""
210+
165211
SPARK_ENTRYPOINT_SCRIPT = f"""
166212
#!/bin/bash
167213
@@ -216,6 +262,8 @@ def __init__(
216262
spark_config: SparkConfig = None,
217263
use_spot_instances=False,
218264
max_wait_time_in_seconds=None,
265+
use_torchrun=False,
266+
nproc_per_node=1,
219267
):
220268
"""Initialize a _JobSettings instance which configures the remote job.
221269
@@ -555,6 +603,9 @@ def __init__(
555603
tags = format_tags(tags)
556604
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)
557605

606+
self.use_torchrun = use_torchrun
607+
self.nproc_per_node = nproc_per_node
608+
558609
@staticmethod
559610
def _get_default_image(session):
560611
"""Return Studio notebook image, if in Studio env. Else, base python.
@@ -725,6 +776,8 @@ def compile(
725776
s3_base_uri=s3_base_uri,
726777
hmac_key=hmac_key,
727778
s3_kms_key=job_settings.s3_kms_key,
779+
use_torchrun=job_settings.use_torchrun,
780+
nproc_per_node=job_settings.nproc_per_node,
728781
)
729782
stored_function.save(func, *func_args, **func_kwargs)
730783
else:
@@ -737,6 +790,8 @@ def compile(
737790
step_name=step_compilation_context.step_name,
738791
func_step_s3_dir=step_compilation_context.pipeline_build_time,
739792
),
793+
use_torchrun=job_settings.use_torchrun,
794+
nproc_per_node=job_settings.nproc_per_node,
740795
)
741796

742797
stored_function.save_pipeline_step_function(serialized_data)
@@ -951,7 +1006,12 @@ def _get_job_name(job_settings, func):
9511006

9521007

9531008
def _prepare_and_upload_runtime_scripts(
954-
spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session
1009+
spark_config: SparkConfig,
1010+
s3_base_uri: str,
1011+
s3_kms_key: str,
1012+
sagemaker_session: Session,
1013+
use_torchrun: bool = False,
1014+
nproc_per_node: int = 1,
9551015
):
9561016
"""Copy runtime scripts to a folder and upload to S3.
9571017
@@ -967,6 +1027,10 @@ def _prepare_and_upload_runtime_scripts(
9671027
s3_kms_key (str): kms key used to encrypt the files uploaded to S3.
9681028
9691029
sagemaker_session (str): SageMaker boto client session.
1030+
1031+
use_torchrun (bool): Whether to use torchrun or not.
1032+
1033+
nproc_per_node (int): Number of processes per node.
9701034
"""
9711035

9721036
from sagemaker.workflow.utilities import load_step_compilation_context
@@ -988,6 +1052,10 @@ def _prepare_and_upload_runtime_scripts(
9881052
)
9891053
shutil.copy2(spark_script_path, bootstrap_scripts)
9901054

1055+
if use_torchrun:
1056+
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057+
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node))
1058+
9911059
with open(entrypoint_script_path, "w", newline="\n") as file:
9921060
file.writelines(entry_point_script)
9931061

@@ -1025,6 +1093,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
10251093
s3_base_uri=s3_base_uri,
10261094
s3_kms_key=job_settings.s3_kms_key,
10271095
sagemaker_session=job_settings.sagemaker_session,
1096+
use_torchrun=job_settings.use_torchrun,
1097+
nproc_per_node=job_settings.nproc_per_node,
10281098
)
10291099

10301100
input_data_config = [

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,26 @@ def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container):
818818
f"--rm {auto_capture_test_container}"
819819
)
820820
subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8")
821+
822+
823+
def test_decorator_torchrun(
824+
sagemaker_session,
825+
dummy_container_without_error,
826+
gpu_instance_type,
827+
use_torchrun=False,
828+
nproc_per_node=1,
829+
):
830+
@remote(
831+
role=ROLE,
832+
image_uri=dummy_container_without_error,
833+
instance_type=gpu_instance_type,
834+
sagemaker_session=sagemaker_session,
835+
keep_alive_period_in_seconds=60,
836+
use_torchrun=use_torchrun,
837+
nproc_per_node=nproc_per_node,
838+
)
839+
def divide(x, y):
840+
return x / y
841+
842+
assert divide(10, 2) == 5
843+
assert divide(20, 2) == 10

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,8 @@ def test_remote_decorator_fields_consistency(get_execution_role, session):
907907
"use_spot_instances",
908908
"max_wait_time_in_seconds",
909909
"custom_file_filter",
910+
"use_torchrun",
911+
"nproc_per_node",
910912
}
911913

912914
job_settings = _JobSettings(

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,8 @@ def test_consistency_between_remote_and_step_decorator():
15041504
"s3_kms_key",
15051505
"s3_root_uri",
15061506
"sagemaker_session",
1507+
"use_torchrun",
1508+
"nproc_per_node",
15071509
]
15081510

15091511
step_args_to_ignore = ["_step", "name", "display_name", "description", "retry_policies"]

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ def test_start(
376376
s3_base_uri=f"{S3_URI}/{job.job_name}",
377377
hmac_key=HMAC_KEY,
378378
s3_kms_key=None,
379+
use_torchrun=False,
380+
nproc_per_node=1,
379381
)
380382

381383
mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4})
@@ -389,6 +391,8 @@ def test_start(
389391
s3_base_uri=f"{S3_URI}/{job.job_name}",
390392
s3_kms_key=None,
391393
sagemaker_session=session(),
394+
use_torchrun=False,
395+
nproc_per_node=1,
392396
)
393397

394398
mock_dependency_upload.assert_called_once_with(
@@ -506,6 +510,8 @@ def test_start_with_checkpoint_location(
506510
s3_base_uri=f"{S3_URI}/{job.job_name}",
507511
hmac_key=HMAC_KEY,
508512
s3_kms_key=None,
513+
use_torchrun=False,
514+
nproc_per_node=1,
509515
)
510516

511517
mock_stored_function().save.assert_called_once_with(
@@ -659,6 +665,8 @@ def test_start_with_complete_job_settings(
659665
s3_base_uri=f"{S3_URI}/{job.job_name}",
660666
hmac_key=HMAC_KEY,
661667
s3_kms_key=KMS_KEY_ARN,
668+
use_torchrun=False,
669+
nproc_per_node=1,
662670
)
663671

664672
local_dependencies_path = mock_runtime_manager().snapshot()
@@ -670,6 +678,8 @@ def test_start_with_complete_job_settings(
670678
s3_base_uri=f"{S3_URI}/{job.job_name}",
671679
s3_kms_key=job_settings.s3_kms_key,
672680
sagemaker_session=session(),
681+
use_torchrun=False,
682+
nproc_per_node=1,
673683
)
674684

675685
mock_user_workspace_upload.assert_called_once_with(
@@ -828,6 +838,8 @@ def test_get_train_args_under_pipeline_context(
828838
step_name=MOCKED_PIPELINE_CONFIG.step_name,
829839
func_step_s3_dir=MOCKED_PIPELINE_CONFIG.pipeline_build_time,
830840
),
841+
use_torchrun=False,
842+
nproc_per_node=1,
831843
)
832844
mock_stored_function.save_pipeline_step_function.assert_called_once_with(mocked_serialized_data)
833845

@@ -840,6 +852,8 @@ def test_get_train_args_under_pipeline_context(
840852
s3_base_uri=s3_base_uri,
841853
s3_kms_key=job_settings.s3_kms_key,
842854
sagemaker_session=session(),
855+
use_torchrun=False,
856+
nproc_per_node=1,
843857
)
844858

845859
mock_user_workspace_upload.assert_called_once_with(
@@ -1014,6 +1028,8 @@ def test_start_with_spark(
10141028
s3_base_uri=f"{S3_URI}/{job.job_name}",
10151029
s3_kms_key=None,
10161030
sagemaker_session=session(),
1031+
use_torchrun=False,
1032+
nproc_per_node=1,
10171033
)
10181034

10191035
session().sagemaker_client.create_training_job.assert_called_once_with(
@@ -1168,6 +1184,8 @@ def test_prepare_and_upload_runtime_scripts(session, mock_copy, mock_s3_upload):
11681184
s3_base_uri=S3_URI,
11691185
s3_kms_key=KMS_KEY_ARN,
11701186
sagemaker_session=session(),
1187+
use_torchrun=False,
1188+
nproc_per_node=1,
11711189
)
11721190

11731191
assert s3_path == mock_s3_upload.return_value

0 commit comments

Comments
 (0)