-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Support for TFS preprocessing #797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
ee61563
Add sagemaker.utils.repack function
mvsusp 83a1ff0
Fix flake8
mvsusp 012f16a
Fix flake8
mvsusp c354878
Fix flake8
mvsusp fde4d9f
Handle PR comments
mvsusp 3bbc3bb
Fix flake8
mvsusp eeec58d
Handle PR comments
mvsusp abd0d77
Fix integ test
mvsusp 1a96e0e
Fix integ test
mvsusp 82d5366
Fix PR comments
mvsusp ada8b2a
Merge remote-tracking branch 'origin/master' into mvs-tfs
mvsusp 5392854
Fix PR comments
mvsusp 09bd185
Fix unit tests
mvsusp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,20 +12,24 @@ | |
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import contextlib | ||
import errno | ||
import os | ||
import random | ||
import re | ||
import shutil | ||
import sys | ||
import tarfile | ||
import tempfile | ||
import time | ||
|
||
from datetime import datetime | ||
from functools import wraps | ||
from six.moves.urllib import parse | ||
|
||
import six | ||
|
||
import sagemaker | ||
|
||
ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$' | ||
|
||
|
@@ -258,13 +262,10 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): | |
|
||
def create_tar_file(source_files, target=None): | ||
"""Create a tar file containing all the source_files | ||
|
||
Args: | ||
source_files (List[str]): List of file paths that will be contained in the tar file | ||
|
||
Returns: | ||
(str): path to created tar file | ||
|
||
""" | ||
if target: | ||
filename = target | ||
|
@@ -278,6 +279,100 @@ def create_tar_file(source_files, target=None): | |
return filename | ||
|
||
|
||
@contextlib.contextmanager | ||
def _tmpdir(suffix='', prefix='tmp'): | ||
"""Create a temporary directory with a context manager. The file is deleted when the context exits. | ||
|
||
The prefix, suffix, and dir arguments are the same as for mkstemp(). | ||
|
||
Args: | ||
suffix (str): If suffix is specified, the file name will end with that suffix, otherwise there will be no | ||
suffix. | ||
prefix (str): If prefix is specified, the file name will begin with that prefix; otherwise, | ||
a default prefix is used. | ||
dir (str): If dir is specified, the file will be created in that directory; otherwise, a default directory is | ||
used. | ||
Returns: | ||
str: path to the directory | ||
""" | ||
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=None) | ||
yield tmp | ||
shutil.rmtree(tmp) | ||
|
||
|
||
def repack_model(inference_script, source_directory, model_uri, sagemaker_session): | ||
"""Unpack model tarball and creates a new model tarball with the provided code script. | ||
|
||
This function does the following: | ||
- uncompresses model tarball from S3 or local system into a temp folder | ||
- replaces the inference code from the model with the new code provided | ||
- compresses the new model tarball and saves it in S3 or local file system | ||
|
||
Args: | ||
inference_script (str): path or basename of the inference script that will be packed into the model | ||
source_directory (str): path including all the files that will be packed into the model | ||
model_uri (str): S3 or file system location of the original model tar | ||
sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3. | ||
|
||
Returns: | ||
str: path to the new packed model | ||
""" | ||
new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp() | ||
|
||
with _tmpdir() as tmp: | ||
tmp_model_dir = os.path.join(tmp, 'model') | ||
os.mkdir(tmp_model_dir) | ||
|
||
model_from_s3 = model_uri.startswith('s3://') | ||
if model_from_s3: | ||
local_model_path = os.path.join(tmp, 'tar_file') | ||
download_file_from_url(model_uri, local_model_path, sagemaker_session) | ||
|
||
new_model_path = os.path.join(tmp, new_model_name) | ||
else: | ||
local_model_path = model_uri.replace('file://', '') | ||
new_model_path = os.path.join(os.path.dirname(local_model_path), new_model_name) | ||
|
||
with tarfile.open(name=local_model_path, mode='r:gz') as t: | ||
t.extractall(path=tmp_model_dir) | ||
|
||
code_dir = os.path.join(tmp_model_dir, 'code') | ||
if os.path.exists(code_dir): | ||
shutil.rmtree(code_dir, ignore_errors=True) | ||
|
||
dirname = source_directory if source_directory else os.path.dirname(inference_script) | ||
|
||
shutil.copytree(dirname, code_dir) | ||
|
||
with tarfile.open(new_model_path, mode='w:gz') as t: | ||
t.add(tmp_model_dir, arcname=os.path.sep) | ||
|
||
if model_from_s3: | ||
url = parse.urlparse(model_uri) | ||
bucket, key = url.netloc, url.path.lstrip('/') | ||
new_key = key.replace(os.path.basename(key), new_model_name) | ||
|
||
sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file(new_model_path) | ||
return 's3://%s/%s' % (bucket, new_key) | ||
else: | ||
return 'file://%s' % new_model_path | ||
|
||
|
||
def _list_files(script, directory): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this still needed? it doesn't seem to be called anywhere anymore. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call! |
||
if directory is None: | ||
return [script] | ||
|
||
basedir = directory if directory else os.path.dirname(script) | ||
return [os.path.join(basedir, name) for name in os.listdir(basedir)] | ||
|
||
|
||
def download_file_from_url(url, dst, sagemaker_session): | ||
url = parse.urlparse(url) | ||
bucket, key = url.netloc, url.path.lstrip('/') | ||
|
||
download_file(bucket, key, dst, sagemaker_session) | ||
|
||
|
||
def download_file(bucket_name, path, target, sagemaker_session): | ||
"""Download a Single File from S3 into a local path | ||
|
||
|
1 change: 1 addition & 0 deletions
1
tests/data/tfs/tfs-test-model-with-inference/00000123/assets/foo.txt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
asset-file-contents |
Binary file not shown.
Binary file added
BIN
+12 Bytes
...s/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.data-00000-of-00001
Binary file not shown.
Binary file added
BIN
+151 Bytes
tests/data/tfs/tfs-test-model-with-inference/00000123/variables/variables.index
Binary file not shown.
26 changes: 26 additions & 0 deletions
26
tests/data/tfs/tfs-test-model-with-inference/code/inference.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import json | ||
|
||
|
||
def input_handler(data, context): | ||
data = json.loads(data.read().decode('utf-8')) | ||
new_values = [x + 1 for x in data['instances']] | ||
dumps = json.dumps({'instances': new_values}) | ||
return dumps | ||
|
||
|
||
def output_handler(data, context): | ||
response_content_type = context.accept_header | ||
prediction = data.content | ||
return prediction, response_content_type |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't be necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned above why we need to list the files instead of passing the entire folder. Another thing that this function does is dealing with the use case which directory is None and only the entry point should be copied to the tarball.