-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: add git_config and git_clone, validate method #832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
10d27c5
db8652c
6b78ed4
e59bb79
7808faa
2783c4a
db3b69f
f397850
5b8d684
a9e2932
241ac92
a81859a
c39c344
068a7b1
2b1622b
28a5c58
0797060
e2e5c20
c6daa5d
e8bede0
e5bd806
c1bae10
2af9b24
e15a22d
b102563
9ae910e
3383bfc
9a7f4e1
d4bb0bb
e6a01f0
0c5e32b
b6e75d0
3621bd4
c7af978
c2f7a43
0790f41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.6.8 | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import os | ||
import subprocess | ||
import tempfile | ||
|
||
|
||
def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): | ||
"""Git clone repo containing the training code and serving code. This method also validate ``git_config``, | ||
and set ``entry_point``, ``source_dir`` and ``dependencies`` to the right file or directory in the repo cloned. | ||
|
||
Args: | ||
git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` | ||
and ``commit``. ``branch`` and ``commit`` are optional. If ``branch`` is not specified, master branch | ||
will be used. If ``commit`` is not specified, the latest commit in the required branch will be used. | ||
entry_point (str): A relative location to the Python source file which should be executed as the entry point | ||
to training or model hosting in the Git repo. | ||
source_dir (str): A relative location to a directory with other training or model hosting source code | ||
dependencies aside from the entry point file in the Git repo (default: None). Structure within this | ||
directory are preserved when training on Amazon SageMaker. | ||
dependencies (list[str]): A list of relative locations to directories with any additional libraries that will | ||
be exported to the container in the Git repo (default: []). | ||
|
||
Raises: | ||
CalledProcessError: If 1. failed to clone git repo | ||
2. failed to checkout the required branch | ||
3. failed to checkout the required commit | ||
ValueError: If 1. entry point specified does not exist in the repo | ||
2. source dir specified does not exist in the repo | ||
|
||
Returns: | ||
dict: A dict that contains the updated values of entry_point, source_dir and dependencies | ||
""" | ||
_validate_git_config(git_config) | ||
repo_dir = tempfile.mkdtemp() | ||
subprocess.check_call(['git', 'clone', git_config['repo'], repo_dir]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's do a shallow clone of the repo - if the git repository is large (e.g. is managing large binaries, has a lot of branches), then the performance of doing a full clone will be unnecessarily slow for what's needed here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. talked offline - this can happen in a subsequent PR |
||
|
||
_checkout_branch_and_commit(git_config, repo_dir) | ||
|
||
ret = {'entry_point': entry_point, 'source_dir': source_dir, 'dependencies': dependencies} | ||
# check if the cloned repo contains entry point, source directory and dependencies | ||
if source_dir: | ||
if not os.path.isdir(os.path.join(repo_dir, source_dir)): | ||
raise ValueError('Source directory does not exist in the repo.') | ||
if not os.path.isfile(os.path.join(repo_dir, source_dir, entry_point)): | ||
raise ValueError('Entry point does not exist in the repo.') | ||
ret['source_dir'] = os.path.join(repo_dir, source_dir) | ||
else: | ||
if not os.path.isfile(os.path.join(repo_dir, entry_point)): | ||
raise ValueError('Entry point does not exist in the repo.') | ||
ret['entry_point'] = os.path.join(repo_dir, entry_point) | ||
|
||
ret['dependencies'] = [] | ||
for path in dependencies: | ||
if not os.path.exists(os.path.join(repo_dir, path)): | ||
raise ValueError('Dependency {} does not exist in the repo.'.format(path)) | ||
ret['dependencies'].append(os.path.join(repo_dir, path)) | ||
return ret | ||
|
||
|
||
def _validate_git_config(git_config): | ||
"""check if a git_config param is valid | ||
|
||
Args: | ||
git_config ((dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` | ||
and ``commit``. | ||
|
||
Raises: | ||
ValueError: If: | ||
1. git_config has no key 'repo' | ||
2. git_config['repo'] is in the wrong format. | ||
""" | ||
if 'repo' not in git_config: | ||
raise ValueError('Please provide a repo for git_config.') | ||
|
||
|
||
def _checkout_branch_and_commit(git_config, repo_dir): | ||
"""Checkout the required branch and commit. | ||
|
||
Args: | ||
git_config: (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` | ||
and ``commit``. | ||
repo_dir (str): the directory where the repo is cloned | ||
|
||
Raises: | ||
ValueError: If 1. entry point specified does not exist in the repo | ||
2. source dir specified does not exist in the repo | ||
""" | ||
if 'branch' in git_config: | ||
subprocess.check_call(args=['git', 'checkout', git_config['branch']], cwd=str(repo_dir)) | ||
if 'commit' in git_config: | ||
subprocess.check_call(args=['git', 'checkout', git_config['commit']], cwd=str(repo_dir)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import os | ||
|
||
import numpy | ||
|
||
from sagemaker.mxnet.estimator import MXNet | ||
from sagemaker.pytorch.estimator import PyTorch | ||
from tests.integ import DATA_DIR, PYTHON_VERSION | ||
|
||
GIT_REPO = 'https://github.com/aws/sagemaker-python-sdk.git' | ||
BRANCH = 'test-branch-git-config' | ||
COMMIT = '329bfcf884482002c05ff7f44f62599ebc9f445a' | ||
|
||
|
||
def test_git_support_with_pytorch(sagemaker_local_session): | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
script_path = 'mnist.py' | ||
data_path = os.path.join(DATA_DIR, 'pytorch_mnist') | ||
git_config = {'repo': GIT_REPO, 'branch': BRANCH, 'commit': COMMIT} | ||
pytorch = PyTorch(entry_point=script_path, role='SageMakerRole', source_dir='pytorch', | ||
framework_version=PyTorch.LATEST_VERSION, py_version=PYTHON_VERSION, | ||
train_instance_count=1, train_instance_type='local', | ||
sagemaker_session=sagemaker_local_session, git_config=git_config) | ||
|
||
pytorch.fit({'training': 'file://' + os.path.join(data_path, 'training')}) | ||
|
||
try: | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
predictor = pytorch.deploy(initial_instance_count=1, instance_type='local') | ||
|
||
data = numpy.zeros(shape=(1, 1, 28, 28)).astype(numpy.float32) | ||
result = predictor.predict(data) | ||
assert result is not None | ||
finally: | ||
predictor.delete_endpoint() | ||
|
||
|
||
def test_git_support_with_mxnet(sagemaker_local_session, mxnet_full_version): | ||
script_path = 'mnist.py' | ||
data_path = os.path.join(DATA_DIR, 'mxnet_mnist') | ||
git_config = {'repo': GIT_REPO, 'branch': BRANCH, 'commit': COMMIT} | ||
dependencies = ['foo/bar.py'] | ||
mx = MXNet(entry_point=script_path, role='SageMakerRole', | ||
source_dir='mxnet', dependencies=dependencies, | ||
framework_version=MXNet.LATEST_VERSION, py_version=PYTHON_VERSION, | ||
train_instance_count=1, train_instance_type='local', | ||
sagemaker_session=sagemaker_local_session, git_config=git_config) | ||
|
||
mx.fit({'train': 'file://' + os.path.join(data_path, 'train'), | ||
'test': 'file://' + os.path.join(data_path, 'test')}) | ||
|
||
files = [file for file in os.listdir(mx.source_dir)] | ||
assert 'some_file' in files | ||
assert 'mnist.py' in files | ||
assert os.path.exists(mx.dependencies[0]) | ||
|
||
try: | ||
predictor = mx.deploy(initial_instance_count=1, instance_type='local') | ||
|
||
data = numpy.zeros(shape=(1, 1, 28, 28)) | ||
result = predictor.predict(data) | ||
assert result is not None | ||
finally: | ||
predictor.delete_endpoint() | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.