Skip to content

Commit fdbf6ba

Browse files
committed
Added torchrun execution for remote jobs
1 parent 0e3769e commit fdbf6ba

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,12 @@
3535

3636
from sagemaker.session import Session
3737
from sagemaker.s3 import s3_path_join
38-
from sagemaker.remote_function.job import _JobSettings, _Job, _RunInfo
38+
from sagemaker.remote_function.job import _Job, _JobSettings
39+
from sagemaker.remote_function.job import _RunInfo
3940
from sagemaker.remote_function import logging_config
4041
from sagemaker.utils import name_from_base, base_from_name
4142
from sagemaker.remote_function.spark_config import SparkConfig
4243
from sagemaker.remote_function.custom_file_filter import CustomFileFilter
43-
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
44-
from sagemaker.telemetry.constants import Feature
4544

4645
_API_CALL_LIMIT = {
4746
"SubmittingIntervalInSecs": 1,
@@ -59,7 +58,6 @@
5958
logger = logging_config.get_logger()
6059

6160

62-
@_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote")
6361
def remote(
6462
_func=None,
6563
*,
@@ -90,6 +88,8 @@ def remote(
9088
spark_config: SparkConfig = None,
9189
use_spot_instances=False,
9290
max_wait_time_in_seconds=None,
91+
use_torchrun=False,
92+
nproc_per_node=1,
9393
):
9494
"""Decorator for running the annotated function as a SageMaker training job.
9595
@@ -310,6 +310,8 @@ def _remote(func):
310310
spark_config=spark_config,
311311
use_spot_instances=use_spot_instances,
312312
max_wait_time_in_seconds=max_wait_time_in_seconds,
313+
use_torchrun=use_torchrun,
314+
nproc_per_node=nproc_per_node,
313315
)
314316

315317
@functools.wraps(func)

src/sagemaker/remote_function/job.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,51 @@
162162
fi
163163
"""
164164

165+
ENTRYPOINT_TORCHRUN_SCRIPT = f"""
166+
#!/bin/bash
167+
168+
# Entry point for bootstrapping runtime environment and invoking remote function with torchrun
169+
170+
set -eu
171+
172+
PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
173+
export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
174+
printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n"
175+
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176+
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
177+
178+
179+
printf "INFO: Bootstraping runtime environment.\\n"
180+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
181+
182+
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
183+
then
184+
if [ -f "remote_function_conda_env.txt" ]
185+
then
186+
cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt
187+
fi
188+
printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n"
189+
cd {JOB_REMOTE_FUNCTION_WORKSPACE}
190+
fi
191+
192+
if [ -f "remote_function_conda_env.txt" ]
193+
then
194+
conda_env=$(cat remote_function_conda_env.txt)
195+
196+
if which mamba >/dev/null; then
197+
conda_exe="mamba"
198+
else
199+
conda_exe="conda"
200+
fi
201+
202+
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
203+
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
204+
else
205+
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
206+
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
207+
fi
208+
"""
209+
165210
SPARK_ENTRYPOINT_SCRIPT = f"""
166211
#!/bin/bash
167212
@@ -216,6 +261,8 @@ def __init__(
216261
spark_config: SparkConfig = None,
217262
use_spot_instances=False,
218263
max_wait_time_in_seconds=None,
264+
use_torchrun=False,
265+
nproc_per_node=1,
219266
):
220267
"""Initialize a _JobSettings instance which configures the remote job.
221268
@@ -555,6 +602,9 @@ def __init__(
555602
tags = format_tags(tags)
556603
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)
557604

605+
self.use_torchrun = use_torchrun
606+
self.nproc_per_node = nproc_per_node
607+
558608
@staticmethod
559609
def _get_default_image(session):
560610
"""Return Studio notebook image, if in Studio env. Else, base python.
@@ -951,7 +1001,7 @@ def _get_job_name(job_settings, func):
9511001

9521002

9531003
def _prepare_and_upload_runtime_scripts(
954-
spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session
1004+
spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session, use_torchrun: bool = False, nproc_per_node: int = 1
9551005
):
9561006
"""Copy runtime scripts to a folder and upload to S3.
9571007
@@ -988,6 +1038,10 @@ def _prepare_and_upload_runtime_scripts(
9881038
)
9891039
shutil.copy2(spark_script_path, bootstrap_scripts)
9901040

1041+
if use_torchrun:
1042+
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1043+
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node))
1044+
9911045
with open(entrypoint_script_path, "w", newline="\n") as file:
9921046
file.writelines(entry_point_script)
9931047

@@ -1025,6 +1079,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
10251079
s3_base_uri=s3_base_uri,
10261080
s3_kms_key=job_settings.s3_kms_key,
10271081
sagemaker_session=job_settings.sagemaker_session,
1082+
use_torchrun=job_settings.use_torchrun,
1083+
nproc_per_node=job_settings.nproc_per_node,
10281084
)
10291085

10301086
input_data_config = [

0 commit comments

Comments
 (0)