Skip to content

Commit 8c82136

Browse files
Rohan GujarathiNamrata Madan
authored andcommitted
pathways: standalone job entrypoint scripts
1 parent 63ea434 commit 8c82136

File tree

14 files changed

+792
-559
lines changed

14 files changed

+792
-559
lines changed

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def read_requirements(filename):
103103
],
104104
install_requires=required_packages,
105105
extras_require=extras,
106-
scripts=["src/sagemaker/remote_function/job_driver.sh"],
107106
entry_points={
108107
"console_scripts": [
109108
"sagemaker-upgrade-v2=sagemaker.cli.compatibility.v2.sagemaker_upgrade_v2:main",

src/sagemaker/remote_function/bootstrap.py

Lines changed: 0 additions & 114 deletions
This file was deleted.

src/sagemaker/remote_function/errors.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@ class ServiceError(RemoteFunctionError):
3838
"""Raised when errors encountered during interaction with SageMaker, S3 service APIs"""
3939

4040

41-
@pickling_support.install
42-
class RuntimeEnvironmentError(RemoteFunctionError):
43-
"""Raised when errors encountered during remote function runtime environment setup"""
44-
45-
4641
@pickling_support.install
4742
class SerializationError(RemoteFunctionError):
4843
"""Raised when errors encountered during serialization of remote function objects"""

src/sagemaker/remote_function/job.py

Lines changed: 147 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,68 @@
1515

1616
import os
1717
import re
18+
import shutil
1819
from typing import Dict, List, Tuple
1920

2021
from sagemaker.config import config_schema
2122
from sagemaker.config.config_factory import SageMakerConfigFactory
2223
from sagemaker.session import get_execution_role, _logs_for_job, Session
23-
from sagemaker.utils import name_from_base
24-
from sagemaker.s3 import s3_path_join
24+
from sagemaker.utils import name_from_base, _tmpdir
25+
from sagemaker.s3 import s3_path_join, S3Uploader
2526
from sagemaker import vpc_utils
2627
from sagemaker.remote_function.core.stored_function import StoredFunction
27-
from sagemaker.remote_function.core.runtime_environment import RuntimeEnvironmentManager
28+
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
29+
RuntimeEnvironmentManager,
30+
)
2831
from sagemaker.remote_function import logging_config
2932

3033

31-
JOBS_CONTAINER_ENTRYPOINT = ["/bin/bash", "job_driver.sh"]
34+
# runtime script names
35+
BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py"
36+
ENTRYPOINT_SCRIPT_NAME = "job_driver.sh"
37+
RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py"
3238

39+
# training channel names
40+
RUNTIME_SCRIPTS_CHANNEL_NAME = "remote_function_scripts"
41+
USER_DEPENDENCIES_CHANNEL_NAME = "remote_function_dependencies"
42+
SAGEMAKER_WHL_CHANNEL_NAME = "sagemaker_whl_file"
43+
44+
SAGEMAKER_SDK_WHL_FILE = (
45+
"s3://sagemaker-pathways/test/pysdk/sagemaker-2.120.1.dev0-py2.py3-none-any.whl"
46+
)
47+
48+
JOBS_CONTAINER_ENTRYPOINT = [
49+
"/bin/bash",
50+
f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}",
51+
]
52+
53+
ENTRYPOINT_SCRIPT = r"""
54+
#!/bin/bash
55+
56+
# Entry point for bootstrapping runtime environment and invoking remote function
57+
58+
set -eu
59+
60+
printf "INFO: Bootstraping runtime environment.\n"
61+
python /opt/ml/input/data/remote_function_scripts/bootstrap_runtime_environment.py "$@"
62+
63+
if [ -f "remote_function_conda_env.txt" ]
64+
then
65+
conda_env=$(cat remote_function_conda_env.txt)
66+
67+
if which mamba >/dev/null; then
68+
conda_exe="mamba"
69+
else
70+
conda_exe="conda"
71+
fi
72+
73+
printf "INFO: Invoking remote function inside conda environment: $conda_env.\n"
74+
$conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
75+
else
76+
printf "INFO: No conda env provided. Invoking remote function\n"
77+
python -m sagemaker.remote_function.invoke_function "$@"
78+
fi
79+
"""
3380

3481
logger = logging_config.get_logger()
3582

@@ -173,12 +220,20 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
173220

174221
s3_base_uri = s3_path_join(job_settings.s3_root_uri, job_name)
175222

176-
runtime_environment_manager = RuntimeEnvironmentManager(
223+
local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies)
224+
225+
remote_function_scripts_s3uri = _prepare_and_upload_runtime_scripts(
226+
s3_base_uri=s3_base_uri,
227+
s3_kms_key=job_settings.s3_kms_key,
228+
sagemaker_session=job_settings.sagemaker_session,
229+
)
230+
231+
user_dependencies_s3uri = _prepare_and_upload_dependencies(
232+
local_dependencies_path=local_dependencies_path,
177233
s3_base_uri=s3_base_uri,
178234
s3_kms_key=job_settings.s3_kms_key,
179235
sagemaker_session=job_settings.sagemaker_session,
180236
)
181-
uploaded_dependencies_path = runtime_environment_manager.snapshot(job_settings.dependencies)
182237

183238
stored_function = StoredFunction(
184239
sagemaker_session=job_settings.sagemaker_session,
@@ -198,6 +253,46 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
198253
RetryStrategy={"MaximumRetryAttempts": job_settings.max_retry_attempts},
199254
)
200255

256+
input_data_config = [
257+
dict(
258+
ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME,
259+
DataSource={
260+
"S3DataSource": {
261+
"S3Uri": remote_function_scripts_s3uri,
262+
"S3DataType": "S3Prefix",
263+
}
264+
},
265+
)
266+
]
267+
268+
if user_dependencies_s3uri:
269+
input_data_config.append(
270+
dict(
271+
ChannelName=USER_DEPENDENCIES_CHANNEL_NAME,
272+
DataSource={
273+
"S3DataSource": {
274+
"S3Uri": user_dependencies_s3uri,
275+
"S3DataType": "S3Prefix",
276+
}
277+
},
278+
)
279+
)
280+
281+
# temporary solution for public beta to make sagemaker installer available
282+
# in the images, this should be removed before pathways GA.
283+
input_data_config.append(
284+
dict(
285+
ChannelName=SAGEMAKER_WHL_CHANNEL_NAME,
286+
DataSource={
287+
"S3DataSource": {
288+
"S3Uri": SAGEMAKER_SDK_WHL_FILE,
289+
"S3DataType": "S3Prefix",
290+
}
291+
},
292+
)
293+
)
294+
request_dict["InputDataConfig"] = input_data_config
295+
201296
output_config = {"S3OutputPath": s3_base_uri}
202297
if job_settings.s3_kms_key is not None:
203298
output_config["KmsKeyId"] = job_settings.s3_kms_key
@@ -207,8 +302,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
207302
container_args.extend(["--region", job_settings.sagemaker_session.boto_region_name])
208303
if job_settings.s3_kms_key:
209304
container_args.extend(["--s3_kms_key", job_settings.s3_kms_key])
210-
if uploaded_dependencies_path:
211-
container_args.extend(["--dependencies", uploaded_dependencies_path])
305+
212306
if job_settings.job_conda_env:
213307
container_args.extend(["--job_conda_env", job_settings.job_conda_env])
214308

@@ -282,3 +376,48 @@ def wait(self, timeout: int = None):
282376
wait=True,
283377
timeout=timeout,
284378
)
379+
380+
381+
def _prepare_and_upload_runtime_scripts(
382+
s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session
383+
):
384+
"""Copy runtime scripts to a folder and upload to S3"""
385+
386+
with _tmpdir() as remote_function_scripts:
387+
388+
# write entrypoint script to tmpdir
389+
entrypoint_script_path = os.path.join(remote_function_scripts, ENTRYPOINT_SCRIPT_NAME)
390+
with open(entrypoint_script_path, "w") as file:
391+
file.writelines(ENTRYPOINT_SCRIPT)
392+
393+
bootstrap_script_path = os.path.join(
394+
os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME
395+
)
396+
runtime_manager_script_path = os.path.join(
397+
os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME
398+
)
399+
400+
# copy runtime scripts to tmpdir
401+
shutil.copy2(bootstrap_script_path, remote_function_scripts)
402+
shutil.copy2(runtime_manager_script_path, remote_function_scripts)
403+
404+
return S3Uploader.upload(
405+
remote_function_scripts,
406+
s3_path_join(s3_base_uri, RUNTIME_SCRIPTS_CHANNEL_NAME),
407+
s3_kms_key,
408+
sagemaker_session,
409+
)
410+
411+
412+
def _prepare_and_upload_dependencies(
413+
local_dependencies_path: str, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session
414+
):
415+
"""Upload dependency file to S3 if present"""
416+
if local_dependencies_path:
417+
return S3Uploader.upload(
418+
local_dependencies_path,
419+
s3_path_join(s3_base_uri, USER_DEPENDENCIES_CHANNEL_NAME),
420+
s3_kms_key,
421+
sagemaker_session,
422+
)
423+
return None

src/sagemaker/remote_function/job_driver.sh

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)