Skip to content

Commit 5b16b4f

Browse files
author
Payton Staub
committed
fix: fix _repack_model script used in pipelines not supporting dependencies or source_dir arguments, which are used for script-mode in models
1 parent e3c4a59 commit 5b16b4f

File tree

3 files changed

+226
-13
lines changed

3 files changed

+226
-13
lines changed

src/sagemaker/workflow/_repack_model.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,11 @@
3333
# repacking is some short-lived hackery, right??
3434
from distutils.dir_util import copy_tree
3535

36-
37-
if __name__ == "__main__":
38-
parser = argparse.ArgumentParser()
39-
parser.add_argument("--inference_script", type=str, default="inference.py")
40-
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
41-
args = parser.parse_args()
36+
def repack(inference_script, model_archive, dependencies=None, source_dir=None):
4237

4338
# the data directory contains a model archive generated by a previous training job
4439
data_directory = "/opt/ml/input/data/training"
45-
model_path = os.path.join(data_directory, args.model_archive)
40+
model_path = os.path.join(data_directory, model_archive)
4641

4742
# create a temporary directory
4843
with tempfile.TemporaryDirectory() as tmp:
@@ -51,17 +46,55 @@
5146
shutil.copy2(model_path, local_path)
5247
src_dir = os.path.join(tmp, "src")
5348
# create the "code" directory which will contain the inference script
54-
os.makedirs(os.path.join(src_dir, "code"))
49+
code_dir = os.path.join(src_dir, "code")
50+
os.makedirs(code_dir)
5551
# extract the contents of the previous training job's model archive to the "src"
5652
# directory of this training job
5753
with tarfile.open(name=local_path, mode="r:gz") as tf:
5854
tf.extractall(path=src_dir)
5955

60-
# generate a path to the custom inference script
61-
entry_point = os.path.join("/opt/ml/code", args.inference_script)
62-
# copy the custom inference script to the "src" dir
63-
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
56+
# copy the custom inference script to code/
57+
entry_point = os.path.join("/opt/ml/code", inference_script)
58+
shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script))
59+
60+
# copy source_dir to code/
61+
if source_dir:
62+
actual_source_dir_path = os.path.join("/opt/ml/code", source_dir)
63+
if os.path.exists(actual_source_dir_path):
64+
for item in os.listdir(actual_source_dir_path):
65+
s = os.path.join(actual_source_dir_path, item)
66+
d = os.path.join(os.path.join(code_dir, source_dir), item)
67+
if os.path.isdir(s):
68+
shutil.copytree(s, d)
69+
else:
70+
shutil.copy2(s, d)
71+
72+
# copy any dependencies to code/lib/
73+
if dependencies:
74+
for dependency in dependencies.split(' '):
75+
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
76+
lib_dir = os.path.join(code_dir, "lib")
77+
if not os.path.exists(lib_dir):
78+
os.mkdir(lib_dir)
79+
if os.path.isdir(actual_dependency_path):
80+
shutil.copytree(actual_dependency_path, os.path.join(lib_dir, os.path.basename(actual_dependency_path)))
81+
else:
82+
shutil.copy2(actual_dependency_path, lib_dir)
6483

6584
# copy the "src" dir, which includes the previous training job's model and the
6685
# custom inference script, to the output of this training job
67-
copy_tree(src_dir, "/opt/ml/model")
86+
copy_tree(src_dir, "/opt/ml/model")
87+
88+
if __name__ == "__main__":
89+
parser = argparse.ArgumentParser()
90+
parser.add_argument("--inference_script", type=str, default="inference.py")
91+
parser.add_argument("--dependencies", type=str, default=None)
92+
parser.add_argument("--source_dir", type=str, default=None)
93+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
94+
args, extra = parser.parse_known_args()
95+
repack(
96+
inference_script=args.inference_script,
97+
dependencies=args.dependencies,
98+
source_dir=args.source_dir,
99+
model_archive=args.model_archive
100+
)

src/sagemaker/workflow/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ def __init__(
145145
self._source_dir = source_dir
146146
self._dependencies = dependencies
147147

148+
# convert dependencies array into space-delimited string
149+
dependencies_hyperparameter = None
150+
if self._dependencies:
151+
dependencies_hyperparameter = ' '.join(self._dependencies)
152+
148153
# the real estimator and inputs
149154
repacker = SKLearn(
150155
framework_version=FRAMEWORK_VERSION,
@@ -157,6 +162,8 @@ def __init__(
157162
hyperparameters={
158163
"inference_script": self._entry_point_basename,
159164
"model_archive": self._model_archive,
165+
"dependencies": dependencies_hyperparameter,
166+
"source_dir": self._source_dir
160167
},
161168
subnets=subnets,
162169
security_group_ids=security_group_ids,
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# language governing permissions and limitations under the License.
14+
from __future__ import absolute_import
15+
from sagemaker.workflow import _repack_model
16+
17+
from pathlib import Path
18+
import shutil
19+
import tarfile
20+
import os
21+
import pytest
22+
import time
23+
24+
@pytest.mark.skip(reason="""This test operates on the root file system
25+
and will likely fail due to permission errors.
26+
Temporarily remove this skip decorator and run
27+
the test after making changes to _repack_model.py""")
28+
def test_repack_entry_point_only(tmp):
29+
model_name = "xg-boost-model"
30+
fake_model_path = os.path.join(tmp, model_name)
31+
32+
# create a fake model
33+
open(fake_model_path, "w")
34+
35+
# create model.tar.gz
36+
model_tar_name = "model-%s.tar.gz" % time.time()
37+
model_tar_location = os.path.join(tmp, model_tar_name)
38+
with tarfile.open(model_tar_location, mode="w:gz") as t:
39+
t.add(fake_model_path, arcname=model_name)
40+
41+
# move model.tar.gz to /opt/ml/input/data/training
42+
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
43+
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))
44+
45+
# create files that will be added to model.tar.gz
46+
create_file_tree(
47+
"/opt/ml/code",
48+
[
49+
"inference.py",
50+
],
51+
)
52+
53+
# repack
54+
_repack_model.repack(
55+
inference_script="inference.py",
56+
model_archive=model_tar_name
57+
)
58+
59+
# /opt/ml/model should now have the original model and the inference script
60+
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
61+
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))
62+
63+
@pytest.mark.skip(reason="""This test operates on the root file system
64+
and will likely fail due to permission errors.
65+
Temporarily remove this skip decorator and run
66+
the test after making changes to _repack_model.py""")
67+
def test_repack_with_dependencies(tmp):
68+
model_name = "xg-boost-model"
69+
fake_model_path = os.path.join(tmp, model_name)
70+
71+
# create a fake model
72+
open(fake_model_path, "w")
73+
74+
# create model.tar.gz
75+
model_tar_name = "model-%s.tar.gz" % time.time()
76+
model_tar_location = os.path.join(tmp, model_tar_name)
77+
with tarfile.open(model_tar_location, mode="w:gz") as t:
78+
t.add(fake_model_path, arcname=model_name)
79+
80+
# move model.tar.gz to /opt/ml/input/data/training
81+
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
82+
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))
83+
84+
# create files that will be added to model.tar.gz
85+
create_file_tree(
86+
"/opt/ml/code",
87+
[
88+
"inference.py",
89+
"dependencies/a",
90+
"bb",
91+
"dependencies/some/dir/b"
92+
],
93+
)
94+
95+
# repack
96+
_repack_model.repack(
97+
inference_script="inference.py",
98+
model_archive=model_tar_name,
99+
dependencies=["dependencies/a", "bb", "dependencies/some/dir"]
100+
)
101+
102+
# /opt/ml/model should now have the original model and the inference script
103+
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
104+
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))
105+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a"))
106+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb"))
107+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b"))
108+
109+
@pytest.mark.skip(reason="""This test operates on the root file system
110+
and will likely fail due to permission errors.
111+
Temporarily remove this skip decorator and run
112+
the test after making changes to _repack_model.py""")
113+
def test_repack_with_source_dir_and_dependencies(tmp):
114+
model_name = "xg-boost-model"
115+
fake_model_path = os.path.join(tmp, model_name)
116+
117+
# create a fake model
118+
open(fake_model_path, "w")
119+
120+
# create model.tar.gz
121+
model_tar_name = "model-%s.tar.gz" % time.time()
122+
model_tar_location = os.path.join(tmp, model_tar_name)
123+
with tarfile.open(model_tar_location, mode="w:gz") as t:
124+
t.add(fake_model_path, arcname=model_name)
125+
126+
# move model.tar.gz to /opt/ml/input/data/training
127+
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
128+
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))
129+
130+
# create files that will be added to model.tar.gz
131+
create_file_tree(
132+
"/opt/ml/code",
133+
[
134+
"inference.py",
135+
"dependencies/a",
136+
"bb",
137+
"dependencies/some/dir/b",
138+
"sourcedir/foo.py",
139+
"sourcedir/some/dir/a"
140+
],
141+
)
142+
143+
# repack
144+
_repack_model.repack(
145+
inference_script="inference.py",
146+
model_archive=model_tar_name,
147+
dependencies=["dependencies/a", "bb", "dependencies/some/dir"],
148+
source_dir="sourcedir"
149+
)
150+
151+
# /opt/ml/model should now have the original model and the inference script
152+
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
153+
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))
154+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a"))
155+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb"))
156+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b"))
157+
assert os.path.exists(os.path.join("/opt/ml/model/code/sourcedir", "foo.py"))
158+
assert os.path.exists(os.path.join("/opt/ml/model/code/sourcedir/some/dir", "a"))
159+
160+
161+
def create_file_tree(root, tree):
162+
for file in tree:
163+
try:
164+
os.makedirs(os.path.join(root, os.path.dirname(file)))
165+
except: # noqa: E722 Using bare except because p2/3 incompatibility issues.
166+
pass
167+
with open(os.path.join(root, file), "a") as f:
168+
f.write(file)
169+
170+
171+
@pytest.fixture()
172+
def tmp(tmpdir):
173+
yield str(tmpdir)

0 commit comments

Comments
 (0)