Skip to content

Commit 9e70866

Browse files
committed
Merge branch 'support-estimator-output-param'
2 parents 8c52f1b + e315711 commit 9e70866

File tree

4 files changed

+173
-11
lines changed

4 files changed

+173
-11
lines changed

src/sagemaker/estimator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,14 +695,19 @@ def _stage_user_code_in_s3(self) -> str:
695695
696696
Returns: S3 URI
697697
"""
698-
local_mode = self.output_path.startswith("file://")
698+
local_mode = not is_pipeline_variable(self.output_path) and self.output_path.startswith(
699+
"file://"
700+
)
699701

700702
if self.code_location is None and local_mode:
701703
code_bucket = self.sagemaker_session.default_bucket()
702704
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
703705
kms_key = None
704706
elif self.code_location is None:
705-
code_bucket, _ = parse_s3_url(self.output_path)
707+
if is_pipeline_variable(self.output_path):
708+
code_bucket = self.sagemaker_session.default_bucket()
709+
else:
710+
code_bucket, _ = parse_s3_url(self.output_path)
706711
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
707712
kms_key = self.output_kms_key
708713
elif local_mode:
@@ -713,7 +718,10 @@ def _stage_user_code_in_s3(self) -> str:
713718
code_bucket, key_prefix = parse_s3_url(self.code_location)
714719
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"]))
715720

716-
output_bucket, _ = parse_s3_url(self.output_path)
721+
if is_pipeline_variable(self.output_path):
722+
output_bucket = self.sagemaker_session.default_bucket()
723+
else:
724+
output_bucket, _ = parse_s3_url(self.output_path)
717725
kms_key = self.output_kms_key if code_bucket == output_bucket else None
718726

719727
return tar_and_upload_dir(

tests/data/_repack_model.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Repack model script for training jobs to inject entry points"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import os
18+
import shutil
19+
import tarfile
20+
import tempfile
21+
22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
29+
# distutils.dir_util.copy_tree works way better than the half-baked
30+
# shutil.copytree which bombs on previously existing target dirs...
31+
# alas ... https://bugs.python.org/issue10948
32+
# we'll go ahead and use the copy_tree function anyways because this
33+
# repacking is some short-lived hackery, right??
34+
from distutils.dir_util import copy_tree
35+
36+
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
38+
"""Repack custom dependencies and code into an existing model TAR archive
39+
40+
Args:
41+
inference_script (str): The path to the custom entry point.
42+
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
43+
dependencies (str): A space-delimited string of paths to custom dependencies.
44+
source_dir (str): The path to a custom source directory.
45+
"""
46+
47+
# the data directory contains a model archive generated by a previous training job
48+
data_directory = "/opt/ml/input/data/training"
49+
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
50+
51+
# create a temporary directory
52+
with tempfile.TemporaryDirectory() as tmp:
53+
local_path = os.path.join(tmp, "local.tar.gz")
54+
# copy the previous training job's model archive to the temporary directory
55+
shutil.copy2(model_path, local_path)
56+
src_dir = os.path.join(tmp, "src")
57+
# create the "code" directory which will contain the inference script
58+
code_dir = os.path.join(src_dir, "code")
59+
os.makedirs(code_dir)
60+
# extract the contents of the previous training job's model archive to the "src"
61+
# directory of this training job
62+
with tarfile.open(name=local_path, mode="r:gz") as tf:
63+
tf.extractall(path=src_dir)
64+
65+
if source_dir:
66+
# copy /opt/ml/code to code/
67+
if os.path.exists(code_dir):
68+
shutil.rmtree(code_dir)
69+
shutil.copytree("/opt/ml/code", code_dir)
70+
else:
71+
# copy the custom inference script to code/
72+
entry_point = os.path.join("/opt/ml/code", inference_script)
73+
shutil.copy2(entry_point, os.path.join(code_dir, inference_script))
74+
75+
# copy any dependencies to code/lib/
76+
if dependencies:
77+
for dependency in dependencies.split(" "):
78+
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
79+
lib_dir = os.path.join(code_dir, "lib")
80+
if not os.path.exists(lib_dir):
81+
os.mkdir(lib_dir)
82+
if os.path.isfile(actual_dependency_path):
83+
shutil.copy2(actual_dependency_path, lib_dir)
84+
else:
85+
if os.path.exists(lib_dir):
86+
shutil.rmtree(lib_dir)
87+
# a directory is in the dependencies. we have to copy
88+
# all of /opt/ml/code into the lib dir because the original directory
89+
# was flattened by the SDK training job upload..
90+
shutil.copytree("/opt/ml/code", lib_dir)
91+
break
92+
93+
# copy the "src" dir, which includes the previous training job's model and the
94+
# custom inference script, to the output of this training job
95+
copy_tree(src_dir, "/opt/ml/model")
96+
97+
98+
if __name__ == "__main__": # pragma: no cover
99+
parser = argparse.ArgumentParser()
100+
parser.add_argument("--inference_script", type=str, default="inference.py")
101+
parser.add_argument("--dependencies", type=str, default=None)
102+
parser.add_argument("--source_dir", type=str, default=None)
103+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
104+
args, extra = parser.parse_known_args()
105+
repack(
106+
inference_script=args.inference_script,
107+
dependencies=args.dependencies,
108+
source_dir=args.source_dir,
109+
model_archive=args.model_archive,
110+
)

tests/integ/sagemaker/workflow/test_training_steps.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def test_training_job_with_debugger_and_profiler(
6060
):
6161
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
6262
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
63+
output_path = ParameterString(
64+
name="OutputPath", default_value=f"s3://{sagemaker_session.default_bucket()}/test/"
65+
)
6366

6467
rules = [
6568
Rule.sagemaker(rule_configs.vanishing_gradient()),
@@ -88,6 +91,7 @@ def test_training_job_with_debugger_and_profiler(
8891
sagemaker_session=sagemaker_session,
8992
rules=rules,
9093
debugger_hook_config=debugger_hook_config,
94+
output_path=output_path,
9195
)
9296

9397
step_train = TrainingStep(
@@ -98,7 +102,7 @@ def test_training_job_with_debugger_and_profiler(
98102

99103
pipeline = Pipeline(
100104
name=pipeline_name,
101-
parameters=[instance_count, instance_type],
105+
parameters=[instance_count, instance_type, output_path],
102106
steps=[step_train],
103107
sagemaker_session=sagemaker_session,
104108
)

tests/unit/sagemaker/workflow/test_training_step.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,21 @@
1313
# language governing permissions and limitations under the License.
1414
from __future__ import absolute_import
1515

16+
<<<<<<< HEAD
17+
=======
18+
import os
19+
>>>>>>> support-estimator-output-param
1620
import json
1721

1822
import pytest
1923
import sagemaker
2024
import warnings
2125

2226
from sagemaker.workflow.pipeline_context import PipelineSession
27+
<<<<<<< HEAD
28+
=======
29+
from sagemaker.workflow.parameters import ParameterString
30+
>>>>>>> support-estimator-output-param
2331

2432
from sagemaker.workflow.steps import TrainingStep
2533
from sagemaker.workflow.pipeline import Pipeline
@@ -46,12 +54,14 @@
4654
from sagemaker.amazon.ntm import NTM
4755
from sagemaker.amazon.object2vec import Object2Vec
4856

57+
from tests.integ import DATA_DIR
4958

5059
from sagemaker.inputs import TrainingInput
5160

5261
REGION = "us-west-2"
5362
IMAGE_URI = "fakeimage"
5463
MODEL_NAME = "gisele"
64+
DUMMY_LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
5565
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
5666
DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/"
5767
INSTANCE_TYPE = "ml.m4.xlarge"
@@ -119,6 +129,36 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
119129
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
120130

121131

132+
def test_estimator_with_parameterized_output(pipeline_session, training_input):
133+
output_path = ParameterString(name="OutputPath")
134+
estimator = XGBoost(
135+
framework_version="1.3-1",
136+
py_version="py3",
137+
role=sagemaker.get_execution_role(),
138+
instance_type=INSTANCE_TYPE,
139+
instance_count=1,
140+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
141+
output_path=output_path,
142+
sagemaker_session=pipeline_session,
143+
)
144+
step_args = estimator.fit(inputs=training_input)
145+
step = TrainingStep(
146+
name="MyTrainingStep",
147+
step_args=step_args,
148+
description="TrainingStep description",
149+
display_name="MyTrainingStep",
150+
)
151+
pipeline = Pipeline(
152+
name="MyPipeline",
153+
steps=[step],
154+
sagemaker_session=pipeline_session,
155+
)
156+
step_def = json.loads(pipeline.definition())["Steps"][0]
157+
assert step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] == {
158+
"Get": "Parameters.OutputPath"
159+
}
160+
161+
122162
@pytest.mark.parametrize(
123163
"estimator",
124164
[
@@ -128,23 +168,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
128168
instance_type=INSTANCE_TYPE,
129169
instance_count=1,
130170
role=sagemaker.get_execution_role(),
131-
entry_point="entry_point.py",
171+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
132172
),
133173
PyTorch(
134174
role=sagemaker.get_execution_role(),
135175
instance_type=INSTANCE_TYPE,
136176
instance_count=1,
137177
framework_version="1.8.0",
138178
py_version="py36",
139-
entry_point="entry_point.py",
179+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
140180
),
141181
TensorFlow(
142182
role=sagemaker.get_execution_role(),
143183
instance_type=INSTANCE_TYPE,
144184
instance_count=1,
145185
framework_version="2.0",
146186
py_version="py3",
147-
entry_point="entry_point.py",
187+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
148188
),
149189
HuggingFace(
150190
transformers_version="4.6",
@@ -153,23 +193,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
153193
instance_type="ml.p3.2xlarge",
154194
instance_count=1,
155195
py_version="py36",
156-
entry_point="entry_point.py",
196+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
157197
),
158198
XGBoost(
159199
framework_version="1.3-1",
160200
py_version="py3",
161201
role=sagemaker.get_execution_role(),
162202
instance_type=INSTANCE_TYPE,
163203
instance_count=1,
164-
entry_point="entry_point.py",
204+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
165205
),
166206
MXNet(
167207
framework_version="1.4.1",
168208
py_version="py3",
169209
role=sagemaker.get_execution_role(),
170210
instance_type=INSTANCE_TYPE,
171211
instance_count=1,
172-
entry_point="entry_point.py",
212+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
173213
),
174214
RLEstimator(
175215
entry_point="cartpole.py",
@@ -182,7 +222,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
182222
),
183223
Chainer(
184224
role=sagemaker.get_execution_role(),
185-
entry_point="entry_point.py",
225+
entry_point=DUMMY_LOCAL_SCRIPT_PATH,
186226
use_mpi=True,
187227
num_processes=4,
188228
framework_version="5.0.0",

0 commit comments

Comments
 (0)