Skip to content

Commit c35c8b2

Browse files
authored
Merge branch 'master' into change/add_data_wrangler_processor
2 parents fdd66c2 + 544ccac commit c35c8b2

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

src/sagemaker/workflow/_repack_model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
import tarfile
2020
import tempfile
2121

22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
2229
# distutils.dir_util.copy_tree works way better than the half-baked
2330
# shutil.copytree which bombs on previously existing target dirs...
2431
# alas ... https://bugs.python.org/issue10948
@@ -33,17 +40,28 @@
3340
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
3441
args = parser.parse_args()
3542

43+
# the data directory contains a model archive generated by a previous training job
3644
data_directory = "/opt/ml/input/data/training"
3745
model_path = os.path.join(data_directory, args.model_archive)
3846

47+
# create a temporary directory
3948
with tempfile.TemporaryDirectory() as tmp:
4049
local_path = os.path.join(tmp, "local.tar.gz")
50+
# copy the previous training job's model archive to the temporary directory
4151
shutil.copy2(model_path, local_path)
4252
src_dir = os.path.join(tmp, "src")
53+
# create the "code" directory which will contain the inference script
54+
os.makedirs(os.path.join(src_dir, "code"))
55+
# extract the contents of the previous training job's model archive to the "src"
56+
# directory of this training job
4357
with tarfile.open(name=local_path, mode="r:gz") as tf:
4458
tf.extractall(path=src_dir)
4559

60+
# generate a path to the custom inference script
4661
entry_point = os.path.join("/opt/ml/code", args.inference_script)
47-
shutil.copy2(entry_point, os.path.join(src_dir, args.inference_script))
62+
# copy the custom inference script to the "src" dir
63+
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
4864

65+
# copy the "src" dir, which includes the previous training job's model and the
66+
# custom inference script, to the output of this training job
4967
copy_tree(src_dir, "/opt/ml/model")

0 commit comments

Comments
 (0)