Skip to content

Commit ca8490a

Browse files
authored
Merge pull request aws#6 from athewsey/feat/fw-processor
Restore 'command' for FrameworkProcessors
2 parents 7aeffc5 + 51464dd commit ca8490a

File tree

7 files changed

+81
-18
lines changed

7 files changed

+81
-18
lines changed

src/sagemaker/mxnet/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
instance_type,
3535
py_version="py3", # New kwarg
3636
image_uri=None,
37+
command=["python"],
3738
volume_size_in_gb=30,
3839
volume_kms_key=None,
3940
output_kms_key=None,
@@ -66,6 +67,7 @@ def __init__(
6667
instance_type,
6768
py_version,
6869
image_uri,
70+
command,
6971
volume_size_in_gb,
7072
volume_kms_key,
7173
output_kms_key,

src/sagemaker/processing.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import pathlib
2323
import logging
24+
from textwrap import dedent
2425
from typing import Dict, List, Optional, Tuple
2526
import attr
2627

@@ -1217,18 +1218,7 @@ class FeatureStoreOutput(ApiObject):
12171218
class FrameworkProcessor(ScriptProcessor):
12181219
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
12191220

1220-
runproc_sh = """#!/bin/bash
1221-
1222-
cd /opt/ml/processing/input/code/
1223-
tar -xzf sourcedir.tar.gz
1224-
1225-
# Exit on any error. SageMaker uses error code to mark failed job.
1226-
set -e
1227-
1228-
[[ -f 'requirements.txt' ]] && pip install -r requirements.txt
1229-
1230-
python {entry_point} "$@"
1231-
"""
1221+
framework_entrypoint_command = ["/bin/bash"]
12321222

12331223
# Added new (kw)args for estimator. The rest are from ScriptProcessor with same defaults.
12341224
def __init__(
@@ -1240,6 +1230,7 @@ def __init__(
12401230
instance_type,
12411231
py_version="py3", # New kwarg
12421232
image_uri=None,
1233+
command=["python"],
12431234
volume_size_in_gb=30,
12441235
volume_kms_key=None,
12451236
output_kms_key=None,
@@ -1272,6 +1263,8 @@ def __init__(
12721263
is ignored when ``image_uri`` is provided.
12731264
image_uri (str): The URI of the Docker image to use for the
12741265
processing jobs (default: None).
1266+
command ([str]): The command to run, along with any command-line flags
1267+
to *precede* the ```entry_point script``` (default: ['python']).
12751268
volume_size_in_gb (int): Size in GB of the EBS volume
12761269
to use for storing data during processing (default: 30).
12771270
volume_kms_key (str): A KMS key for the processing volume (default: None).
@@ -1312,7 +1305,7 @@ def __init__(
13121305
super().__init__(
13131306
role=role,
13141307
image_uri=image_uri,
1315-
command=["/bin/bash"],
1308+
command=command,
13161309
instance_count=instance_count,
13171310
instance_type=instance_type,
13181311
volume_size_in_gb=volume_size_in_gb,
@@ -1493,7 +1486,7 @@ def run( # type: ignore[override]
14931486
)
14941487
script = estimator.uploaded_code.script_name
14951488
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1496-
self.runproc_sh.format(entry_point=script),
1489+
self._generate_framework_script(script),
14971490
desired_s3_uri=entrypoint_s3_uri,
14981491
sagemaker_session=self.sagemaker_session,
14991492
)
@@ -1512,6 +1505,35 @@ def run( # type: ignore[override]
15121505
kms_key=kms_key,
15131506
)
15141507

1508+
def _generate_framework_script(self, user_script: str) -> str:
1509+
"""Generate the framework entrypoint file (as text) for a processing job.
1510+
1511+
This script implements the "framework" functionality for setting up your code:
1512+
Untar-ing the sourcedir bundle in the ```code``` input; installing extra
1513+
runtime dependencies if specified; and then invoking the ```command``` and
1514+
```entry_point``` configured for the job.
1515+
1516+
Args:
1517+
user_script (str): Relative path to ```entry_point``` in the source bundle
1518+
- e.g. 'process.py'.
1519+
"""
1520+
return dedent("""\
1521+
#!/bin/bash
1522+
1523+
cd /opt/ml/processing/input/code/
1524+
tar -xzf sourcedir.tar.gz
1525+
1526+
# Exit on any error. SageMaker uses error code to mark failed job.
1527+
set -e
1528+
1529+
[[ -f 'requirements.txt' ]] && pip install -r requirements.txt
1530+
1531+
{entry_point_command} {entry_point} "$@"
1532+
""").format(
1533+
entry_point_command=" ".join(self.command),
1534+
entry_point=user_script,
1535+
)
1536+
15151537
def _upload_payload(
15161538
self,
15171539
entry_point: str,
@@ -1575,3 +1597,18 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
15751597
)
15761598
)
15771599
return inputs
1600+
1601+
def _set_entrypoint(self, command, user_script_name):
1602+
"""FrameworkProcessor override for setting processing job entrypoint.
1603+
1604+
Args:
1605+
command ([str]): Ignored in favor of self.framework_entrypoint_command
1606+
user_script_name (str): A filename with an extension.
1607+
"""
1608+
1609+
user_script_location = str(
1610+
pathlib.PurePosixPath(
1611+
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
1612+
)
1613+
)
1614+
self.entrypoint = self.framework_entrypoint_command + [user_script_location]

src/sagemaker/pytorch/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
instance_type,
3535
py_version="py3", # New kwarg
3636
image_uri=None,
37+
command=["python"],
3738
volume_size_in_gb=30,
3839
volume_kms_key=None,
3940
output_kms_key=None,
@@ -66,6 +67,7 @@ def __init__(
6667
instance_type,
6768
py_version,
6869
image_uri,
70+
command,
6971
volume_size_in_gb,
7072
volume_kms_key,
7173
output_kms_key,

src/sagemaker/sklearn/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
instance_type,
4949
py_version="py3", # New kwarg
5050
image_uri=None,
51+
command=["python"],
5152
volume_size_in_gb=30,
5253
volume_kms_key=None,
5354
output_kms_key=None,
@@ -68,6 +69,7 @@ def __init__(
6869
instance_type,
6970
py_version,
7071
image_uri,
72+
command,
7173
volume_size_in_gb,
7274
volume_kms_key,
7375
output_kms_key,

src/sagemaker/tensorflow/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
instance_type,
3535
py_version="py3", # New kwarg
3636
image_uri=None,
37+
command=["python"],
3738
volume_size_in_gb=30,
3839
volume_kms_key=None,
3940
output_kms_key=None,
@@ -66,6 +67,7 @@ def __init__(
6667
instance_type,
6768
py_version,
6869
image_uri,
70+
command,
6971
volume_size_in_gb,
7072
volume_kms_key,
7173
output_kms_key,

src/sagemaker/xgboost/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
instance_type,
3535
py_version="py3", # New kwarg
3636
image_uri=None,
37+
command=["python"],
3738
volume_size_in_gb=30,
3839
volume_kms_key=None,
3940
output_kms_key=None,
@@ -66,6 +67,7 @@ def __init__(
6667
instance_type,
6768
py_version,
6869
image_uri,
70+
command,
6971
volume_size_in_gb,
7072
volume_kms_key,
7173
output_kms_key,

tests/unit/test_processing.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def test_sklearn_with_all_parameters(
106106
processor = SKLearnProcessor(
107107
role=ROLE,
108108
framework_version=sklearn_version,
109+
command=["Rscript"],
109110
instance_type="ml.m4.xlarge",
110111
instance_count=1,
111112
volume_size_in_gb=100,
@@ -153,12 +154,14 @@ def test_sklearn_with_all_parameters_via_run_args(
153154
exists_mock, isfile_mock, botocore_resolver, sklearn_version, sagemaker_session
154155
):
155156
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
157+
custom_command = ["Rscript"]
156158

157159
processor = SKLearnProcessor(
158160
role=ROLE,
159161
framework_version=sklearn_version,
162+
command=custom_command,
160163
instance_type="ml.m4.xlarge",
161-
instance_count=1,
164+
instance_count=2,
162165
volume_size_in_gb=100,
163166
volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
164167
output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
@@ -195,14 +198,27 @@ def test_sklearn_with_all_parameters_via_run_args(
195198
experiment_config={"ExperimentName": "AnExperiment"},
196199
)
197200

198-
expected_args = _get_expected_args_all_parameters_modular_code(processor._current_job_name)
201+
expected_args = _get_expected_args_all_parameters_modular_code(
202+
processor._current_job_name,
203+
instance_count=2,
204+
)
199205
sklearn_image_uri = (
200206
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
201207
).format(sklearn_version)
202208
expected_args["app_specification"]["ImageUri"] = sklearn_image_uri
203209

204210
sagemaker_session.process.assert_called_with(**expected_args)
205211

212+
# Verify the alternate command was applied successfully:
213+
framework_script = processor._generate_framework_script("processing_code.py")
214+
expected_invocation = f"{' '.join(custom_command)} processing_code.py"
215+
assert f"\n{expected_invocation}" in framework_script, (
216+
"Framework script should contain customized invocation:\n{}\n\nGot:\n{}".format(
217+
expected_invocation,
218+
framework_script,
219+
)
220+
)
221+
206222

207223
@patch("sagemaker.utils._botocore_resolver")
208224
@patch("os.path.exists", return_value=True)
@@ -811,7 +827,7 @@ def _get_data_outputs_all_parameters():
811827
]
812828

813829

814-
def _get_expected_args_all_parameters_modular_code(job_name, code_s3_uri=MOCKED_S3_URI):
830+
def _get_expected_args_all_parameters_modular_code(job_name, code_s3_uri=MOCKED_S3_URI, instance_count=1):
815831
# Add something to inputs
816832
return {
817833
"inputs": [
@@ -927,7 +943,7 @@ def _get_expected_args_all_parameters_modular_code(job_name, code_s3_uri=MOCKED_
927943
"resources": {
928944
"ClusterConfig": {
929945
"InstanceType": "ml.m4.xlarge",
930-
"InstanceCount": 1,
946+
"InstanceCount": instance_count,
931947
"VolumeSizeInGB": 100,
932948
"VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
933949
}

0 commit comments

Comments
 (0)