Skip to content

fix: Refactor repack_model script injection, fixes tar.gz error #3039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
from typing import List, Union
from sagemaker import image_uris
from sagemaker.inputs import TrainingInput
from sagemaker.s3 import (
S3Downloader,
S3Uploader,
)
from sagemaker.estimator import EstimatorBase
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.workflow.entities import RequestType
Expand All @@ -35,6 +31,7 @@
Step,
ConfigurableRetryStep,
)
from sagemaker.utils import _save_model, download_file_from_url
from sagemaker.workflow.retry import RetryPolicy

FRAMEWORK_VERSION = "0.23-1"
Expand Down Expand Up @@ -203,40 +200,36 @@ def _establish_source_dir(self):
self._entry_point = self._entry_point_basename

def _inject_repack_script(self):
"""Injects the _repack_model.py script where it belongs.
"""Injects the _repack_model.py script into S3 or local source directory.

If the source_dir is an S3 path:
1) downloads the source_dir tar.gz
2) copies the _repack_model.py script where it belongs
3) uploads the mutated source_dir
2) extracts it
3) copies the _repack_model.py script into the extracted directory
4) rezips the directory
5) overwrites the S3 source_dir with the new tar.gz

If the source_dir is a local path:
1) copies the _repack_model.py script into the source dir
"""
fname = os.path.join(os.path.dirname(__file__), REPACK_SCRIPT)
if self._source_dir.lower().startswith("s3://"):
with tempfile.TemporaryDirectory() as tmp:
local_path = os.path.join(tmp, "local.tar.gz")

S3Downloader.download(
s3_uri=self._source_dir,
local_path=local_path,
sagemaker_session=self.sagemaker_session,
)

src_dir = os.path.join(tmp, "src")
with tarfile.open(name=local_path, mode="r:gz") as tf:
tf.extractall(path=src_dir)

shutil.copy2(fname, os.path.join(src_dir, REPACK_SCRIPT))
with tarfile.open(name=local_path, mode="w:gz") as tf:
tf.add(src_dir, arcname=".")

S3Uploader.upload(
local_path=local_path,
desired_s3_uri=self._source_dir,
sagemaker_session=self.sagemaker_session,
)
targz_contents_dir = os.path.join(tmp, "extracted")

old_targz_path = os.path.join(tmp, "old.tar.gz")
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)

with tarfile.open(name=old_targz_path, mode="r:gz") as t:
t.extractall(path=targz_contents_dir)

shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))

new_targz_path = os.path.join(tmp, "new.tar.gz")
with tarfile.open(new_targz_path, mode="w:gz") as t:
t.add(targz_contents_dir, arcname=os.path.sep)

_save_model(self._source_dir, new_targz_path, self.sagemaker_session, kms_key=None)
else:
shutil.copy2(fname, os.path.join(self._source_dir, REPACK_SCRIPT))

Expand Down
55 changes: 55 additions & 0 deletions tests/unit/sagemaker/workflow/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sagemaker.estimator import Estimator
from sagemaker.workflow import Properties
from sagemaker.workflow._utils import _RepackModelStep
from tests.unit.test_utils import FakeS3, list_tar_files
from tests.unit import DATA_DIR

REGION = "us-west-2"
Expand Down Expand Up @@ -210,3 +211,57 @@ def test_repack_model_step_with_source_dir(estimator, source_dir):
assert step.properties.TrainingJobName.expr == {
"Get": "Steps.MyRepackModelStep.TrainingJobName"
}


@pytest.fixture()
def tmp(tmpdir):
yield str(tmpdir)


@pytest.fixture()
def fake_s3(tmp):
return FakeS3(tmp)


def test_inject_repack_script_s3(estimator, tmp, fake_s3):

create_file_tree(
tmp,
[
"model-dir/aa",
"model-dir/foo/inference.py",
],
)

model_data = Properties(path="Steps.MyStep", shape_name="DescribeModelOutput")
entry_point = "inference.py"
source_dir_path = "s3://fake/location"
step = _RepackModelStep(
name="MyRepackModelStep",
sagemaker_session=fake_s3.sagemaker_session,
role=estimator.role,
image_uri="foo",
model_data=model_data,
entry_point=entry_point,
source_dir=source_dir_path,
)

fake_s3.tar_and_upload("model-dir", "s3://fake/location")

step._inject_repack_script()

assert list_tar_files(fake_s3.fake_upload_path, tmp) == {
"/aa",
"/foo/inference.py",
"/_repack_model.py",
}


def create_file_tree(root, tree):
for file in tree:
try:
os.makedirs(os.path.join(root, os.path.dirname(file)))
except: # noqa: E722 Using bare except because p2/3 incompatibility issues.
pass
with open(os.path.join(root, file), "a") as f:
f.write(file)