Skip to content

Commit f33a8bc

Browse files
committed
make test more robust
1 parent 14a89eb commit f33a8bc

File tree

2 files changed

+118
-1
lines changed

2 files changed

+118
-1
lines changed

tests/data/_repack_model.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Repack model script for training jobs to inject entry points"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import os
18+
import shutil
19+
import tarfile
20+
import tempfile
21+
22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
29+
# distutils.dir_util.copy_tree works way better than the half-baked
30+
# shutil.copytree which bombs on previously existing target dirs...
31+
# alas ... https://bugs.python.org/issue10948
32+
# we'll go ahead and use the copy_tree function anyways because this
33+
# repacking is some short-lived hackery, right??
34+
from distutils.dir_util import copy_tree
35+
36+
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
38+
"""Repack custom dependencies and code into an existing model TAR archive
39+
40+
Args:
41+
inference_script (str): The path to the custom entry point.
42+
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
43+
dependencies (str): A space-delimited string of paths to custom dependencies.
44+
source_dir (str): The path to a custom source directory.
45+
"""
46+
47+
# the data directory contains a model archive generated by a previous training job
48+
data_directory = "/opt/ml/input/data/training"
49+
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
50+
51+
# create a temporary directory
52+
with tempfile.TemporaryDirectory() as tmp:
53+
local_path = os.path.join(tmp, "local.tar.gz")
54+
# copy the previous training job's model archive to the temporary directory
55+
shutil.copy2(model_path, local_path)
56+
src_dir = os.path.join(tmp, "src")
57+
# create the "code" directory which will contain the inference script
58+
code_dir = os.path.join(src_dir, "code")
59+
os.makedirs(code_dir)
60+
# extract the contents of the previous training job's model archive to the "src"
61+
# directory of this training job
62+
with tarfile.open(name=local_path, mode="r:gz") as tf:
63+
tf.extractall(path=src_dir)
64+
65+
if source_dir:
66+
# copy /opt/ml/code to code/
67+
if os.path.exists(code_dir):
68+
shutil.rmtree(code_dir)
69+
shutil.copytree("/opt/ml/code", code_dir)
70+
else:
71+
# copy the custom inference script to code/
72+
entry_point = os.path.join("/opt/ml/code", inference_script)
73+
shutil.copy2(entry_point, os.path.join(code_dir, inference_script))
74+
75+
# copy any dependencies to code/lib/
76+
if dependencies:
77+
for dependency in dependencies.split(" "):
78+
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
79+
lib_dir = os.path.join(code_dir, "lib")
80+
if not os.path.exists(lib_dir):
81+
os.mkdir(lib_dir)
82+
if os.path.isfile(actual_dependency_path):
83+
shutil.copy2(actual_dependency_path, lib_dir)
84+
else:
85+
if os.path.exists(lib_dir):
86+
shutil.rmtree(lib_dir)
87+
# a directory is in the dependencies. we have to copy
88+
# all of /opt/ml/code into the lib dir because the original directory
89+
# was flattened by the SDK training job upload..
90+
shutil.copytree("/opt/ml/code", lib_dir)
91+
break
92+
93+
# copy the "src" dir, which includes the previous training job's model and the
94+
# custom inference script, to the output of this training job
95+
copy_tree(src_dir, "/opt/ml/model")
96+
97+
98+
if __name__ == "__main__": # pragma: no cover
99+
parser = argparse.ArgumentParser()
100+
parser.add_argument("--inference_script", type=str, default="inference.py")
101+
parser.add_argument("--dependencies", type=str, default=None)
102+
parser.add_argument("--source_dir", type=str, default=None)
103+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
104+
args, extra = parser.parse_known_args()
105+
repack(
106+
inference_script=args.inference_script,
107+
dependencies=args.dependencies,
108+
source_dir=args.source_dir,
109+
model_archive=args.model_archive,
110+
)

tests/unit/sagemaker/workflow/test_model_step.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,4 +873,11 @@ def test_model_step_with_lambda_property_reference(pipeline_session):
873873
steps=[lambda_step, step_create_model],
874874
sagemaker_session=pipeline_session,
875875
)
876-
assert pipeline.definition() is not None
876+
steps = json.loads(pipeline.definition())["Steps"]
877+
repack_step = steps[1]
878+
assert repack_step['Arguments']['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri'] == \
879+
{'Get': "Steps.MyLambda.OutputParameters['model_artifact']"}
880+
register_step = steps[2]
881+
assert register_step['Arguments']['PrimaryContainer']['Image'] == \
882+
{'Get': "Steps.MyLambda.OutputParameters['model_image']"}
883+

0 commit comments

Comments
 (0)