30
30
from sagemaker import s3
31
31
from sagemaker .job import _Job
32
32
from sagemaker .local import LocalSession
33
- from sagemaker .utils import base_name_from_image , name_from_base
33
+ from sagemaker .utils import base_name_from_image , get_config_value , name_from_base
34
34
from sagemaker .session import Session
35
35
from sagemaker .network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
36
36
from sagemaker .workflow .properties import Properties
@@ -1220,7 +1220,7 @@ class FrameworkProcessor(ScriptProcessor):
1220
1220
runproc_sh = """#!/bin/bash
1221
1221
1222
1222
cd /opt/ml/processing/input/code/
1223
- tar -xzf payload/ sourcedir.tar.gz
1223
+ tar -xzf sourcedir.tar.gz
1224
1224
1225
1225
# Exit on any error. SageMaker uses error code to mark failed job.
1226
1226
set -e
@@ -1235,7 +1235,6 @@ def __init__(
1235
1235
self ,
1236
1236
estimator_cls , # New arg
1237
1237
framework_version , # New arg
1238
- s3_prefix , # New arg
1239
1238
role ,
1240
1239
instance_count ,
1241
1240
instance_type ,
@@ -1244,6 +1243,7 @@ def __init__(
1244
1243
volume_size_in_gb = 30 ,
1245
1244
volume_kms_key = None ,
1246
1245
output_kms_key = None ,
1246
+ code_location = None , # New arg
1247
1247
max_runtime_in_seconds = None ,
1248
1248
base_job_name = None ,
1249
1249
sagemaker_session = None ,
@@ -1262,10 +1262,6 @@ def __init__(
1262
1262
estimator
1263
1263
framework_version (str): The version of the framework. Value is ignored when
1264
1264
``image_uri`` is provided.
1265
- s3_prefix (str): The S3 prefix URI where custom code will be
1266
- uploaded - don't include a trailing slash since a string prepended
1267
- with a "/" is appended to ``s3_prefix``. The code file uploaded to S3
1268
- is 's3_prefix/job-name/source/sourcedir.tar.gz'.
1269
1265
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing uses
1270
1266
this role to access AWS resources, such as data stored in Amazon S3.
1271
1267
instance_count (int): The number of instances to run a processing job with.
@@ -1280,6 +1276,10 @@ def __init__(
1280
1276
to use for storing data during processing (default: 30).
1281
1277
volume_kms_key (str): A KMS key for the processing volume (default: None).
1282
1278
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
1279
+ code_location (str): The S3 prefix URI where custom code will be
1280
+ uploaded (default: None). The code file uploaded to S3 is
1281
+ 'code_location/job-name/source/sourcedir.tar.gz'. If not specified, the
1282
+ default ``code location`` is 's3://{sagemaker-default-bucket}'
1283
1283
max_runtime_in_seconds (int): Timeout in seconds (default: None).
1284
1284
After this amount of time, Amazon SageMaker terminates the job,
1285
1285
regardless of its current status. If `max_runtime_in_seconds` is not
@@ -1325,8 +1325,14 @@ def __init__(
1325
1325
tags = tags ,
1326
1326
network_config = network_config ,
1327
1327
)
1328
+ # This subclass uses the "code" input for actual payload and the ScriptProcessor parent's
1329
+ # functionality for uploading just a small entrypoint script to invoke it.
1330
+ self ._CODE_CONTAINER_INPUT_NAME = "entrypoint"
1328
1331
1329
- self .s3_prefix = s3_prefix
1332
+ self .code_location = (
1333
+ code_location [:- 1 ] if (code_location and code_location .endswith ("/" ))
1334
+ else code_location
1335
+ )
1330
1336
1331
1337
def _pre_init_normalization (
1332
1338
self ,
@@ -1474,12 +1480,26 @@ def run( # type: ignore[override]
1474
1480
)
1475
1481
1476
1482
# Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.
1477
- s3_runproc_sh = S3Uploader .upload_string_as_file_body (
1478
- self .runproc_sh .format (entry_point = entry_point ),
1479
- desired_s3_uri = f"{ self .s3_prefix } /{ job_name } /source/runproc.sh" ,
1480
- sagemaker_session = self .sagemaker_session ,
1481
- )
1482
- logger .info ("runproc.sh uploaded to %s" , s3_runproc_sh )
1483
+ local_code = get_config_value ("local.local_code" , self .sagemaker_session .config )
1484
+ if self .sagemaker_session .local_mode and local_code :
1485
+ # TODO: Can we be more prescriptive about how to not trigger this error?
1486
+ # How can user or us force a local mode `Estimator` to run with `local_code=False`?
1487
+ raise RuntimeError (
1488
+ "Local *code* is not currently supported for SageMaker Processing in Local Mode"
1489
+ )
1490
+ else :
1491
+ # estimator
1492
+ entrypoint_s3_uri = estimator .uploaded_code .s3_prefix .replace (
1493
+ "sourcedir.tar.gz" ,
1494
+ "runproc.sh" ,
1495
+ )
1496
+ script = estimator .uploaded_code .script_name
1497
+ s3_runproc_sh = S3Uploader .upload_string_as_file_body (
1498
+ self .runproc_sh .format (entry_point = script ),
1499
+ desired_s3_uri = entrypoint_s3_uri ,
1500
+ sagemaker_session = self .sagemaker_session ,
1501
+ )
1502
+ logger .info ("runproc.sh uploaded to %s" , s3_runproc_sh )
1483
1503
1484
1504
# Submit a processing job.
1485
1505
super ().run (
@@ -1512,7 +1532,7 @@ def _upload_payload(
1512
1532
git_config = git_config ,
1513
1533
framework_version = self .framework_version ,
1514
1534
py_version = self .py_version ,
1515
- code_location = self .s3_prefix , # Upload to <code_loc>/jobname/output/source.tar.gz
1535
+ code_location = self .code_location , # Upload to <code_loc>/jobname/output/source.tar.gz
1516
1536
enable_network_isolation = False , # If true, uploads to input channel. Not what we want!
1517
1537
image_uri = self .image_uri , # The image uri is already normalized by this point.
1518
1538
role = self .role ,
@@ -1550,6 +1570,10 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
1550
1570
if inputs is None :
1551
1571
inputs = []
1552
1572
inputs .append (
1553
- ProcessingInput (source = s3_payload , destination = "/opt/ml/processing/input/code/payload/" )
1573
+ ProcessingInput (
1574
+ input_name = "code" ,
1575
+ source = s3_payload ,
1576
+ destination = "/opt/ml/processing/input/code/" ,
1577
+ )
1554
1578
)
1555
1579
return inputs
0 commit comments