Skip to content

Commit 7dc43dc

Browse files
committed
change(processing): refactor s3_prefix & payload
Swap mandatory 's3_prefix' param in FrameworkProcessors for optional 'code_location' param inkeeping with Framework estimators. Separate out the 'code' from 'entrypoint' processing inputs to avoid having one input channel inside another in the container (input/payload). Update existing unit & integration tests to work with the new code.
1 parent 34f95bd commit 7dc43dc

File tree

8 files changed

+101
-67
lines changed

8 files changed

+101
-67
lines changed

src/sagemaker/mxnet/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class MXNetProcessor(FrameworkProcessor):
2929
def __init__(
3030
self,
3131
framework_version, # New arg
32-
s3_prefix, # New arg
3332
role,
3433
instance_count,
3534
instance_type,
@@ -38,6 +37,7 @@ def __init__(
3837
volume_size_in_gb=30,
3938
volume_kms_key=None,
4039
output_kms_key=None,
40+
code_location=None, # New arg
4141
max_runtime_in_seconds=None,
4242
base_job_name=None,
4343
sagemaker_session=None,
@@ -61,7 +61,6 @@ def __init__(
6161
super().__init__(
6262
self.estimator_cls,
6363
framework_version,
64-
s3_prefix,
6564
role,
6665
instance_count,
6766
instance_type,
@@ -70,6 +69,7 @@ def __init__(
7069
volume_size_in_gb,
7170
volume_kms_key,
7271
output_kms_key,
72+
code_location,
7373
max_runtime_in_seconds,
7474
base_job_name,
7575
sagemaker_session,

src/sagemaker/processing.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sagemaker import s3
3131
from sagemaker.job import _Job
3232
from sagemaker.local import LocalSession
33-
from sagemaker.utils import base_name_from_image, name_from_base
33+
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
3434
from sagemaker.session import Session
3535
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
3636
from sagemaker.workflow.properties import Properties
@@ -1220,7 +1220,7 @@ class FrameworkProcessor(ScriptProcessor):
12201220
runproc_sh = """#!/bin/bash
12211221
12221222
cd /opt/ml/processing/input/code/
1223-
tar -xzf payload/sourcedir.tar.gz
1223+
tar -xzf sourcedir.tar.gz
12241224
12251225
# Exit on any error. SageMaker uses error code to mark failed job.
12261226
set -e
@@ -1235,7 +1235,6 @@ def __init__(
12351235
self,
12361236
estimator_cls, # New arg
12371237
framework_version, # New arg
1238-
s3_prefix, # New arg
12391238
role,
12401239
instance_count,
12411240
instance_type,
@@ -1244,6 +1243,7 @@ def __init__(
12441243
volume_size_in_gb=30,
12451244
volume_kms_key=None,
12461245
output_kms_key=None,
1246+
code_location=None, # New arg
12471247
max_runtime_in_seconds=None,
12481248
base_job_name=None,
12491249
sagemaker_session=None,
@@ -1262,10 +1262,6 @@ def __init__(
12621262
estimator
12631263
framework_version (str): The version of the framework. Value is ignored when
12641264
``image_uri`` is provided.
1265-
s3_prefix (str): The S3 prefix URI where custom code will be
1266-
uploaded - don't include a trailing slash since a string prepended
1267-
with a "/" is appended to ``s3_prefix``. The code file uploaded to S3
1268-
is 's3_prefix/job-name/source/sourcedir.tar.gz'.
12691265
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing uses
12701266
this role to access AWS resources, such as data stored in Amazon S3.
12711267
instance_count (int): The number of instances to run a processing job with.
@@ -1280,6 +1276,10 @@ def __init__(
12801276
to use for storing data during processing (default: 30).
12811277
volume_kms_key (str): A KMS key for the processing volume (default: None).
12821278
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
1279+
code_location (str): The S3 prefix URI where custom code will be
1280+
uploaded (default: None). The code file uploaded to S3 is
1281+
'code_location/job-name/source/sourcedir.tar.gz'. If not specified, the
1282+
default ``code location`` is 's3://{sagemaker-default-bucket}'
12831283
max_runtime_in_seconds (int): Timeout in seconds (default: None).
12841284
After this amount of time, Amazon SageMaker terminates the job,
12851285
regardless of its current status. If `max_runtime_in_seconds` is not
@@ -1325,8 +1325,14 @@ def __init__(
13251325
tags=tags,
13261326
network_config=network_config,
13271327
)
1328+
# This subclass uses the "code" input for actual payload and the ScriptProcessor parent's
1329+
# functionality for uploading just a small entrypoint script to invoke it.
1330+
self._CODE_CONTAINER_INPUT_NAME = "entrypoint"
13281331

1329-
self.s3_prefix = s3_prefix
1332+
self.code_location = (
1333+
code_location[:-1] if (code_location and code_location.endswith("/"))
1334+
else code_location
1335+
)
13301336

13311337
def _pre_init_normalization(
13321338
self,
@@ -1474,12 +1480,26 @@ def run( # type: ignore[override]
14741480
)
14751481

14761482
# Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.
1477-
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1478-
self.runproc_sh.format(entry_point=entry_point),
1479-
desired_s3_uri=f"{self.s3_prefix}/{job_name}/source/runproc.sh",
1480-
sagemaker_session=self.sagemaker_session,
1481-
)
1482-
logger.info("runproc.sh uploaded to %s", s3_runproc_sh)
1483+
local_code = get_config_value("local.local_code", self.sagemaker_session.config)
1484+
if self.sagemaker_session.local_mode and local_code:
1485+
# TODO: Can we be more prescriptive about how to not trigger this error?
1486+
# How can user or us force a local mode `Estimator` to run with `local_code=False`?
1487+
raise RuntimeError(
1488+
"Local *code* is not currently supported for SageMaker Processing in Local Mode"
1489+
)
1490+
else:
1491+
# estimator
1492+
entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace(
1493+
"sourcedir.tar.gz",
1494+
"runproc.sh",
1495+
)
1496+
script = estimator.uploaded_code.script_name
1497+
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1498+
self.runproc_sh.format(entry_point=script),
1499+
desired_s3_uri=entrypoint_s3_uri,
1500+
sagemaker_session=self.sagemaker_session,
1501+
)
1502+
logger.info("runproc.sh uploaded to %s", s3_runproc_sh)
14831503

14841504
# Submit a processing job.
14851505
super().run(
@@ -1512,7 +1532,7 @@ def _upload_payload(
15121532
git_config=git_config,
15131533
framework_version=self.framework_version,
15141534
py_version=self.py_version,
1515-
code_location=self.s3_prefix, # Upload to <code_loc>/jobname/output/source.tar.gz
1535+
code_location=self.code_location, # Upload to <code_loc>/jobname/output/source.tar.gz
15161536
enable_network_isolation=False, # If true, uploads to input channel. Not what we want!
15171537
image_uri=self.image_uri, # The image uri is already normalized by this point.
15181538
role=self.role,
@@ -1550,6 +1570,10 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
15501570
if inputs is None:
15511571
inputs = []
15521572
inputs.append(
1553-
ProcessingInput(source=s3_payload, destination="/opt/ml/processing/input/code/payload/")
1573+
ProcessingInput(
1574+
input_name="code",
1575+
source=s3_payload,
1576+
destination="/opt/ml/processing/input/code/",
1577+
)
15541578
)
15551579
return inputs

src/sagemaker/pytorch/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class PyTorchProcessor(FrameworkProcessor):
2929
def __init__(
3030
self,
3131
framework_version, # New arg
32-
s3_prefix, # New arg
3332
role,
3433
instance_count,
3534
instance_type,
@@ -38,6 +37,7 @@ def __init__(
3837
volume_size_in_gb=30,
3938
volume_kms_key=None,
4039
output_kms_key=None,
40+
code_location=None, # New arg
4141
max_runtime_in_seconds=None,
4242
base_job_name=None,
4343
sagemaker_session=None,
@@ -61,7 +61,6 @@ def __init__(
6161
super().__init__(
6262
self.estimator_cls,
6363
framework_version,
64-
s3_prefix,
6564
role,
6665
instance_count,
6766
instance_type,
@@ -70,6 +69,7 @@ def __init__(
7069
volume_size_in_gb,
7170
volume_kms_key,
7271
output_kms_key,
72+
code_location,
7373
max_runtime_in_seconds,
7474
base_job_name,
7575
sagemaker_session,

src/sagemaker/sklearn/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ class SKLearnProcessor(FrameworkProcessor):
4343
def __init__(
4444
self,
4545
framework_version, # New arg
46-
s3_prefix, # New arg
4746
role,
4847
instance_count,
4948
instance_type,
@@ -52,6 +51,7 @@ def __init__(
5251
volume_size_in_gb=30,
5352
volume_kms_key=None,
5453
output_kms_key=None,
54+
code_location=None, # New arg
5555
max_runtime_in_seconds=None,
5656
base_job_name=None,
5757
sagemaker_session=None,
@@ -63,7 +63,6 @@ def __init__(
6363
super().__init__(
6464
self.estimator_cls,
6565
framework_version,
66-
s3_prefix,
6766
role,
6867
instance_count,
6968
instance_type,
@@ -72,6 +71,7 @@ def __init__(
7271
volume_size_in_gb,
7372
volume_kms_key,
7473
output_kms_key,
74+
code_location,
7575
max_runtime_in_seconds,
7676
base_job_name,
7777
sagemaker_session,

src/sagemaker/tensorflow/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class TensorFlowProcessor(FrameworkProcessor):
2929
def __init__(
3030
self,
3131
framework_version, # New arg
32-
s3_prefix, # New arg
3332
role,
3433
instance_count,
3534
instance_type,
@@ -38,6 +37,7 @@ def __init__(
3837
volume_size_in_gb=30,
3938
volume_kms_key=None,
4039
output_kms_key=None,
40+
code_location=None, # New arg
4141
max_runtime_in_seconds=None,
4242
base_job_name=None,
4343
sagemaker_session=None,
@@ -61,7 +61,6 @@ def __init__(
6161
super().__init__(
6262
self.estimator_cls,
6363
framework_version,
64-
s3_prefix,
6564
role,
6665
instance_count,
6766
instance_type,
@@ -70,6 +69,7 @@ def __init__(
7069
volume_size_in_gb,
7170
volume_kms_key,
7271
output_kms_key,
72+
code_location,
7373
max_runtime_in_seconds,
7474
base_job_name,
7575
sagemaker_session,

src/sagemaker/xgboost/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class XGBoostEstimator(FrameworkProcessor):
2929
def __init__(
3030
self,
3131
framework_version, # New arg
32-
s3_prefix, # New arg
3332
role,
3433
instance_count,
3534
instance_type,
@@ -38,6 +37,7 @@ def __init__(
3837
volume_size_in_gb=30,
3938
volume_kms_key=None,
4039
output_kms_key=None,
40+
code_location=None, # New arg
4141
max_runtime_in_seconds=None,
4242
base_job_name=None,
4343
sagemaker_session=None,
@@ -61,7 +61,6 @@ def __init__(
6161
super().__init__(
6262
self.estimator_cls,
6363
framework_version,
64-
s3_prefix,
6564
role,
6665
instance_count,
6766
instance_type,
@@ -70,6 +69,7 @@ def __init__(
7069
volume_size_in_gb,
7170
volume_kms_key,
7271
output_kms_key,
72+
code_location,
7373
max_runtime_in_seconds,
7474
base_job_name,
7575
sagemaker_session,

0 commit comments

Comments
 (0)