-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: add git_config and git_clone, validate method #832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
10d27c5
add git_config and validate method
db8652c
Merge branch 'master' of github.com:aws/sagemaker-python-sdk into clo…
6b78ed4
modify the order of git_config, add tests
e59bb79
move validate_git_config, add integ test
7808faa
modify location _git_clone_code called
2783c4a
add documentation
db3b69f
Merge branch 'master' of github.com:aws/sagemaker-python-sdk into clo…
f397850
Update doc/overview.rst
GaryTu1020 5b8d684
Update doc/overview.rst
GaryTu1020 a9e2932
add more integ tests
241ac92
write unit tests for git_utils
a81859a
fix conflict on overview.rst
c39c344
delete a line
068a7b1
modify an assertion in test_with_mxnet
2b1622b
add assertion to some test functions
28a5c58
remove deploy part in test_git
0797060
change testing git repo
e2e5c20
change the testing repo
c6daa5d
correct an error message
e8bede0
pull master
e5bd806
stop patching private methods
c1bae10
modified overview.rst, add lock for tests
2af9b24
slight change to overview.rst
e15a22d
Merge branch 'master' into clone_from_github
chuyang-deng b102563
add a comment for lock
9ae910e
merge with remote branch
3383bfc
Merge branch 'master' into clone_from_github
GaryTu1020 9a7f4e1
Merge branch 'master' into clone_from_github
chuyang-deng d4bb0bb
merge with master
e6a01f0
merge with master
0c5e32b
merge aws master
b6e75d0
merge with master
3621bd4
merge with master
c7af978
merge with aws master
c2f7a43
merge with aws master
0790f41
Merge branch 'master' into clone_from_github
mvsusp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,8 @@ | |
import json | ||
import logging | ||
import os | ||
import subprocess | ||
import tempfile | ||
import warnings | ||
from abc import ABCMeta | ||
from abc import abstractmethod | ||
|
@@ -774,16 +776,54 @@ class Framework(EstimatorBase): | |
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host' | ||
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options' | ||
|
||
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False, | ||
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs): | ||
def __init__(self, entry_point, source_dir=None, hyperparameters=None, | ||
enable_cloudwatch_metrics=False, container_log_level=logging.INFO, code_location=None, | ||
image_name=None, dependencies=None, git_config=None, **kwargs): | ||
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()`` | ||
|
||
Args: | ||
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed | ||
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. | ||
If 'git_config' is provided, 'entry_point' should be a relative location to the Python source file in | ||
the Git repo. | ||
Example: | ||
|
||
With the following GitHub repo directory structure: | ||
|
||
>>> |----- README.md | ||
>>> |----- src | ||
>>> |----- train.py | ||
>>> |----- test.py | ||
|
||
You can assign entry_point='src/train.py'. | ||
git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch' | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
mvsusp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and 'commit' (default: None). | ||
'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If | ||
'commit' is not specified, the latest commit in the required branch will be used. | ||
Example: | ||
|
||
The following config: | ||
|
||
>>> git_config = {'repo': 'https://github.com/GaryTu1020/python-sdk-testing.git', | ||
>>> 'branch': 'master', | ||
>>> 'commit': 'aea6f3acef9619f77f94772d9d654f041e16bf49'} | ||
|
||
results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout | ||
the specified commit. | ||
source_dir (str): Path (absolute or relative) to a directory with any other training | ||
source code dependencies aside from the entry point file (default: None). Structure within this | ||
directory are preserved when training on Amazon SageMaker. | ||
directory are preserved when training on Amazon SageMaker. If 'git_config' is provided, | ||
source_dir should be a relative location to a directory in the Git repo. | ||
Example: | ||
|
||
With the following GitHub repo directory structure: | ||
|
||
>>> |----- README.md | ||
>>> |----- src | ||
>>> |----- train.py | ||
>>> |----- test.py | ||
|
||
You can assign entry_point='train.py', source_dir='src'. | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dependencies (list[str]): A list of paths to directories (absolute or relative) with | ||
any additional libraries that will be exported to the container (default: []). | ||
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied. | ||
|
@@ -818,9 +858,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl | |
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor. | ||
""" | ||
super(Framework, self).__init__(**kwargs) | ||
|
||
if entry_point.startswith('s3://'): | ||
raise ValueError('Invalid entry point script: {}. Must be a path to a local file.'.format(entry_point)) | ||
self.entry_point = entry_point | ||
self.git_config = git_config | ||
self.source_dir = source_dir | ||
self.dependencies = dependencies or [] | ||
if enable_cloudwatch_metrics: | ||
|
@@ -833,6 +875,82 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl | |
|
||
self._hyperparameters = hyperparameters or {} | ||
|
||
def _git_clone_code(self): | ||
"""Git clone repo containing the training scripts. This method also validate ``git_config``, | ||
and set ``entry_point`` and ``source_dir`` to the right file or directory in the repo cloned. | ||
|
||
Raises: | ||
CalledProcessError: If 1. failed to clone git repo | ||
2. failed to checkout the required branch | ||
3. failed to checkout the required commit | ||
ValueError: If 1. entry point specified does not exist in the repo | ||
2. source dir specified does not exist in the repo | ||
""" | ||
self._validate_git_config() | ||
# create a temporary directory to store the cloned repo | ||
repo_dir = tempfile.mkdtemp() | ||
try: | ||
subprocess.check_call(['git', 'clone', self.git_config['repo'], repo_dir]) | ||
except subprocess.CalledProcessError: | ||
raise subprocess.CalledProcessError(1, cmd='git clone {} {}'.format(self.git_config['repo'], repo_dir)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need to create a new |
||
|
||
self._checkout_branch_and_commit(repo_dir) | ||
|
||
# check if the cloned repo contains entry point and source dir; if so, set ``entry_point`` and | ||
# ``source_dir`` to the paths to local file system. | ||
if self.source_dir: | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if os.path.isdir(os.path.join(repo_dir, self.source_dir)): | ||
self.source_dir = os.path.join(repo_dir, self.source_dir) | ||
os.chdir(self.source_dir) | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise ValueError('Source directory does not exist in the repo.') | ||
if not os.path.isfile(os.path.join(self.source_dir, self.entry_point)): | ||
raise ValueError('Entry point does not exist in the repo.') | ||
else: | ||
if not os.path.isfile(os.path.join(repo_dir, self.entry_point)): | ||
raise ValueError('Entry point does not exist in the repo.') | ||
for path in self.dependencies: | ||
if not os.path.isdir(os.path.join(repo_dir, path)): | ||
raise ValueError('Dependency {} does not exist in the repo.'.format(path)) | ||
|
||
def _checkout_branch_and_commit(self, repo_dir): | ||
"""Enter the directory where the repo is cloned, and checkout the required branch and commit. | ||
|
||
Args: | ||
repo_dir: the directory where the repo is cloned | ||
|
||
Raises: | ||
ValueError: If 1. entry point specified does not exist in the repo | ||
2. source dir specified does not exist in the repo | ||
""" | ||
os.chdir(repo_dir) | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if 'branch' in self.git_config: | ||
try: | ||
subprocess.check_call(['git', 'checkout', self.git_config['branch']]) | ||
except subprocess.CalledProcessError: | ||
raise subprocess.CalledProcessError(1, cmd='git checkout {}'.format(self.git_config['branch'])) | ||
if 'commit' in self.git_config: | ||
try: | ||
subprocess.check_call(['git', 'checkout', self.git_config['commit']]) | ||
except subprocess.CalledProcessError: | ||
raise subprocess.CalledProcessError(1, cmd='git checkout {}'.format(self.git_config['commit'])) | ||
|
||
def _validate_git_config(self): | ||
"""check if a git_config param is valid | ||
|
||
Raises: | ||
ValueError: If: | ||
1. git_config has no key 'repo' | ||
2. git_config['repo'] is in the wrong format. | ||
""" | ||
if 'repo' not in self.git_config: | ||
raise ValueError('Please provide a repo for git_config.') | ||
repo = self.git_config['repo'] | ||
codecommit_url = repo.startswith('https://git-codecommit') or repo.startswith('ssh://git-codecommit') | ||
github_url = repo.startswith('https://github') or repo.startswith('git@github') | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not codecommit_url and not github_url: | ||
raise ValueError('Please provide a valid git repo url.') | ||
|
||
def _prepare_for_training(self, job_name=None): | ||
"""Set hyperparameters needed for training. This method will also validate ``source_dir``. | ||
|
||
|
@@ -842,6 +960,9 @@ def _prepare_for_training(self, job_name=None): | |
""" | ||
super(Framework, self)._prepare_for_training(job_name=job_name) | ||
|
||
if self.git_config: | ||
self._git_clone_code() | ||
|
||
# validate source dir will raise a ValueError if there is something wrong with the | ||
# source directory. We are intentionally not handling it because this is a critical error. | ||
if self.source_dir and not self.source_dir.lower().startswith('s3://'): | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import os | ||
|
||
import numpy | ||
|
||
from sagemaker.pytorch.estimator import PyTorch | ||
from sagemaker.pytorch.model import PyTorchModel | ||
from sagemaker.utils import sagemaker_timestamp | ||
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES | ||
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
GIT_REPO = 'https://github.com/GaryTu1020/python-sdk-testing.git' | ||
BRANCH = 'branch1' | ||
COMMIT = '4893e528afa4a790331e1b5286954f073b0f14a2' | ||
|
||
|
||
def test_git_support_with_pytorch(sagemaker_local_session): | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
script_path = 'mnist.py' | ||
data_path = os.path.join(DATA_DIR, 'pytorch_mnist') | ||
git_config = {'repo': GIT_REPO, 'branch': BRANCH, 'commit': COMMIT} | ||
pytorch = PyTorch(entry_point=script_path, role='SageMakerRole', source_dir='pytorch', | ||
framework_version=PyTorch.LATEST_VERSION, py_version=PYTHON_VERSION, | ||
train_instance_count=1, train_instance_type='ml.c4.xlarge', | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sagemaker_session=sagemaker_local_session, git_config=git_config) | ||
|
||
train_input = pytorch.sagemaker_session.upload_data(path=os.path.join(data_path, 'training'), | ||
key_prefix='integ-test-data/pytorch_mnist/training') | ||
pytorch.fit({'training': train_input}) | ||
|
||
files = [file for file in os.listdir(pytorch.source_dir)] | ||
assert files == ['some-file', 'mnist.py'] | ||
|
||
endpoint_name = 'test-git_support-with-pytorch-{}'.format(sagemaker_timestamp()) | ||
|
||
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_local_session): | ||
desc = sagemaker_local_session.sagemaker_client.describe_training_job(pytorch.latest_training_job.name) | ||
model_data = desc['ModelArtifacts']['S3ModelArtifacts'] | ||
model = PyTorchModel(model_data, 'SageMakerRole', entry_point=script_path, | ||
sagemaker_session=sagemaker_local_session) | ||
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) | ||
GaryTu1020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
data = numpy.zeros(shape=(1, 1, 28, 28)) | ||
result = predictor.predict(data) | ||
assert result is not None |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.