Skip to content

Commit 51464dd

Browse files
committed
fix: restore 'command' for FrameworkProcessors
FrameworkProcessor now accepts a 'command' param like the ScriptProcessor class did previously, and tries to maintain similar API by forwarding it to the parent class... But patching the container run command generation from the parent to override. Also add initial unit tests for multi-instance and overridden commands.
1 parent 7aeffc5 commit 51464dd

File tree

7 files changed

+81
-18
lines changed

7 files changed

+81
-18
lines changed

src/sagemaker/mxnet/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
instance_type,
3535
py_version="py3", # New kwarg
3636
image_uri=None,
37+
command=["python"],
3738
volume_size_in_gb=30,
3839
volume_kms_key=None,
3940
output_kms_key=None,
@@ -66,6 +67,7 @@ def __init__(
6667
instance_type,
6768
py_version,
6869
image_uri,
70+
command,
6971
volume_size_in_gb,
7072
volume_kms_key,
7173
output_kms_key,

src/sagemaker/processing.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import pathlib
2323
import logging
24+
from textwrap import dedent
2425
from typing import Dict, List, Optional, Tuple
2526
import attr
2627

@@ -1217,18 +1218,7 @@ class FeatureStoreOutput(ApiObject):
12171218
class FrameworkProcessor(ScriptProcessor):
12181219
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
12191220

1220-
runproc_sh = """#!/bin/bash
1221-
1222-
cd /opt/ml/processing/input/code/
1223-
tar -xzf sourcedir.tar.gz
1224-
1225-
# Exit on any error. SageMaker uses error code to mark failed job.
1226-
set -e
1227-
1228-
[[ -f 'requirements.txt' ]] && pip install -r requirements.txt
1229-
1230-
python {entry_point} "$@"
1231-
"""
1221+
framework_entrypoint_command = ["/bin/bash"]
12321222

12331223
# Added new (kw)args for estimator. The rest are from ScriptProcessor with same defaults.
12341224
def __init__(
@@ -1240,6 +1230,7 @@ def __init__(
12401230
instance_type,
12411231
py_version="py3", # New kwarg
12421232
image_uri=None,
1233+
command=["python"],
12431234
volume_size_in_gb=30,
12441235
volume_kms_key=None,
12451236
output_kms_key=None,
@@ -1272,6 +1263,8 @@ def __init__(
12721263
is ignored when ``image_uri`` is provided.
12731264
image_uri (str): The URI of the Docker image to use for the
12741265
processing jobs (default: None).
1266+
command ([str]): The command to run, along with any command-line flags
1267+
to *precede* the ```entry_point script``` (default: ['python']).
12751268
volume_size_in_gb (int): Size in GB of the EBS volume
12761269
to use for storing data during processing (default: 30).
12771270
volume_kms_key (str): A KMS key for the processing volume (default: None).
@@ -1312,7 +1305,7 @@ def __init__(
13121305
super().__init__(
13131306
role=role,
13141307
image_uri=image_uri,
1315-
command=["/bin/bash"],
1308+
command=command,
13161309
instance_count=instance_count,
13171310
instance_type=instance_type,
13181311
volume_size_in_gb=volume_size_in_gb,
@@ -1493,7 +1486,7 @@ def run( # type: ignore[override]
14931486
)
14941487
script = estimator.uploaded_code.script_name
14951488
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1496-
self.runproc_sh.format(entry_point=script),
1489+
self._generate_framework_script(script),
14971490
desired_s3_uri=entrypoint_s3_uri,
14981491
sagemaker_session=self.sagemaker_session,
14991492
)
@@ -1512,6 +1505,35 @@ def run( # type: ignore[override]
15121505
kms_key=kms_key,
15131506
)
15141507

1508+
def _generate_framework_script(self, user_script: str) -> str:
1509+
"""Generate the framework entrypoint file (as text) for a processing job.
1510+
1511+
This script implements the "framework" functionality for setting up your code:
1512+
Untar-ing the sourcedir bundle in the ```code``` input; installing extra
1513+
runtime dependencies if specified; and then invoking the ```command``` and
1514+
```entry_point``` configured for the job.
1515+
1516+
Args:
1517+
user_script (str): Relative path to ```entry_point``` in the source bundle
1518+
- e.g. 'process.py'.
1519+
"""
1520+
return dedent("""\
1521+
#!/bin/bash
1522+
1523+
cd /opt/ml/processing/input/code/
1524+
tar -xzf sourcedir.tar.gz
1525+
1526+
# Exit on any error. SageMaker uses error code to mark failed job.
1527+
set -e
1528+
1529+
[[ -f 'requirements.txt' ]] && pip install -r requirements.txt
1530+
1531+
{entry_point_command} {entry_point} "$@"
1532+
""").format(
1533+
entry_point_command=" ".join(self.command),
1534+
entry_point=user_script,
1535+
)
1536+
15151537
def _upload_payload(
15161538
self,
15171539
entry_point: str,
@@ -1575,3 +1597,18 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
15751597
)
15761598
)
15771599
return inputs
1600+
1601+
def _set_entrypoint(self, command, user_script_name):
1602+
"""FrameworkProcessor override for setting processing job entrypoint.
1603+
1604+
Args:
1605+
command ([str]): Ignored in favor of self.framework_entrypoint_command
1606+
user_script_name (str): A filename with an extension.
1607+
"""
1608+
1609+
user_script_location = str(
1610+
pathlib.PurePosixPath(
1611+
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
1612+
)
1613+
)
1614+
self.entrypoint = self.framework_entrypoint_command + [user_script_location]

src/sagemaker/pytorch/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
instance_type,
3535
py_version="py3", # New kwarg
3636
image_uri=None,
37+
command=["python"],
3738
volume_size_in_gb=30,
3839
volume_kms_key=None,
3940
output_kms_key=None,
@@ -66,6 +67,7 @@ def __init__(
6667
instance_type,
6768
py_version,
6869
image_uri,
70+
command,
6971
volume_size_in_gb,
7072
volume_kms_key,
7173
output_kms_key,

src/sagemaker/sklearn/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
instance_type,
4949
py_version="py3", # New kwarg
5050
image_uri=None,
51+
command=["python"],
5152
volume_size_in_gb=30,
5253
volume_kms_key=None,
5354
output_kms_key=None,
@@ -68,6 +69,7 @@ def __init__(
6869
instance_type,
6970
py_version,
7071
image_uri,
72+
command,
7173
volume_size_in_gb,
7274
volume_kms_key,
7375
output_kms_key,

src/sagemaker/tensorflow/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
instance_type,
3535
py_version="py3", # New kwarg
3636
image_uri=None,
37+
command=["python"],
3738
volume_size_in_gb=30,
3839
volume_kms_key=None,
3940
output_kms_key=None,
@@ -66,6 +67,7 @@ def __init__(
6667
instance_type,
6768
py_version,
6869
image_uri,
70+
command,
6971
volume_size_in_gb,
7072
volume_kms_key,
7173
output_kms_key,

src/sagemaker/xgboost/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
instance_type,
3535
py_version="py3", # New kwarg
3636
image_uri=None,
37+
command=["python"],
3738
volume_size_in_gb=30,
3839
volume_kms_key=None,
3940
output_kms_key=None,
@@ -66,6 +67,7 @@ def __init__(
6667
instance_type,
6768
py_version,
6869
image_uri,
70+
command,
6971
volume_size_in_gb,
7072
volume_kms_key,
7173
output_kms_key,

tests/unit/test_processing.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def test_sklearn_with_all_parameters(
106106
processor = SKLearnProcessor(
107107
role=ROLE,
108108
framework_version=sklearn_version,
109+
command=["Rscript"],
109110
instance_type="ml.m4.xlarge",
110111
instance_count=1,
111112
volume_size_in_gb=100,
@@ -153,12 +154,14 @@ def test_sklearn_with_all_parameters_via_run_args(
153154
exists_mock, isfile_mock, botocore_resolver, sklearn_version, sagemaker_session
154155
):
155156
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
157+
custom_command = ["Rscript"]
156158

157159
processor = SKLearnProcessor(
158160
role=ROLE,
159161
framework_version=sklearn_version,
162+
command=custom_command,
160163
instance_type="ml.m4.xlarge",
161-
instance_count=1,
164+
instance_count=2,
162165
volume_size_in_gb=100,
163166
volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
164167
output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
@@ -195,14 +198,27 @@ def test_sklearn_with_all_parameters_via_run_args(
195198
experiment_config={"ExperimentName": "AnExperiment"},
196199
)
197200

198-
expected_args = _get_expected_args_all_parameters_modular_code(processor._current_job_name)
201+
expected_args = _get_expected_args_all_parameters_modular_code(
202+
processor._current_job_name,
203+
instance_count=2,
204+
)
199205
sklearn_image_uri = (
200206
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
201207
).format(sklearn_version)
202208
expected_args["app_specification"]["ImageUri"] = sklearn_image_uri
203209

204210
sagemaker_session.process.assert_called_with(**expected_args)
205211

212+
# Verify the alternate command was applied successfully:
213+
framework_script = processor._generate_framework_script("processing_code.py")
214+
expected_invocation = f"{' '.join(custom_command)} processing_code.py"
215+
assert f"\n{expected_invocation}" in framework_script, (
216+
"Framework script should contain customized invocation:\n{}\n\nGot:\n{}".format(
217+
expected_invocation,
218+
framework_script,
219+
)
220+
)
221+
206222

207223
@patch("sagemaker.utils._botocore_resolver")
208224
@patch("os.path.exists", return_value=True)
@@ -811,7 +827,7 @@ def _get_data_outputs_all_parameters():
811827
]
812828

813829

814-
def _get_expected_args_all_parameters_modular_code(job_name, code_s3_uri=MOCKED_S3_URI):
830+
def _get_expected_args_all_parameters_modular_code(job_name, code_s3_uri=MOCKED_S3_URI, instance_count=1):
815831
# Add something to inputs
816832
return {
817833
"inputs": [
@@ -927,7 +943,7 @@ def _get_expected_args_all_parameters_modular_code(job_name, code_s3_uri=MOCKED_
927943
"resources": {
928944
"ClusterConfig": {
929945
"InstanceType": "ml.m4.xlarge",
930-
"InstanceCount": 1,
946+
"InstanceCount": instance_count,
931947
"VolumeSizeInGB": 100,
932948
"VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
933949
}

0 commit comments

Comments
 (0)