Skip to content

feature: add git_config and git_clone, validate method #832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
10d27c5
add git_config and validate method
Jun 6, 2019
db8652c
Merge branch 'master' of github.com:aws/sagemaker-python-sdk into clo…
Jun 6, 2019
6b78ed4
modify the order of git_config, add tests
Jun 6, 2019
e59bb79
move validate_git_config, add integ test
Jun 8, 2019
7808faa
modify location _git_clone_code called
Jun 10, 2019
2783c4a
add documentation
Jun 11, 2019
db3b69f
Merge branch 'master' of github.com:aws/sagemaker-python-sdk into clo…
Jun 11, 2019
f397850
Update doc/overview.rst
GaryTu1020 Jun 12, 2019
5b8d684
Update doc/overview.rst
GaryTu1020 Jun 12, 2019
a9e2932
add more integ tests
Jun 13, 2019
241ac92
write unit tests for git_utils
Jun 15, 2019
a81859a
fix conflict on overview.rst
Jun 15, 2019
c39c344
delete a line
Jun 15, 2019
068a7b1
modify an assertion in test_with_mxnet
Jun 15, 2019
2b1622b
add assertion to some test functions
Jun 17, 2019
28a5c58
remove deploy part in test_git
Jun 17, 2019
0797060
change testing git repo
Jun 17, 2019
e2e5c20
change the testing repo
Jun 17, 2019
c6daa5d
correct an error message
Jun 18, 2019
e8bede0
pull master
Jun 18, 2019
e5bd806
stop patching private methods
Jun 18, 2019
c1bae10
modified overview.rst, add lock for tests
Jun 19, 2019
2af9b24
slight change to overview.rst
Jun 19, 2019
e15a22d
Merge branch 'master' into clone_from_github
chuyang-deng Jun 19, 2019
b102563
add a comment for lock
Jun 19, 2019
9ae910e
merge with remote branch
Jun 19, 2019
3383bfc
Merge branch 'master' into clone_from_github
GaryTu1020 Jun 20, 2019
9a7f4e1
Merge branch 'master' into clone_from_github
chuyang-deng Jun 20, 2019
d4bb0bb
merge with master
Jun 21, 2019
e6a01f0
merge with master
Jun 21, 2019
0c5e32b
merge aws master
Jun 21, 2019
b6e75d0
merge with master
Jun 21, 2019
3621bd4
merge with master
Jun 21, 2019
c7af978
merge with aws master
Jun 23, 2019
c2f7a43
merge with aws master
Jun 24, 2019
0790f41
Merge branch 'master' into clone_from_github
mvsusp Jun 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions doc/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,58 @@ For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/r
# Deletes the SageMaker model
mxnet_predictor.delete_model()

Git Support
~~~~~~~~~~~
The SageMaker Python SDK has Git support. If you have your training scripts in your GitHub repository, you can
use them directly wihtout the trouble to download them to local machine. Git support can be enabled simply by
providing ``git_config`` parameter when initializing an estimator. If git support is enabled, then ``entry_point``,
``source_dir`` and ``dependencies`` should all be relative paths in the Git repo. Git support works for frameworks
images (TensorFlow, MXNet, Chainer, PyTorch, and Scikit-Learn). Note that if you decided to
use Git support, then everything you need for ``entry_point``, ``source_dir`` and ``dependencies`` should be in
a single Git repo.

Here is an example:

.. code:: python

import sagemaker

# Specifies the git_config parameter
git_config = {'repo': 'https://github.com/GaryTu1020/python-sdk-testing.git',
'branch': 'branch1',
'commit': '4893e528afa4a790331e1b5286954f073b0f14a2'}

# Configures a PyTorch Estimator (no training happens here)
pytorch_estimator = PyTorch(entry_point='mnist.py',
role='SageMakerRole',
source_dir='pytorch',
git_config=git_config,
framework_version=PyTorch.LATEST_VERSION,
py_version=PYTHON_VERSION,
train_instance_count=1,
train_instance_type='ml.c4.xlarge',
sagemaker_session=sagemaker_local_session)

# Starts a SageMaker training job and waits until completion
pytorch_estimator.fit('s3://my_bucket/my_training_data/')

# Deploys the model that was generated by fit() to an existing SageMaker endpoint
mxnet_predictor = mxnet_estimator.deploy(initial_instance_count=1,
instance_type='ml.p2.xlarge',
update_endpoint=True,
endpoint_name='existing-endpoint')

# Serializes data and makes a prediction request to the SageMaker endpoint
response = mxnet_predictor.predict(data)

# Tears down the SageMaker endpoint and endpoint configuration
mxnet_predictor.delete_endpoint()

# Deletes the SageMaker model
mxnet_predictor.delete_model()

When Git support is enabled, users can still use local mode in the same way.

Training Metrics
~~~~~~~~~~~~~~~~
The SageMaker Python SDK allows you to specify a name and a regular expression for metrics you want to track for training.
Expand Down Expand Up @@ -268,6 +320,7 @@ Currently, the following algorithms support incremental training:
- Object Detection
- Semantic Segmentation


Using SageMaker AlgorithmEstimators
-----------------------------------

Expand Down
127 changes: 124 additions & 3 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import json
import logging
import os
import subprocess
import tempfile
import warnings
from abc import ABCMeta
from abc import abstractmethod
Expand Down Expand Up @@ -774,16 +776,54 @@ class Framework(EstimatorBase):
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'

def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs):
def __init__(self, entry_point, source_dir=None, hyperparameters=None,
enable_cloudwatch_metrics=False, container_log_level=logging.INFO, code_location=None,
image_name=None, dependencies=None, git_config=None, **kwargs):
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``

Args:
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
If 'git_config' is provided, 'entry_point' should be a relative location to the Python source file in
the Git repo.
Example:

With the following GitHub repo directory structure:

>>> |----- README.md
>>> |----- src
>>> |----- train.py
>>> |----- test.py

You can assign entry_point='src/train.py'.
git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch'
and 'commit' (default: None).
'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If
'commit' is not specified, the latest commit in the required branch will be used.
Example:

The following config:

>>> git_config = {'repo': 'https://github.com/GaryTu1020/python-sdk-testing.git',
>>> 'branch': 'master',
>>> 'commit': 'aea6f3acef9619f77f94772d9d654f041e16bf49'}

results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout
the specified commit.
source_dir (str): Path (absolute or relative) to a directory with any other training
source code dependencies aside from the entry point file (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
directory are preserved when training on Amazon SageMaker. If 'git_config' is provided,
source_dir should be a relative location to a directory in the Git repo.
Example:

With the following GitHub repo directory structure:

>>> |----- README.md
>>> |----- src
>>> |----- train.py
>>> |----- test.py

You can assign entry_point='train.py', source_dir='src'.
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container (default: []).
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
Expand Down Expand Up @@ -818,9 +858,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
"""
super(Framework, self).__init__(**kwargs)

if entry_point.startswith('s3://'):
raise ValueError('Invalid entry point script: {}. Must be a path to a local file.'.format(entry_point))
self.entry_point = entry_point
self.git_config = git_config
self.source_dir = source_dir
self.dependencies = dependencies or []
if enable_cloudwatch_metrics:
Expand All @@ -833,6 +875,82 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl

self._hyperparameters = hyperparameters or {}

def _git_clone_code(self):
"""Git clone repo containing the training scripts. This method also validate ``git_config``,
and set ``entry_point`` and ``source_dir`` to the right file or directory in the repo cloned.

Raises:
CalledProcessError: If 1. failed to clone git repo
2. failed to checkout the required branch
3. failed to checkout the required commit
ValueError: If 1. entry point specified does not exist in the repo
2. source dir specified does not exist in the repo
"""
self._validate_git_config()
# create a temporary directory to store the cloned repo
repo_dir = tempfile.mkdtemp()
try:
subprocess.check_call(['git', 'clone', self.git_config['repo'], repo_dir])
except subprocess.CalledProcessError:
raise subprocess.CalledProcessError(1, cmd='git clone {} {}'.format(self.git_config['repo'], repo_dir))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to create a new CalledProcessError here instead of just using the original call?


self._checkout_branch_and_commit(repo_dir)

# check if the cloned repo contains entry point and source dir; if so, set ``entry_point`` and
# ``source_dir`` to the paths to local file system.
if self.source_dir:
if os.path.isdir(os.path.join(repo_dir, self.source_dir)):
self.source_dir = os.path.join(repo_dir, self.source_dir)
os.chdir(self.source_dir)
else:
raise ValueError('Source directory does not exist in the repo.')
if not os.path.isfile(os.path.join(self.source_dir, self.entry_point)):
raise ValueError('Entry point does not exist in the repo.')
else:
if not os.path.isfile(os.path.join(repo_dir, self.entry_point)):
raise ValueError('Entry point does not exist in the repo.')
for path in self.dependencies:
if not os.path.isdir(os.path.join(repo_dir, path)):
raise ValueError('Dependency {} does not exist in the repo.'.format(path))

def _checkout_branch_and_commit(self, repo_dir):
"""Enter the directory where the repo is cloned, and checkout the required branch and commit.

Args:
repo_dir: the directory where the repo is cloned

Raises:
ValueError: If 1. entry point specified does not exist in the repo
2. source dir specified does not exist in the repo
"""
os.chdir(repo_dir)
if 'branch' in self.git_config:
try:
subprocess.check_call(['git', 'checkout', self.git_config['branch']])
except subprocess.CalledProcessError:
raise subprocess.CalledProcessError(1, cmd='git checkout {}'.format(self.git_config['branch']))
if 'commit' in self.git_config:
try:
subprocess.check_call(['git', 'checkout', self.git_config['commit']])
except subprocess.CalledProcessError:
raise subprocess.CalledProcessError(1, cmd='git checkout {}'.format(self.git_config['commit']))

def _validate_git_config(self):
"""check if a git_config param is valid

Raises:
ValueError: If:
1. git_config has no key 'repo'
2. git_config['repo'] is in the wrong format.
"""
if 'repo' not in self.git_config:
raise ValueError('Please provide a repo for git_config.')
repo = self.git_config['repo']
codecommit_url = repo.startswith('https://git-codecommit') or repo.startswith('ssh://git-codecommit')
github_url = repo.startswith('https://github') or repo.startswith('git@github')
if not codecommit_url and not github_url:
raise ValueError('Please provide a valid git repo url.')

def _prepare_for_training(self, job_name=None):
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.

Expand All @@ -842,6 +960,9 @@ def _prepare_for_training(self, job_name=None):
"""
super(Framework, self)._prepare_for_training(job_name=job_name)

if self.git_config:
self._git_clone_code()

# validate source dir will raise a ValueError if there is something wrong with the
# source directory. We are intentionally not handling it because this is a critical error.
if self.source_dir and not self.source_dir.lower().startswith('s3://'):
Expand Down
57 changes: 57 additions & 0 deletions tests/integ/test_git.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

import numpy

from sagemaker.pytorch.estimator import PyTorch
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.utils import sagemaker_timestamp
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
GIT_REPO = 'https://github.com/GaryTu1020/python-sdk-testing.git'
BRANCH = 'branch1'
COMMIT = '4893e528afa4a790331e1b5286954f073b0f14a2'


def test_git_support_with_pytorch(sagemaker_local_session):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
script_path = 'mnist.py'
data_path = os.path.join(DATA_DIR, 'pytorch_mnist')
git_config = {'repo': GIT_REPO, 'branch': BRANCH, 'commit': COMMIT}
pytorch = PyTorch(entry_point=script_path, role='SageMakerRole', source_dir='pytorch',
framework_version=PyTorch.LATEST_VERSION, py_version=PYTHON_VERSION,
train_instance_count=1, train_instance_type='ml.c4.xlarge',
sagemaker_session=sagemaker_local_session, git_config=git_config)

train_input = pytorch.sagemaker_session.upload_data(path=os.path.join(data_path, 'training'),
key_prefix='integ-test-data/pytorch_mnist/training')
pytorch.fit({'training': train_input})

files = [file for file in os.listdir(pytorch.source_dir)]
assert files == ['some-file', 'mnist.py']

endpoint_name = 'test-git_support-with-pytorch-{}'.format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_local_session):
desc = sagemaker_local_session.sagemaker_client.describe_training_job(pytorch.latest_training_job.name)
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
model = PyTorchModel(model_data, 'SageMakerRole', entry_point=script_path,
sagemaker_session=sagemaker_local_session)
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)

data = numpy.zeros(shape=(1, 1, 28, 28))
result = predictor.predict(data)
assert result is not None
Loading