Skip to content

feature: add git_config and git_clone, validate method #832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
10d27c5
add git_config and validate method
Jun 6, 2019
db8652c
Merge branch 'master' of github.com:aws/sagemaker-python-sdk into clo…
Jun 6, 2019
6b78ed4
modify the order of git_config, add tests
Jun 6, 2019
e59bb79
move validate_git_config, add integ test
Jun 8, 2019
7808faa
modify location _git_clone_code called
Jun 10, 2019
2783c4a
add documentation
Jun 11, 2019
db3b69f
Merge branch 'master' of github.com:aws/sagemaker-python-sdk into clo…
Jun 11, 2019
f397850
Update doc/overview.rst
GaryTu1020 Jun 12, 2019
5b8d684
Update doc/overview.rst
GaryTu1020 Jun 12, 2019
a9e2932
add more integ tests
Jun 13, 2019
241ac92
write unit tests for git_utils
Jun 15, 2019
a81859a
fix conflict on overview.rst
Jun 15, 2019
c39c344
delete a line
Jun 15, 2019
068a7b1
modify an assertion in test_with_mxnet
Jun 15, 2019
2b1622b
add assertion to some test functions
Jun 17, 2019
28a5c58
remove deploy part in test_git
Jun 17, 2019
0797060
change testing git repo
Jun 17, 2019
e2e5c20
change the testing repo
Jun 17, 2019
c6daa5d
correct an error message
Jun 18, 2019
e8bede0
pull master
Jun 18, 2019
e5bd806
stop patching private methods
Jun 18, 2019
c1bae10
modified overview.rst, add lock for tests
Jun 19, 2019
2af9b24
slight change to overview.rst
Jun 19, 2019
e15a22d
Merge branch 'master' into clone_from_github
chuyang-deng Jun 19, 2019
b102563
add a comment for lock
Jun 19, 2019
9ae910e
merge with remote branch
Jun 19, 2019
3383bfc
Merge branch 'master' into clone_from_github
GaryTu1020 Jun 20, 2019
9a7f4e1
Merge branch 'master' into clone_from_github
chuyang-deng Jun 20, 2019
d4bb0bb
merge with master
Jun 21, 2019
e6a01f0
merge with master
Jun 21, 2019
0c5e32b
merge aws master
Jun 21, 2019
b6e75d0
merge with master
Jun 21, 2019
3621bd4
merge with master
Jun 21, 2019
c7af978
merge with aws master
Jun 23, 2019
c2f7a43
merge with aws master
Jun 24, 2019
0790f41
Merge branch 'master' into clone_from_github
mvsusp Jun 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import json
import logging
import os
import subprocess
import tempfile
import warnings
from abc import ABCMeta
from abc import abstractmethod
Expand All @@ -24,7 +26,7 @@
import sagemaker
from sagemaker.analytics import TrainingJobAnalytics
from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, parse_s3_url, UploadedCode,
validate_source_dir)
validate_source_dir, validate_git_config)
from sagemaker.job import _Job
from sagemaker.local import LocalSession
from sagemaker.model import Model, NEO_ALLOWED_TARGET_INSTANCE_FAMILY, NEO_ALLOWED_FRAMEWORKS
Expand Down Expand Up @@ -771,13 +773,17 @@ class Framework(EstimatorBase):
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'

def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, **kwargs):
def __init__(self, entry_point, git_config=None, source_dir=None, hyperparameters=None,
enable_cloudwatch_metrics=False, container_log_level=logging.INFO, code_location=None,
image_name=None, dependencies=None, **kwargs):
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``

Args:
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
entry_point (str): Path (absolute or relative) to either: 1. the local Python source file if git_support
is False 2. the Python source file in Git repo if git_support is True, which should be executed
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch'
and 'commit' for now (default: None).
source_dir (str): Path (absolute or relative) to a directory with any other training
source code dependencies aside from the entry point file (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
Expand Down Expand Up @@ -815,9 +821,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
"""
super(Framework, self).__init__(**kwargs)

if entry_point.startswith('s3://'):
raise ValueError('Invalid entry point script: {}. Must be a path to a local file.'.format(entry_point))
self.entry_point = entry_point
self.git_config = git_config
self.source_dir = source_dir
self.dependencies = dependencies or []
if enable_cloudwatch_metrics:
Expand All @@ -830,6 +838,45 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl

self._hyperparameters = hyperparameters or {}

def _git_clone_code(self):
"""Git clone repo containing the training scripts.

This method also validate ``git_config``.
Set ``entry_point`` and ``source_dir`` to the right file or directory in the repo cloned.


"""
validate_git_config(self.git_config)
# create a temporary directory to store the cloned repo
repo_dir = tempfile.mkdtemp()
try:
subprocess.check_call(['git', 'clone', self.git_config['repo'], repo_dir])
except subprocess.CalledProcessError:
raise ValueError('Failed to clone git repo.')

# checkout the specified branch and commit
os.chdir(repo_dir)
try:
subprocess.check_call(['git', 'checkout', self.git_config['branch']])
except subprocess.CalledProcessError:
raise ValueError('Failed to checkout the required branch.')
try:
subprocess.check_call(['git', 'checkout', self.git_config['commit']])
except subprocess.CalledProcessError:
raise ValueError('Failed to checkout the required commit.')

# check if the cloned repo contains entry point and source dir; if so, set ``entry_point`` and
# ``source_dir`` to the paths to local file system.
if not os.path.isfile(os.path.join(repo_dir, self.entry_point)):
raise ValueError('Entry point does not exist in the repo.')
else:
self.entry_point = os.path.join(repo_dir, self.entry_point)
if self.source_dir:
if not os.path.isdir(os.path.join(repo_dir, self.source_dir)):
raise ValueError('Source does not exist in the repo.')
else:
self.source_dir = os.path.join(repo_dir, self.source_dir)

def _prepare_for_training(self, job_name=None):
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.

Expand Down
22 changes: 22 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,28 @@ def _accelerator_type_valid_for_framework(framework, accelerator_type=None, opti
return True


def validate_git_config(git_config):
"""check if a git_config param is valid

Args:
git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch',
and 'commit' for now.

Raises:
ValueError: If:
1. git_config has no key 'repo'
2. git_config['repo'] is in the wrong format.
"""
if 'repo' not in git_config:
raise ValueError('Please provide a repo for git_config.')
codecommit_url = git_config['repo'].startswith('https://git-codecommit') \
or git_config['repo'].startswith('ssh://git-codecommit')
github_url = git_config['repo'].startswith('https://github') \
or git_config['repo'].startswith('git@github')
if not codecommit_url and not github_url:
raise ValueError('Please provide a valid git repo url.')


def validate_source_dir(script, directory):
"""Validate that the source directory exists and it contains the user script

Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,19 @@ def test_prepare_for_training_force_name_generation(strftime, sagemaker_session)
assert JOB_NAME == fw._current_job_name


def test_git_clone_code_succeed(sagemaker_session):
git_config = {'repo': 'https://github.com/GaryTu1020/python-sdk-testing.git',
'branch': 'branch1',
'commit': 'aea6f3acef9619f77f94772d9d654f041e16bf49'}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here (global variables.)

fw = DummyFramework(entry_point='source_dir/entry_point', git_config=git_config,
source_dir='source_dir', role=ROLE, sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
enable_cloudwatch_metrics=True)
fw._git_clone_code()
assert os.path.isfile(fw.entry_point)
assert os.path.isdir(fw.source_dir)


@patch('time.strftime', return_value=TIMESTAMP)
def test_init_with_source_dir_s3(strftime, sagemaker_session):
fw = DummyFramework(entry_point=SCRIPT_PATH, source_dir='s3://location', role=ROLE,
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ def test_create_image_uri_local_sagemaker_notebook_accelerator():
assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0rc-gpu-py3'


def test_validate_git_config_repo_not_provided():
git_config = {'branch': 'master', 'username': 'User1', 'password': 'passw0rd'}
with pytest.raises(ValueError) as error:
fw_utils.validate_git_config(git_config)
assert 'Please provide a repo for git_config.' in str(error)


def test_validate_git_config_bad_repo_url():
git_config = {'repo': 'hhttps://github.com/user/repo.git', 'branch': 'master', 'password': 'passw0rd'}
with pytest.raises(ValueError) as error:
fw_utils.validate_git_config(git_config)
assert 'Please provide a valid git repo url.' in str(error)


def test_invalid_accelerator():
error_message = '{} is not a valid SageMaker Elastic Inference accelerator type.'.format(MOCK_ACCELERATOR)
# accelerator type is missing 'ml.' prefix
Expand Down