Skip to content

Added torchrun compatibility for distributet training across multiple GPUs in a single node (single instance) #4766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fdbf6ba
Added torchrun execution for remote jobs
brunopistone Jul 2, 2024
c253d0b
added integration tests
brunopistone Jul 2, 2024
6815adb
docstring for use_torchrun and nproc_per_node
brunopistone Jul 25, 2024
05b2c61
Merge branch 'master' into master
brunopistone Jul 26, 2024
f1b99a4
Merge branch 'master' into master
sage-maker Aug 7, 2024
fb3015f
Merge branch 'master' into master
sage-maker Aug 7, 2024
73a1a62
Merge branch 'master' into master
sage-maker Aug 8, 2024
f6840d1
code formatting
brunopistone Aug 8, 2024
60a421d
Merge branch 'master' of https://github.com/brunopistone/sagemaker-py…
brunopistone Aug 8, 2024
a61d042
added telemetry tracking
brunopistone Aug 8, 2024
e747737
indentation fixed
brunopistone Aug 8, 2024
6fce4d6
Merge branch 'master' into master
sage-maker Aug 8, 2024
a508ebf
runned linter
brunopistone Aug 8, 2024
020f29b
Merge branch 'master' of https://github.com/brunopistone/sagemaker-py…
brunopistone Aug 8, 2024
634b8f6
fixed string length, sagemaker_remote/job.py
brunopistone Aug 8, 2024
ef92bcf
Merge branch 'master' into master
sage-maker Aug 8, 2024
9681d91
Merge branch 'master' into master
sage-maker Aug 8, 2024
7a31831
Merge branch 'master' into master
sage-maker Aug 8, 2024
5ce1fbd
Merge branch 'master' into master
sage-maker Aug 8, 2024
fb38454
fixed test cases for remote function, run test cases for all sagemake…
Aug 8, 2024
1ad86dd
Merge branch 'master' of https://github.com/brunopistone/sagemaker-py…
brunopistone Aug 8, 2024
97c172e
reduced length of docstring
brunopistone Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def remote(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
):
"""Decorator for running the annotated function as a SageMaker training job.

Expand Down Expand Up @@ -278,6 +280,12 @@ def remote(
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot training
job to complete. Defaults to ``None``.

use_torchrun (bool): Specifies whether to use torchrun for distributed training.
Defaults to ``False``.

nproc_per_node (int): Specifies the number of processes per node for distributed training.
Defaults to ``1``.
"""

def _remote(func):
Expand Down Expand Up @@ -310,6 +318,8 @@ def _remote(func):
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
use_torchrun=use_torchrun,
nproc_per_node=nproc_per_node,
)

@functools.wraps(func)
Expand Down
63 changes: 62 additions & 1 deletion src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,51 @@
fi
"""

ENTRYPOINT_TORCHRUN_SCRIPT = f"""
#!/bin/bash

# Entry point for bootstrapping runtime environment and invoking remote function with torchrun

set -eu

PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n"
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"


printf "INFO: Bootstraping runtime environment.\\n"
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"

if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
then
if [ -f "remote_function_conda_env.txt" ]
then
cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt
fi
printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n"
cd {JOB_REMOTE_FUNCTION_WORKSPACE}
fi

if [ -f "remote_function_conda_env.txt" ]
then
conda_env=$(cat remote_function_conda_env.txt)

if which mamba >/dev/null; then
conda_exe="mamba"
else
conda_exe="conda"
fi

printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line too long

else
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
fi
"""

SPARK_ENTRYPOINT_SCRIPT = f"""
#!/bin/bash

Expand Down Expand Up @@ -216,6 +261,8 @@ def __init__(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
use_torchrun=False,
nproc_per_node=1,
):
"""Initialize a _JobSettings instance which configures the remote job.

Expand Down Expand Up @@ -555,6 +602,9 @@ def __init__(
tags = format_tags(tags)
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)

self.use_torchrun = use_torchrun
self.nproc_per_node = nproc_per_node

@staticmethod
def _get_default_image(session):
"""Return Studio notebook image, if in Studio env. Else, base python.
Expand Down Expand Up @@ -951,7 +1001,12 @@ def _get_job_name(job_settings, func):


def _prepare_and_upload_runtime_scripts(
spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session
spark_config: SparkConfig,
s3_base_uri: str,
s3_kms_key: str,
sagemaker_session: Session,
use_torchrun: bool = False,
nproc_per_node: int = 1,
):
"""Copy runtime scripts to a folder and upload to S3.

Expand Down Expand Up @@ -988,6 +1043,10 @@ def _prepare_and_upload_runtime_scripts(
)
shutil.copy2(spark_script_path, bootstrap_scripts)

if use_torchrun:
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node))

with open(entrypoint_script_path, "w", newline="\n") as file:
file.writelines(entry_point_script)

Expand Down Expand Up @@ -1025,6 +1084,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
s3_base_uri=s3_base_uri,
s3_kms_key=job_settings.s3_kms_key,
sagemaker_session=job_settings.sagemaker_session,
use_torchrun=job_settings.use_torchrun,
nproc_per_node=job_settings.nproc_per_node,
)

input_data_config = [
Expand Down
23 changes: 23 additions & 0 deletions tests/integ/sagemaker/remote_function/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,3 +818,26 @@ def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container):
f"--rm {auto_capture_test_container}"
)
subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and a new line here


def test_decorator_torchrun(
sagemaker_session,
dummy_container_without_error,
gpu_instance_type,
use_torchrun=False,
nproc_per_node=1,
):
@remote(
role=ROLE,
image_uri=dummy_container_without_error,
instance_type=gpu_instance_type,
sagemaker_session=sagemaker_session,
keep_alive_period_in_seconds=60,
use_torchrun=use_torchrun,
nproc_per_node=nproc_per_node,
)
def divide(x, y):
return x / y

assert divide(10, 2) == 5
assert divide(20, 2) == 10
Loading