-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Allow Local Mode to work with a local training script. #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,13 +14,12 @@ | |
|
||
import json | ||
import logging | ||
import os | ||
from abc import ABCMeta | ||
from abc import abstractmethod | ||
from six import with_metaclass, string_types | ||
|
||
from sagemaker.fw_utils import tar_and_upload_dir | ||
from sagemaker.fw_utils import parse_s3_url | ||
from sagemaker.fw_utils import UploadedCode | ||
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir | ||
from sagemaker.local.local_session import LocalSession, file_input | ||
|
||
from sagemaker.model import Model | ||
|
@@ -30,7 +29,7 @@ | |
from sagemaker.predictor import RealTimePredictor | ||
from sagemaker.session import Session | ||
from sagemaker.session import s3_input | ||
from sagemaker.utils import base_name_from_image, name_from_base | ||
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value | ||
|
||
|
||
class EstimatorBase(with_metaclass(ABCMeta, object)): | ||
|
@@ -83,13 +82,10 @@ def __init__(self, role, train_instance_count, train_instance_type, | |
self.input_mode = input_mode | ||
|
||
if self.train_instance_type in ('local', 'local_gpu'): | ||
self.local_mode = True | ||
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1: | ||
raise RuntimeError("Distributed Training in Local GPU is not supported") | ||
|
||
self.sagemaker_session = sagemaker_session or LocalSession() | ||
else: | ||
self.local_mode = False | ||
self.sagemaker_session = sagemaker_session or Session() | ||
|
||
self.base_job_name = base_job_name | ||
|
@@ -158,9 +154,14 @@ def fit(self, inputs, wait=True, logs=True, job_name=None): | |
base_name = self.base_job_name or base_name_from_image(self.train_image()) | ||
self._current_job_name = name_from_base(base_name) | ||
|
||
# if output_path was specified we use it otherwise initialize here | ||
# if output_path was specified we use it otherwise initialize here. | ||
# For Local Mode with no_internet=True we don't need an explicit output_path | ||
if self.output_path is None: | ||
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket()) | ||
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config) | ||
if self.sagemaker_session.local_mode and no_internet: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we start throwing more configuration into local_mode, then we should have a local_mode configuration object in place of the local_mode boolean variable. I think this is okay now, but once we get one more piece of config, let's refactor this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good. |
||
self.output_path = '' | ||
else: | ||
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket()) | ||
|
||
self.latest_training_job = _TrainingJob.start_new(self, inputs) | ||
if wait: | ||
|
@@ -323,7 +324,7 @@ def start_new(cls, estimator, inputs): | |
sagemaker.estimator.Framework: Constructed object that captures all information about the started job. | ||
""" | ||
|
||
local_mode = estimator.local_mode | ||
local_mode = estimator.sagemaker_session.local_mode | ||
|
||
# Allow file:// input only in local mode | ||
if isinstance(inputs, str) and inputs.startswith('file://'): | ||
|
@@ -604,27 +605,54 @@ def fit(self, inputs, wait=True, logs=True, job_name=None): | |
base_name = self.base_job_name or base_name_from_image(self.train_image()) | ||
self._current_job_name = name_from_base(base_name) | ||
|
||
# validate source dir will raise a ValueError if there is something wrong with the | ||
# source directory. We are intentionally not handling it because this is a critical error. | ||
if self.source_dir and not self.source_dir.lower().startswith('s3://'): | ||
validate_source_dir(self.entry_point, self.source_dir) | ||
|
||
# if we are in local mode with no_internet=True. We want the container to just | ||
# mount the source dir instead of uploading to S3. | ||
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config) | ||
if self.sagemaker_session.local_mode and no_internet: | ||
# if there is no source dir, use the directory containing the entry point. | ||
if self.source_dir is None: | ||
self.source_dir = os.path.dirname(self.entry_point) | ||
self.entry_point = os.path.basename(self.entry_point) | ||
|
||
code_dir = 'file://' + self.source_dir | ||
script = self.entry_point | ||
else: | ||
self.uploaded_code = self._stage_user_code_in_s3() | ||
code_dir = self.uploaded_code.s3_prefix | ||
script = self.uploaded_code.script_name | ||
|
||
# Modify hyperparameters in-place to point to the right code directory and script URIs | ||
self._hyperparameters[DIR_PARAM_NAME] = code_dir | ||
self._hyperparameters[SCRIPT_PARAM_NAME] = script | ||
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics | ||
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level | ||
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name | ||
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.region_name | ||
super(Framework, self).fit(inputs, wait, logs, self._current_job_name) | ||
|
||
def _stage_user_code_in_s3(self): | ||
""" Upload the user training script to s3 and return the location. | ||
|
||
Returns: s3 uri | ||
|
||
""" | ||
if self.code_location is None: | ||
code_bucket = self.sagemaker_session.default_bucket() | ||
code_s3_prefix = '{}/source'.format(self._current_job_name) | ||
else: | ||
code_bucket, key_prefix = parse_s3_url(self.code_location) | ||
code_s3_prefix = '{}/{}/source'.format(key_prefix, self._current_job_name) | ||
|
||
self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session, | ||
bucket=code_bucket, | ||
s3_key_prefix=code_s3_prefix, | ||
script=self.entry_point, | ||
directory=self.source_dir) | ||
|
||
# Modify hyperparameters in-place to add the URLs to the uploaded code. | ||
self._hyperparameters[DIR_PARAM_NAME] = self.uploaded_code.s3_prefix | ||
self._hyperparameters[SCRIPT_PARAM_NAME] = self.uploaded_code.script_name | ||
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics | ||
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level | ||
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name | ||
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_session.region_name | ||
super(Framework, self).fit(inputs, wait, logs, self._current_job_name) | ||
return tar_and_upload_dir(session=self.sagemaker_session.boto_session, | ||
bucket=code_bucket, | ||
s3_key_prefix=code_s3_prefix, | ||
script=self.entry_point, | ||
directory=self.source_dir) | ||
|
||
def hyperparameters(self): | ||
"""Return the hyperparameters as a dictionary to use for training. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,27 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver | |
.format(account, region, framework, tag) | ||
|
||
|
||
def validate_source_dir(script, directory): | ||
"""Validate that the source directory exists and it contains the user script | ||
|
||
Args: | ||
script (str): Script filename. | ||
directory (str): Directory containing the source file. | ||
|
||
Raises: | ||
ValueError: If ``directory`` does not exist, is not a directory, or does not contain ``script``. | ||
""" | ||
if directory: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this if statement |
||
if not os.path.exists(directory): | ||
raise ValueError('"{}" does not exist.'.format(directory)) | ||
if not os.path.isdir(directory): | ||
raise ValueError('"{}" is not a directory.'.format(directory)) | ||
if script not in os.listdir(directory): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a very small point - but doing a stat on os.path.join(directory, script) would be better, because it avoids listing the entire directory. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd be tempted to rewrite this entire thing as: if not os.path.isfile(os.path.join(directory, script)):
raise ValueError('...') There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. This is was just taken from the original tar_and_upload_dir() but this is a great time to improve it 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just realized I didnt do this. Let me do it now. |
||
raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory)) | ||
|
||
return True | ||
|
||
|
||
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory): | ||
"""Pack and upload source files to S3 only if directory is empty or local. | ||
|
||
|
@@ -83,21 +104,13 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory): | |
|
||
Returns: | ||
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name. | ||
|
||
Raises: | ||
ValueError: If ``directory`` does not exist, is not a directory, or does not contain ``script``. | ||
""" | ||
if directory: | ||
if directory.lower().startswith("s3://"): | ||
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script)) | ||
if not os.path.exists(directory): | ||
raise ValueError('"{}" does not exist.'.format(directory)) | ||
if not os.path.isdir(directory): | ||
raise ValueError('"{}" is not a directory.'.format(directory)) | ||
if script not in os.listdir(directory): | ||
raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory)) | ||
script_name = script | ||
source_files = [os.path.join(directory, name) for name in os.listdir(directory)] | ||
else: | ||
script_name = script | ||
source_files = [os.path.join(directory, name) for name in os.listdir(directory)] | ||
else: | ||
# If no directory is specified, the script parameter needs to be a valid relative path. | ||
os.path.exists(script) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This property should be local.local_code and it should default to true. Customers can still call out to the internet if 'no_internet' is true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
much better, I was not a fan of no_internet.