Skip to content

feature: add git_config and git_clone, validate method #832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
10d27c5
add git_config and validate method
Jun 6, 2019
db8652c
Merge branch 'master' of github.com:aws/sagemaker-python-sdk into clo…
Jun 6, 2019
6b78ed4
modify the order of git_config, add tests
Jun 6, 2019
e59bb79
move validate_git_config, add integ test
Jun 8, 2019
7808faa
modify location _git_clone_code called
Jun 10, 2019
2783c4a
add documentation
Jun 11, 2019
db3b69f
Merge branch 'master' of github.com:aws/sagemaker-python-sdk into clo…
Jun 11, 2019
f397850
Update doc/overview.rst
GaryTu1020 Jun 12, 2019
5b8d684
Update doc/overview.rst
GaryTu1020 Jun 12, 2019
a9e2932
add more integ tests
Jun 13, 2019
241ac92
write unit tests for git_utils
Jun 15, 2019
a81859a
fix conflict on overview.rst
Jun 15, 2019
c39c344
delete a line
Jun 15, 2019
068a7b1
modify an assertion in test_with_mxnet
Jun 15, 2019
2b1622b
add assertion to some test functions
Jun 17, 2019
28a5c58
remove deploy part in test_git
Jun 17, 2019
0797060
change testing git repo
Jun 17, 2019
e2e5c20
change the testing repo
Jun 17, 2019
c6daa5d
correct an error message
Jun 18, 2019
e8bede0
pull master
Jun 18, 2019
e5bd806
stop patching private methods
Jun 18, 2019
c1bae10
modified overview.rst, add lock for tests
Jun 19, 2019
2af9b24
slight change to overview.rst
Jun 19, 2019
e15a22d
Merge branch 'master' into clone_from_github
chuyang-deng Jun 19, 2019
b102563
add a comment for lock
Jun 19, 2019
9ae910e
merge with remote branch
Jun 19, 2019
3383bfc
Merge branch 'master' into clone_from_github
GaryTu1020 Jun 20, 2019
9a7f4e1
Merge branch 'master' into clone_from_github
chuyang-deng Jun 20, 2019
d4bb0bb
merge with master
Jun 21, 2019
e6a01f0
merge with master
Jun 21, 2019
0c5e32b
merge aws master
Jun 21, 2019
b6e75d0
merge with master
Jun 21, 2019
3621bd4
merge with master
Jun 21, 2019
c7af978
merge with aws master
Jun 23, 2019
c2f7a43
merge with aws master
Jun 24, 2019
0790f41
Merge branch 'master' into clone_from_github
mvsusp Jun 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions doc/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,65 @@ For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/r
# Deletes the SageMaker model
mxnet_predictor.delete_model()

Git Support
~~~~~~~~~~~
If you have your training scripts in your GitHub repository, you can use them directly without the trouble to download
them to local machine. Git support can be enabled simply by providing ``git_config`` parameter when initializing an
estimator. If Git support is enabled, then ``entry_point``, ``source_dir`` and ``dependencies`` should all be relative
paths in the Git repo. Note that if you decided to use Git support, then everything you need for ``entry_point``,
``source_dir`` and ``dependencies`` should be in a single Git repo.

Here are ways to specify ``git_config``:

.. code:: python

# Specifies the git_config parameter
git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git',
'branch': 'branch1',
'commit': '4893e528afa4a790331e1b5286954f073b0f14a2'}

# Alternatively, you can also specify git_config by providing only 'repo' and 'branch'.
# If this is the case, the latest commit in the branch will be used.
git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git',
'branch': 'branch1'}

# Only providing 'repo' is also allowed. If this is the case, latest commit in
# 'master' branch will be used.
git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git'

The following are some examples to define estimators with Git support:

.. code:: python

# In this example, the source directory 'pytorch' contains the entry point 'mnist.py' and other source code.
# and it is relative path inside the Git repo.
pytorch_estimator = PyTorch(entry_point='mnist.py',
role='SageMakerRole',
source_dir='pytorch',
git_config=git_config,
train_instance_count=1,
train_instance_type='ml.c4.xlarge')

# In this example, the entry point 'mnist.py' is all we need for source code.
# We need to specify the path to it in the Git repo.
mx_estimator = MXNet(entry_point='mxnet/mnist.py',
role='SageMakerRole',
git_config=git_config,
train_instance_count=1,
train_instance_type='ml.c4.xlarge')

# In this example, besides entry point and other source code in source directory, we still need some
# dependencies for the training job. Dependencies should also be paths inside the Git repo.
pytorch_estimator = PyTorch(entry_point='mnist.py',
role='SageMakerRole',
source_dir='pytorch',
dependencies=['dep.py', 'foo/bar.py'],
git_config=git_config,
train_instance_count=1,
train_instance_type='ml.c4.xlarge')

When Git support is enabled, users can still use local mode in the same way.

Training Metrics
~~~~~~~~~~~~~~~~
The SageMaker Python SDK allows you to specify a name and a regular expression for metrics you want to track for training.
Expand Down Expand Up @@ -268,6 +327,7 @@ Currently, the following algorithms support incremental training:
- Object Detection
- Semantic Segmentation


Using SageMaker AlgorithmEstimators
-----------------------------------

Expand Down
52 changes: 51 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from six import string_types

import sagemaker
from sagemaker import git_utils
from sagemaker.analytics import TrainingJobAnalytics
from sagemaker.fw_utils import (
create_image_uri,
Expand Down Expand Up @@ -933,6 +934,7 @@ class Framework(EstimatorBase):
"""

__framework_name__ = None

LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
Expand All @@ -949,6 +951,7 @@ def __init__(
code_location=None,
image_name=None,
dependencies=None,
git_config=None,
enable_network_isolation=False,
**kwargs
):
Expand All @@ -957,9 +960,47 @@ def __init__(
Args:
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
If 'git_config' is provided, 'entry_point' should be a relative location to the Python source file in
the Git repo.
Example:

With the following GitHub repo directory structure:

>>> |----- README.md
>>> |----- src
>>> |----- train.py
>>> |----- test.py

You can assign entry_point='src/train.py'.
git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch'
and 'commit' (default: None).
'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If
'commit' is not specified, the latest commit in the required branch will be used.
Example:

The following config:

>>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git',
>>> 'branch': 'test-branch-git-config',
>>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}

results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout
the specified commit.
source_dir (str): Path (absolute or relative) to a directory with any other training
source code dependencies aside from the entry point file (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
directory are preserved when training on Amazon SageMaker. If 'git_config' is provided,
source_dir should be a relative location to a directory in the Git repo.
Example:

With the following GitHub repo directory structure:

>>> |----- README.md
>>> |----- src
>>> |----- train.py
>>> |----- test.py

and you need 'train.py' as entry point and 'test.py' as training source code as well, you can
assign entry_point='train.py', source_dir='src'.
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
For convenience, this accepts other types for keys and values, but ``str()`` will be called
Expand Down Expand Up @@ -1006,6 +1047,7 @@ def __init__(
)
)
self.entry_point = entry_point
self.git_config = git_config
self.source_dir = source_dir
self.dependencies = dependencies or []
if enable_cloudwatch_metrics:
Expand Down Expand Up @@ -1038,6 +1080,14 @@ def _prepare_for_training(self, job_name=None):
"""
super(Framework, self)._prepare_for_training(job_name=job_name)

if self.git_config:
updates = git_utils.git_clone_repo(
self.git_config, self.entry_point, self.source_dir, self.dependencies
)
self.entry_point = updates["entry_point"]
self.source_dir = updates["source_dir"]
self.dependencies = updates["dependencies"]

# validate source dir will raise a ValueError if there is something wrong with the
# source directory. We are intentionally not handling it because this is a critical error.
if self.source_dir and not self.source_dir.lower().startswith("s3://"):
Expand Down
104 changes: 104 additions & 0 deletions src/sagemaker/git_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import subprocess
import tempfile


def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
"""Git clone repo containing the training code and serving code. This method also validate ``git_config``,
and set ``entry_point``, ``source_dir`` and ``dependencies`` to the right file or directory in the repo cloned.

Args:
git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``
and ``commit``. ``branch`` and ``commit`` are optional. If ``branch`` is not specified, master branch
will be used. If ``commit`` is not specified, the latest commit in the required branch will be used.
entry_point (str): A relative location to the Python source file which should be executed as the entry point
to training or model hosting in the Git repo.
source_dir (str): A relative location to a directory with other training or model hosting source code
dependencies aside from the entry point file in the Git repo (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
dependencies (list[str]): A list of relative locations to directories with any additional libraries that will
be exported to the container in the Git repo (default: []).

Raises:
CalledProcessError: If 1. failed to clone git repo
2. failed to checkout the required branch
3. failed to checkout the required commit
ValueError: If 1. entry point specified does not exist in the repo
2. source dir specified does not exist in the repo

Returns:
dict: A dict that contains the updated values of entry_point, source_dir and dependencies
"""
_validate_git_config(git_config)
repo_dir = tempfile.mkdtemp()
subprocess.check_call(["git", "clone", git_config["repo"], repo_dir])

_checkout_branch_and_commit(git_config, repo_dir)

ret = {"entry_point": entry_point, "source_dir": source_dir, "dependencies": dependencies}
# check if the cloned repo contains entry point, source directory and dependencies
if source_dir:
if not os.path.isdir(os.path.join(repo_dir, source_dir)):
raise ValueError("Source directory does not exist in the repo.")
if not os.path.isfile(os.path.join(repo_dir, source_dir, entry_point)):
raise ValueError("Entry point does not exist in the repo.")
ret["source_dir"] = os.path.join(repo_dir, source_dir)
else:
if not os.path.isfile(os.path.join(repo_dir, entry_point)):
raise ValueError("Entry point does not exist in the repo.")
ret["entry_point"] = os.path.join(repo_dir, entry_point)

ret["dependencies"] = []
for path in dependencies:
if not os.path.exists(os.path.join(repo_dir, path)):
raise ValueError("Dependency {} does not exist in the repo.".format(path))
ret["dependencies"].append(os.path.join(repo_dir, path))
return ret


def _validate_git_config(git_config):
"""check if a git_config param is valid

Args:
git_config ((dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``
and ``commit``.

Raises:
ValueError: If:
1. git_config has no key 'repo'
2. git_config['repo'] is in the wrong format.
"""
if "repo" not in git_config:
raise ValueError("Please provide a repo for git_config.")


def _checkout_branch_and_commit(git_config, repo_dir):
"""Checkout the required branch and commit.

Args:
git_config: (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``
and ``commit``.
repo_dir (str): the directory where the repo is cloned

Raises:
ValueError: If 1. entry point specified does not exist in the repo
2. source dir specified does not exist in the repo
"""
if "branch" in git_config:
subprocess.check_call(args=["git", "checkout", git_config["branch"]], cwd=str(repo_dir))
if "commit" in git_config:
subprocess.check_call(args=["git", "checkout", git_config["commit"]], cwd=str(repo_dir))
100 changes: 100 additions & 0 deletions tests/integ/test_git.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

import numpy
import tempfile

from tests.integ import lock as lock
from sagemaker.mxnet.estimator import MXNet
from sagemaker.pytorch.estimator import PyTorch
from tests.integ import DATA_DIR, PYTHON_VERSION

GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
BRANCH = "test-branch-git-config"
COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a"

# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_git_lock")


def test_git_support_with_pytorch(sagemaker_local_session):
script_path = "mnist.py"
data_path = os.path.join(DATA_DIR, "pytorch_mnist")
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
pytorch = PyTorch(
entry_point=script_path,
role="SageMakerRole",
source_dir="pytorch",
framework_version=PyTorch.LATEST_VERSION,
py_version=PYTHON_VERSION,
train_instance_count=1,
train_instance_type="local",
sagemaker_session=sagemaker_local_session,
git_config=git_config,
)

pytorch.fit({"training": "file://" + os.path.join(data_path, "training")})

with lock.lock(LOCK_PATH):
try:
predictor = pytorch.deploy(initial_instance_count=1, instance_type="local")

data = numpy.zeros(shape=(1, 1, 28, 28)).astype(numpy.float32)
result = predictor.predict(data)
assert result is not None
finally:
predictor.delete_endpoint()


def test_git_support_with_mxnet(sagemaker_local_session, mxnet_full_version):
script_path = "mnist.py"
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
dependencies = ["foo/bar.py"]
mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
source_dir="mxnet",
dependencies=dependencies,
framework_version=MXNet.LATEST_VERSION,
py_version=PYTHON_VERSION,
train_instance_count=1,
train_instance_type="local",
sagemaker_session=sagemaker_local_session,
git_config=git_config,
)

mx.fit(
{
"train": "file://" + os.path.join(data_path, "train"),
"test": "file://" + os.path.join(data_path, "test"),
}
)

files = [file for file in os.listdir(mx.source_dir)]
assert "some_file" in files
assert "mnist.py" in files
assert os.path.exists(mx.dependencies[0])

with lock.lock(LOCK_PATH):
try:
predictor = mx.deploy(initial_instance_count=1, instance_type="local")

data = numpy.zeros(shape=(1, 1, 28, 28))
result = predictor.predict(data)
assert result is not None
finally:
predictor.delete_endpoint()
Loading