|
162 | 162 | fi
|
163 | 163 | """
|
164 | 164 |
|
| 165 | +ENTRYPOINT_TORCHRUN_SCRIPT = f""" |
| 166 | +#!/bin/bash |
| 167 | +
|
| 168 | +# Entry point for bootstrapping runtime environment and invoking remote function with torchrun |
| 169 | +
|
| 170 | +set -eu |
| 171 | +
|
| 172 | +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} |
| 173 | +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs |
| 174 | +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" |
| 175 | +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip |
| 176 | +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" |
| 177 | +
|
| 178 | +
|
| 179 | +printf "INFO: Bootstraping runtime environment.\\n" |
| 180 | +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" |
| 181 | +
|
| 182 | +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] |
| 183 | +then |
| 184 | + if [ -f "remote_function_conda_env.txt" ] |
| 185 | + then |
| 186 | + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt |
| 187 | + fi |
| 188 | + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" |
| 189 | + cd {JOB_REMOTE_FUNCTION_WORKSPACE} |
| 190 | +fi |
| 191 | +
|
| 192 | +if [ -f "remote_function_conda_env.txt" ] |
| 193 | +then |
| 194 | + conda_env=$(cat remote_function_conda_env.txt) |
| 195 | +
|
| 196 | + if which mamba >/dev/null; then |
| 197 | + conda_exe="mamba" |
| 198 | + else |
| 199 | + conda_exe="conda" |
| 200 | + fi |
| 201 | +
|
| 202 | + printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" |
| 203 | + $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@" |
| 204 | +else |
| 205 | + printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" |
| 206 | + torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@" |
| 207 | +fi |
| 208 | +""" |
| 209 | + |
165 | 210 | SPARK_ENTRYPOINT_SCRIPT = f"""
|
166 | 211 | #!/bin/bash
|
167 | 212 |
|
@@ -216,6 +261,8 @@ def __init__(
|
216 | 261 | spark_config: SparkConfig = None,
|
217 | 262 | use_spot_instances=False,
|
218 | 263 | max_wait_time_in_seconds=None,
|
| 264 | + use_torchrun=False, |
| 265 | + nproc_per_node=1, |
219 | 266 | ):
|
220 | 267 | """Initialize a _JobSettings instance which configures the remote job.
|
221 | 268 |
|
@@ -555,6 +602,9 @@ def __init__(
|
555 | 602 | tags = format_tags(tags)
|
556 | 603 | self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)
|
557 | 604 |
|
| 605 | + self.use_torchrun = use_torchrun |
| 606 | + self.nproc_per_node = nproc_per_node |
| 607 | + |
558 | 608 | @staticmethod
|
559 | 609 | def _get_default_image(session):
|
560 | 610 | """Return Studio notebook image, if in Studio env. Else, base python.
|
@@ -951,7 +1001,7 @@ def _get_job_name(job_settings, func):
|
951 | 1001 |
|
952 | 1002 |
|
953 | 1003 | def _prepare_and_upload_runtime_scripts(
|
954 |
| - spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session |
| 1004 | + spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session, use_torchrun: bool = False, nproc_per_node: int = 1 |
955 | 1005 | ):
|
956 | 1006 | """Copy runtime scripts to a folder and upload to S3.
|
957 | 1007 |
|
@@ -988,6 +1038,10 @@ def _prepare_and_upload_runtime_scripts(
|
988 | 1038 | )
|
989 | 1039 | shutil.copy2(spark_script_path, bootstrap_scripts)
|
990 | 1040 |
|
| 1041 | + if use_torchrun: |
| 1042 | + entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT |
| 1043 | + entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node)) |
| 1044 | + |
991 | 1045 | with open(entrypoint_script_path, "w", newline="\n") as file:
|
992 | 1046 | file.writelines(entry_point_script)
|
993 | 1047 |
|
@@ -1025,6 +1079,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
|
1025 | 1079 | s3_base_uri=s3_base_uri,
|
1026 | 1080 | s3_kms_key=job_settings.s3_kms_key,
|
1027 | 1081 | sagemaker_session=job_settings.sagemaker_session,
|
| 1082 | + use_torchrun=job_settings.use_torchrun, |
| 1083 | + nproc_per_node=job_settings.nproc_per_node, |
1028 | 1084 | )
|
1029 | 1085 |
|
1030 | 1086 | input_data_config = [
|
|
0 commit comments