15
15
16
16
import os
17
17
import re
18
+ import shutil
18
19
from typing import Dict , List , Tuple
19
20
20
21
from sagemaker .config import config_schema
21
22
from sagemaker .config .config_factory import SageMakerConfigFactory
22
23
from sagemaker .session import get_execution_role , _logs_for_job , Session
23
- from sagemaker .utils import name_from_base
24
- from sagemaker .s3 import s3_path_join
24
+ from sagemaker .utils import name_from_base , _tmpdir
25
+ from sagemaker .s3 import s3_path_join , S3Uploader
25
26
from sagemaker import vpc_utils
26
27
from sagemaker .remote_function .core .stored_function import StoredFunction
27
- from sagemaker .remote_function .core .runtime_environment import RuntimeEnvironmentManager
28
+ from sagemaker .remote_function .runtime_environment .runtime_environment_manager import (
29
+ RuntimeEnvironmentManager ,
30
+ )
28
31
from sagemaker .remote_function import logging_config
29
32
30
33
31
- JOBS_CONTAINER_ENTRYPOINT = ["/bin/bash" , "job_driver.sh" ]
34
+ # runtime script names
35
+ BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py"
36
+ ENTRYPOINT_SCRIPT_NAME = "job_driver.sh"
37
+ RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py"
32
38
39
+ # training channel names
40
+ RUNTIME_SCRIPTS_CHANNEL_NAME = "remote_function_scripts"
41
+ USER_DEPENDENCIES_CHANNEL_NAME = "remote_function_dependencies"
42
+ SAGEMAKER_WHL_CHANNEL_NAME = "sagemaker_whl_file"
43
+
44
+ SAGEMAKER_SDK_WHL_FILE = (
45
+ "s3://sagemaker-pathways/test/pysdk/sagemaker-2.120.1.dev0-py2.py3-none-any.whl"
46
+ )
47
+
48
+ JOBS_CONTAINER_ENTRYPOINT = [
49
+ "/bin/bash" ,
50
+ f"/opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ ENTRYPOINT_SCRIPT_NAME } " ,
51
+ ]
52
+
53
+ ENTRYPOINT_SCRIPT = r"""
54
+ #!/bin/bash
55
+
56
+ # Entry point for bootstrapping runtime environment and invoking remote function
57
+
58
+ set -eu
59
+
60
+ printf "INFO: Bootstraping runtime environment.\n"
61
+ python /opt/ml/input/data/remote_function_scripts/bootstrap_runtime_environment.py "$@"
62
+
63
+ if [ -f "remote_function_conda_env.txt" ]
64
+ then
65
+ conda_env=$(cat remote_function_conda_env.txt)
66
+
67
+ if which mamba >/dev/null; then
68
+ conda_exe="mamba"
69
+ else
70
+ conda_exe="conda"
71
+ fi
72
+
73
+ printf "INFO: Invoking remote function inside conda environment: $conda_env.\n"
74
+ $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
75
+ else
76
+ printf "INFO: No conda env provided. Invoking remote function\n"
77
+ python -m sagemaker.remote_function.invoke_function "$@"
78
+ fi
79
+ """
33
80
34
81
logger = logging_config .get_logger ()
35
82
@@ -173,12 +220,20 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
173
220
174
221
s3_base_uri = s3_path_join (job_settings .s3_root_uri , job_name )
175
222
176
- runtime_environment_manager = RuntimeEnvironmentManager (
223
+ local_dependencies_path = RuntimeEnvironmentManager ().snapshot (job_settings .dependencies )
224
+
225
+ remote_function_scripts_s3uri = _prepare_and_upload_runtime_scripts (
226
+ s3_base_uri = s3_base_uri ,
227
+ s3_kms_key = job_settings .s3_kms_key ,
228
+ sagemaker_session = job_settings .sagemaker_session ,
229
+ )
230
+
231
+ user_dependencies_s3uri = _prepare_and_upload_dependencies (
232
+ local_dependencies_path = local_dependencies_path ,
177
233
s3_base_uri = s3_base_uri ,
178
234
s3_kms_key = job_settings .s3_kms_key ,
179
235
sagemaker_session = job_settings .sagemaker_session ,
180
236
)
181
- uploaded_dependencies_path = runtime_environment_manager .snapshot (job_settings .dependencies )
182
237
183
238
stored_function = StoredFunction (
184
239
sagemaker_session = job_settings .sagemaker_session ,
@@ -198,6 +253,46 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
198
253
RetryStrategy = {"MaximumRetryAttempts" : job_settings .max_retry_attempts },
199
254
)
200
255
256
+ input_data_config = [
257
+ dict (
258
+ ChannelName = RUNTIME_SCRIPTS_CHANNEL_NAME ,
259
+ DataSource = {
260
+ "S3DataSource" : {
261
+ "S3Uri" : remote_function_scripts_s3uri ,
262
+ "S3DataType" : "S3Prefix" ,
263
+ }
264
+ },
265
+ )
266
+ ]
267
+
268
+ if user_dependencies_s3uri :
269
+ input_data_config .append (
270
+ dict (
271
+ ChannelName = USER_DEPENDENCIES_CHANNEL_NAME ,
272
+ DataSource = {
273
+ "S3DataSource" : {
274
+ "S3Uri" : user_dependencies_s3uri ,
275
+ "S3DataType" : "S3Prefix" ,
276
+ }
277
+ },
278
+ )
279
+ )
280
+
281
+ # temporary solution for public beta to make sagemaker installer available
282
+ # in the images, this should be removed before pathways GA.
283
+ input_data_config .append (
284
+ dict (
285
+ ChannelName = SAGEMAKER_WHL_CHANNEL_NAME ,
286
+ DataSource = {
287
+ "S3DataSource" : {
288
+ "S3Uri" : SAGEMAKER_SDK_WHL_FILE ,
289
+ "S3DataType" : "S3Prefix" ,
290
+ }
291
+ },
292
+ )
293
+ )
294
+ request_dict ["InputDataConfig" ] = input_data_config
295
+
201
296
output_config = {"S3OutputPath" : s3_base_uri }
202
297
if job_settings .s3_kms_key is not None :
203
298
output_config ["KmsKeyId" ] = job_settings .s3_kms_key
@@ -207,8 +302,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
207
302
container_args .extend (["--region" , job_settings .sagemaker_session .boto_region_name ])
208
303
if job_settings .s3_kms_key :
209
304
container_args .extend (["--s3_kms_key" , job_settings .s3_kms_key ])
210
- if uploaded_dependencies_path :
211
- container_args .extend (["--dependencies" , uploaded_dependencies_path ])
305
+
212
306
if job_settings .job_conda_env :
213
307
container_args .extend (["--job_conda_env" , job_settings .job_conda_env ])
214
308
@@ -282,3 +376,48 @@ def wait(self, timeout: int = None):
282
376
wait = True ,
283
377
timeout = timeout ,
284
378
)
379
+
380
+
381
+ def _prepare_and_upload_runtime_scripts (
382
+ s3_base_uri : str , s3_kms_key : str , sagemaker_session : Session
383
+ ):
384
+ """Copy runtime scripts to a folder and upload to S3"""
385
+
386
+ with _tmpdir () as remote_function_scripts :
387
+
388
+ # write entrypoint script to tmpdir
389
+ entrypoint_script_path = os .path .join (remote_function_scripts , ENTRYPOINT_SCRIPT_NAME )
390
+ with open (entrypoint_script_path , "w" ) as file :
391
+ file .writelines (ENTRYPOINT_SCRIPT )
392
+
393
+ bootstrap_script_path = os .path .join (
394
+ os .path .dirname (__file__ ), "runtime_environment" , BOOTSTRAP_SCRIPT_NAME
395
+ )
396
+ runtime_manager_script_path = os .path .join (
397
+ os .path .dirname (__file__ ), "runtime_environment" , RUNTIME_MANAGER_SCRIPT_NAME
398
+ )
399
+
400
+ # copy runtime scripts to tmpdir
401
+ shutil .copy2 (bootstrap_script_path , remote_function_scripts )
402
+ shutil .copy2 (runtime_manager_script_path , remote_function_scripts )
403
+
404
+ return S3Uploader .upload (
405
+ remote_function_scripts ,
406
+ s3_path_join (s3_base_uri , RUNTIME_SCRIPTS_CHANNEL_NAME ),
407
+ s3_kms_key ,
408
+ sagemaker_session ,
409
+ )
410
+
411
+
412
+ def _prepare_and_upload_dependencies (
413
+ local_dependencies_path : str , s3_base_uri : str , s3_kms_key : str , sagemaker_session : Session
414
+ ):
415
+ """Upload dependency file to S3 if present"""
416
+ if local_dependencies_path :
417
+ return S3Uploader .upload (
418
+ local_dependencies_path ,
419
+ s3_path_join (s3_base_uri , USER_DEPENDENCIES_CHANNEL_NAME ),
420
+ s3_kms_key ,
421
+ sagemaker_session ,
422
+ )
423
+ return None
0 commit comments