Skip to content

breaking: rename s3_input to TrainingInput #1680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/sagemaker-submit
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ if __name__ == '__main__':
hyperparameters=hyperparameters,
instance_count=args.instance_count,
instance_type=args.instance_type)
estimator.fit(sagemaker.s3_input(args.data))
estimator.fit(sagemaker.TrainingInput(args.data))
4 changes: 2 additions & 2 deletions doc/frameworks/tensorflow/using_tf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ If your TFRecords are compressed, you can train on Gzipped TF Records by passing

.. code:: python

from sagemaker.session import s3_input
from sagemaker.inputs import TrainingInput

train_s3_input = s3_input('s3://bucket/path/to/training/data', compression='Gzip')
train_s3_input = TrainingInput('s3://bucket/path/to/training/data', compression='Gzip')
tf_estimator.fit(train_s3_input)


Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FactorizationMachinesModel,
)
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor # noqa: F401
from sagemaker.inputs import TrainingInput # noqa: F401
from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor # noqa: F401
from sagemaker.amazon.randomcutforest import ( # noqa: F401
RandomCutForest,
Expand All @@ -54,7 +55,6 @@
from sagemaker.session import Session # noqa: F401
from sagemaker.session import container_def, pipeline_container_def # noqa: F401
from sagemaker.session import production_variant # noqa: F401
from sagemaker.session import s3_input # noqa: F401
from sagemaker.session import get_execution_role # noqa: F401

from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
the container via a Unix-named pipe.

This argument can be overriden on a per-channel basis using
``sagemaker.session.s3_input.input_mode``.
``sagemaker.inputs.TrainingInput.input_mode``.

output_path (str): S3 location for saving the training result (model artifacts and
output files). If not specified, results are stored to a default bucket. If
Expand Down
9 changes: 5 additions & 4 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
from sagemaker.amazon.common import write_numpy_to_dense_tensor
from sagemaker.estimator import EstimatorBase, _TrainingJob
from sagemaker.inputs import FileSystemInput
from sagemaker.inputs import FileSystemInput, TrainingInput
from sagemaker.model import NEO_IMAGE_ACCOUNT
from sagemaker.session import s3_input
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
from sagemaker.xgboost.defaults import (
XGBOOST_1P_VERSIONS,
Expand Down Expand Up @@ -341,8 +340,10 @@ def data_channel(self):
return {self.channel: self.records_s3_input()}

def records_s3_input(self):
"""Return a s3_input to represent the training data"""
return s3_input(self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type)
"""Return a TrainingInput to represent the training data"""
return TrainingInput(
self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type
)


class FileSystemRecordSet(object):
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/automl/candidate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def fit(
self.name = candidate_name or self.name
running_jobs = {}

# convert inputs to s3_input format
# convert inputs to TrainingInput format
if isinstance(inputs, string_types):
if not inputs.startswith("s3://"):
inputs = self.sagemaker_session.upload_data(inputs, key_prefix="auto-ml-input-data")
Expand Down
33 changes: 18 additions & 15 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
validate_source_dir,
_region_supports_debugger,
)
from sagemaker.inputs import TrainingInput
from sagemaker.job import _Job
from sagemaker.local import LocalSession
from sagemaker.model import Model, NEO_ALLOWED_FRAMEWORKS
Expand All @@ -53,7 +54,6 @@
)
from sagemaker.predictor import Predictor
from sagemaker.session import Session
from sagemaker.session import s3_input
from sagemaker.transformer import Transformer
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, get_config_value
from sagemaker import vpc_utils
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
'Pipe' - Amazon SageMaker streams data directly from S3 to the
container via a Unix-named pipe. This argument can be overriden
on a per-channel basis using
``sagemaker.session.s3_input.input_mode``.
``sagemaker.inputs.TrainingInput.input_mode``.
output_path (str): S3 location for saving the training result (model
artifacts and output files). If not specified, results are
stored to a default bucket. If the bucket with the specific name
Expand Down Expand Up @@ -472,17 +472,18 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
model using the Amazon SageMaker hosting services.

Args:
inputs (str or dict or sagemaker.session.s3_input): Information
inputs (str or dict or sagemaker.inputs.TrainingInput): Information
about the training data. This can be one of three types:

* (str) the S3 location where training data is saved, or a file:// path in
local mode.
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple
channels for training data, you can specify a dict mapping channel names to
strings or :func:`~sagemaker.session.s3_input` objects.
* (sagemaker.session.s3_input) - channel configuration for S3 data sources that can
provide additional information as well as the path to the training dataset.
See :func:`sagemaker.session.s3_input` for full details.
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
* (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources
that can provide additional information as well as the path to the training
dataset.
See :func:`sagemaker.inputs.TrainingInput` for full details.
* (sagemaker.session.FileSystemInput) - channel configuration for
a file system data source that can provide additional information as well as
the path to the training dataset.
Expand Down Expand Up @@ -1020,10 +1021,10 @@ def start_new(cls, estimator, inputs, experiment_config):
train_args["metric_definitions"] = estimator.metric_definitions
train_args["experiment_config"] = experiment_config

if isinstance(inputs, s3_input):
if isinstance(inputs, TrainingInput):
if "InputMode" in inputs.config:
logging.debug(
"Selecting s3_input's input_mode (%s) for TrainingInputMode.",
"Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
inputs.config["InputMode"],
)
train_args["input_mode"] = inputs.config["InputMode"]
Expand Down Expand Up @@ -1191,7 +1192,7 @@ def __init__(
container via a Unix-named pipe.

This argument can be overriden on a per-channel basis using
``sagemaker.session.s3_input.input_mode``.
``sagemaker.inputs.TrainingInput.input_mode``.
output_path (str): S3 location for saving the training result (model
artifacts and output files). If not specified, results are
stored to a default bucket. If the bucket with the specific name
Expand Down Expand Up @@ -2028,7 +2029,7 @@ def _s3_uri_prefix(channel_name, s3_data):
channel_name:
s3_data:
"""
if isinstance(s3_data, s3_input):
if isinstance(s3_data, TrainingInput):
s3_uri = s3_data.config["DataSource"]["S3DataSource"]["S3Uri"]
else:
s3_uri = s3_data
Expand All @@ -2038,7 +2039,7 @@ def _s3_uri_prefix(channel_name, s3_data):


# E.g. 's3://bucket/data' would return 'bucket/data'.
# Also accepts other valid input types, e.g. dict and s3_input.
# Also accepts other valid input types, e.g. dict and TrainingInput.
def _s3_uri_without_prefix_from_input(input_data):
# Unpack an input_config object from a dict if a dict was passed in.
"""
Expand All @@ -2052,8 +2053,10 @@ def _s3_uri_without_prefix_from_input(input_data):
return response
if isinstance(input_data, str):
return _s3_uri_prefix("training", input_data)
if isinstance(input_data, s3_input):
if isinstance(input_data, TrainingInput):
return _s3_uri_prefix("training", input_data)
raise ValueError(
"Unrecognized type for S3 input data config - not str or s3_input: {}".format(input_data)
"Unrecognized type for S3 input data config - not str or TrainingInput: {}".format(
input_data
)
)
10 changes: 1 addition & 9 deletions src/sagemaker/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
"""Amazon SageMaker channel configurations for S3 data sources and file system data sources"""
from __future__ import absolute_import, print_function

import logging

FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"]

logger = logging.getLogger("sagemaker")


class s3_input(object):
class TrainingInput(object):
"""Amazon SageMaker channel configurations for S3 data sources.

Attributes:
Expand Down Expand Up @@ -80,10 +76,6 @@ def __init__(
this channel. See the SageMaker API documentation for more info:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
"""
logger.warning(
"'s3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2."
)

self.config = {
"DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}}
}
Expand Down
22 changes: 12 additions & 10 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
from abc import abstractmethod
from six import string_types

from sagemaker.inputs import FileSystemInput
from sagemaker.inputs import FileSystemInput, TrainingInput
from sagemaker.local import file_input
from sagemaker.session import s3_input


class _Job(object):
Expand Down Expand Up @@ -142,7 +141,7 @@ def _format_inputs_to_input_config(inputs, validate_uri=True):
input_dict = {}
if isinstance(inputs, string_types):
input_dict["training"] = _Job._format_string_uri_input(inputs, validate_uri)
elif isinstance(inputs, s3_input):
elif isinstance(inputs, TrainingInput):
input_dict["training"] = inputs
elif isinstance(inputs, file_input):
input_dict["training"] = inputs
Expand All @@ -154,7 +153,10 @@ def _format_inputs_to_input_config(inputs, validate_uri=True):
elif isinstance(inputs, FileSystemInput):
input_dict["training"] = inputs
else:
msg = "Cannot format input {}. Expecting one of str, dict, s3_input or FileSystemInput"
msg = (
"Cannot format input {}. Expecting one of str, dict, TrainingInput or "
"FileSystemInput"
)
raise ValueError(msg.format(inputs))

channels = [
Expand Down Expand Up @@ -193,7 +195,7 @@ def _format_string_uri_input(
target_attribute_name:
"""
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
s3_input_result = s3_input(
s3_input_result = TrainingInput(
uri_input,
content_type=content_type,
input_mode=input_mode,
Expand All @@ -209,19 +211,19 @@ def _format_string_uri_input(
'"file://"'.format(uri_input)
)
if isinstance(uri_input, str):
s3_input_result = s3_input(
s3_input_result = TrainingInput(
uri_input,
content_type=content_type,
input_mode=input_mode,
compression=compression,
target_attribute_name=target_attribute_name,
)
return s3_input_result
if isinstance(uri_input, (s3_input, file_input, FileSystemInput)):
if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)):
return uri_input

raise ValueError(
"Cannot format input {}. Expecting one of str, s3_input, file_input or "
"Cannot format input {}. Expecting one of str, TrainingInput, file_input or "
"FileSystemInput".format(uri_input)
)

Expand Down Expand Up @@ -270,7 +272,7 @@ def _format_model_uri_input(model_uri, validate_uri=True):
validate_uri:
"""
if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("s3://"):
return s3_input(
return TrainingInput(
model_uri,
input_mode="File",
distribution="FullyReplicated",
Expand All @@ -283,7 +285,7 @@ def _format_model_uri_input(model_uri, validate_uri=True):
'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://'
)
if isinstance(model_uri, string_types):
return s3_input(
return TrainingInput(
model_uri,
input_mode="File",
distribution="FullyReplicated",
Expand Down
2 changes: 0 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
import sagemaker.logs
from sagemaker import vpc_utils

# import s3_input for backward compatibility
from sagemaker.inputs import s3_input # noqa # pylint: disable=unused-import
from sagemaker.user_agent import prepend_user_agent
from sagemaker.utils import (
name_from_image,
Expand Down
16 changes: 8 additions & 8 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
from sagemaker.analytics import HyperparameterTuningJobAnalytics
from sagemaker.estimator import Framework
from sagemaker.inputs import TrainingInput
from sagemaker.job import _Job
from sagemaker.parameter import (
CategoricalParameter,
Expand All @@ -36,7 +37,6 @@
ParameterRange,
)
from sagemaker.session import Session
from sagemaker.session import s3_input
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base

AMAZON_ESTIMATOR_MODULE = "sagemaker"
Expand Down Expand Up @@ -377,13 +377,13 @@ def fit(
any of the following forms:

* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) -
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) -
If using multiple channels for training data, you can specify
a dict mapping channel names to strings or
:func:`~sagemaker.session.s3_input` objects.
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
provide additional information about the training dataset.
See :func:`sagemaker.session.s3_input` for full details.
:func:`~sagemaker.inputs.TrainingInput` objects.
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources
that can provide additional information about the training dataset.
See :func:`sagemaker.inputs.TrainingInput` for full details.
* (sagemaker.session.FileSystemInput) - channel configuration for
a file system data source that can provide additional information as well as
the path to the training dataset.
Expand Down Expand Up @@ -1500,10 +1500,10 @@ def _prepare_training_config(
training_config["input_mode"] = estimator.input_mode
training_config["metric_definitions"] = metric_definitions

if isinstance(inputs, s3_input):
if isinstance(inputs, TrainingInput):
if "InputMode" in inputs.config:
logging.debug(
"Selecting s3_input's input_mode (%s) for TrainingInputMode.",
"Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
inputs.config["InputMode"],
)
training_config["input_mode"] = inputs.config["InputMode"]
Expand Down
24 changes: 12 additions & 12 deletions src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=

* (str) - The S3 location where training data is saved.

* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
channels for training data, you can specify a dict mapping channel names to
strings or :func:`~sagemaker.session.s3_input` objects.
strings or :func:`~sagemaker.inputs.TrainingInput` objects.

* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
provide additional information about the training dataset. See
:func:`sagemaker.session.s3_input` for full details.
:func:`sagemaker.inputs.TrainingInput` for full details.

* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
Expand Down Expand Up @@ -208,13 +208,13 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
method of the associated estimator, as this can take any of the following forms:
* (str) - The S3 location where training data is saved.

* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
channels for training data, you can specify a dict mapping channel names to
strings or :func:`~sagemaker.session.s3_input` objects.
strings or :func:`~sagemaker.inputs.TrainingInput` objects.

* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
provide additional information about the training dataset. See
:func:`sagemaker.session.s3_input` for full details.
:func:`sagemaker.inputs.TrainingInput` for full details.

* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
Expand Down Expand Up @@ -258,13 +258,13 @@ def tuning_config(tuner, inputs, job_name=None, include_cls_metadata=False, mini

* (str) - The S3 location where training data is saved.

* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
channels for training data, you can specify a dict mapping channel names to
strings or :func:`~sagemaker.session.s3_input` objects.
strings or :func:`~sagemaker.inputs.TrainingInput` objects.

* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
provide additional information about the training dataset. See
:func:`sagemaker.session.s3_input` for full details.
:func:`sagemaker.inputs.TrainingInput` for full details.

* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
Expand Down
Loading