Skip to content

Commit fed2fe9

Browse files
chuyang-dengChuyang Denglaurenyu
authored
breaking: rename s3_input to TrainingInput (#1680)
* breaking: rename s3_input to TrainingInput * remove TrainingInput import from session * update docstring Co-authored-by: Chuyang Deng <[email protected]> Co-authored-by: Lauren Yu <[email protected]>
1 parent 8ec7f05 commit fed2fe9

File tree

17 files changed

+94
-104
lines changed

17 files changed

+94
-104
lines changed

bin/sagemaker-submit

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ if __name__ == '__main__':
5656
hyperparameters=hyperparameters,
5757
instance_count=args.instance_count,
5858
instance_type=args.instance_type)
59-
estimator.fit(sagemaker.s3_input(args.data))
59+
estimator.fit(sagemaker.TrainingInput(args.data))

doc/frameworks/tensorflow/using_tf.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,9 @@ If your TFRecords are compressed, you can train on Gzipped TF Records by passing
351351

352352
.. code:: python
353353
354-
from sagemaker.session import s3_input
354+
from sagemaker.inputs import TrainingInput
355355
356-
train_s3_input = s3_input('s3://bucket/path/to/training/data', compression='Gzip')
356+
train_s3_input = TrainingInput('s3://bucket/path/to/training/data', compression='Gzip')
357357
tf_estimator.fit(train_s3_input)
358358
359359

src/sagemaker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
FactorizationMachinesModel,
3030
)
3131
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor # noqa: F401
32+
from sagemaker.inputs import TrainingInput # noqa: F401
3233
from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor # noqa: F401
3334
from sagemaker.amazon.randomcutforest import ( # noqa: F401
3435
RandomCutForest,
@@ -54,7 +55,6 @@
5455
from sagemaker.session import Session # noqa: F401
5556
from sagemaker.session import container_def, pipeline_container_def # noqa: F401
5657
from sagemaker.session import production_variant # noqa: F401
57-
from sagemaker.session import s3_input # noqa: F401
5858
from sagemaker.session import get_execution_role # noqa: F401
5959

6060
from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401

src/sagemaker/algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
the container via a Unix-named pipe.
8888
8989
This argument can be overriden on a per-channel basis using
90-
``sagemaker.session.s3_input.input_mode``.
90+
``sagemaker.inputs.TrainingInput.input_mode``.
9191
9292
output_path (str): S3 location for saving the training result (model artifacts and
9393
output files). If not specified, results are stored to a default bucket. If

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@
2323
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2424
from sagemaker.amazon.common import write_numpy_to_dense_tensor
2525
from sagemaker.estimator import EstimatorBase, _TrainingJob
26-
from sagemaker.inputs import FileSystemInput
26+
from sagemaker.inputs import FileSystemInput, TrainingInput
2727
from sagemaker.model import NEO_IMAGE_ACCOUNT
28-
from sagemaker.session import s3_input
2928
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
3029
from sagemaker.xgboost.defaults import (
3130
XGBOOST_1P_VERSIONS,
@@ -341,8 +340,10 @@ def data_channel(self):
341340
return {self.channel: self.records_s3_input()}
342341

343342
def records_s3_input(self):
344-
"""Return a s3_input to represent the training data"""
345-
return s3_input(self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type)
343+
"""Return a TrainingInput to represent the training data"""
344+
return TrainingInput(
345+
self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type
346+
)
346347

347348

348349
class FileSystemRecordSet(object):

src/sagemaker/automl/candidate_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def fit(
103103
self.name = candidate_name or self.name
104104
running_jobs = {}
105105

106-
# convert inputs to s3_input format
106+
# convert inputs to TrainingInput format
107107
if isinstance(inputs, string_types):
108108
if not inputs.startswith("s3://"):
109109
inputs = self.sagemaker_session.upload_data(inputs, key_prefix="auto-ml-input-data")

src/sagemaker/estimator.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
validate_source_dir,
4141
_region_supports_debugger,
4242
)
43+
from sagemaker.inputs import TrainingInput
4344
from sagemaker.job import _Job
4445
from sagemaker.local import LocalSession
4546
from sagemaker.model import Model, NEO_ALLOWED_FRAMEWORKS
@@ -53,7 +54,6 @@
5354
)
5455
from sagemaker.predictor import Predictor
5556
from sagemaker.session import Session
56-
from sagemaker.session import s3_input
5757
from sagemaker.transformer import Transformer
5858
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, get_config_value
5959
from sagemaker import vpc_utils
@@ -127,7 +127,7 @@ def __init__(
127127
'Pipe' - Amazon SageMaker streams data directly from S3 to the
128128
container via a Unix-named pipe. This argument can be overriden
129129
on a per-channel basis using
130-
``sagemaker.session.s3_input.input_mode``.
130+
``sagemaker.inputs.TrainingInput.input_mode``.
131131
output_path (str): S3 location for saving the training result (model
132132
artifacts and output files). If not specified, results are
133133
stored to a default bucket. If the bucket with the specific name
@@ -472,17 +472,18 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
472472
model using the Amazon SageMaker hosting services.
473473
474474
Args:
475-
inputs (str or dict or sagemaker.session.s3_input): Information
475+
inputs (str or dict or sagemaker.inputs.TrainingInput): Information
476476
about the training data. This can be one of three types:
477477
478478
* (str) the S3 location where training data is saved, or a file:// path in
479479
local mode.
480-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple
480+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple
481481
channels for training data, you can specify a dict mapping channel names to
482-
strings or :func:`~sagemaker.session.s3_input` objects.
483-
* (sagemaker.session.s3_input) - channel configuration for S3 data sources that can
484-
provide additional information as well as the path to the training dataset.
485-
See :func:`sagemaker.session.s3_input` for full details.
482+
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
483+
* (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources
484+
that can provide additional information as well as the path to the training
485+
dataset.
486+
See :func:`sagemaker.inputs.TrainingInput` for full details.
486487
* (sagemaker.session.FileSystemInput) - channel configuration for
487488
a file system data source that can provide additional information as well as
488489
the path to the training dataset.
@@ -1020,10 +1021,10 @@ def start_new(cls, estimator, inputs, experiment_config):
10201021
train_args["metric_definitions"] = estimator.metric_definitions
10211022
train_args["experiment_config"] = experiment_config
10221023

1023-
if isinstance(inputs, s3_input):
1024+
if isinstance(inputs, TrainingInput):
10241025
if "InputMode" in inputs.config:
10251026
logging.debug(
1026-
"Selecting s3_input's input_mode (%s) for TrainingInputMode.",
1027+
"Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
10271028
inputs.config["InputMode"],
10281029
)
10291030
train_args["input_mode"] = inputs.config["InputMode"]
@@ -1191,7 +1192,7 @@ def __init__(
11911192
container via a Unix-named pipe.
11921193
11931194
This argument can be overriden on a per-channel basis using
1194-
``sagemaker.session.s3_input.input_mode``.
1195+
``sagemaker.inputs.TrainingInput.input_mode``.
11951196
output_path (str): S3 location for saving the training result (model
11961197
artifacts and output files). If not specified, results are
11971198
stored to a default bucket. If the bucket with the specific name
@@ -2028,7 +2029,7 @@ def _s3_uri_prefix(channel_name, s3_data):
20282029
channel_name:
20292030
s3_data:
20302031
"""
2031-
if isinstance(s3_data, s3_input):
2032+
if isinstance(s3_data, TrainingInput):
20322033
s3_uri = s3_data.config["DataSource"]["S3DataSource"]["S3Uri"]
20332034
else:
20342035
s3_uri = s3_data
@@ -2038,7 +2039,7 @@ def _s3_uri_prefix(channel_name, s3_data):
20382039

20392040

20402041
# E.g. 's3://bucket/data' would return 'bucket/data'.
2041-
# Also accepts other valid input types, e.g. dict and s3_input.
2042+
# Also accepts other valid input types, e.g. dict and TrainingInput.
20422043
def _s3_uri_without_prefix_from_input(input_data):
20432044
# Unpack an input_config object from a dict if a dict was passed in.
20442045
"""
@@ -2052,8 +2053,10 @@ def _s3_uri_without_prefix_from_input(input_data):
20522053
return response
20532054
if isinstance(input_data, str):
20542055
return _s3_uri_prefix("training", input_data)
2055-
if isinstance(input_data, s3_input):
2056+
if isinstance(input_data, TrainingInput):
20562057
return _s3_uri_prefix("training", input_data)
20572058
raise ValueError(
2058-
"Unrecognized type for S3 input data config - not str or s3_input: {}".format(input_data)
2059+
"Unrecognized type for S3 input data config - not str or TrainingInput: {}".format(
2060+
input_data
2061+
)
20592062
)

src/sagemaker/inputs.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,11 @@
1313
"""Amazon SageMaker channel configurations for S3 data sources and file system data sources"""
1414
from __future__ import absolute_import, print_function
1515

16-
import logging
17-
1816
FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
1917
FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"]
2018

21-
logger = logging.getLogger("sagemaker")
22-
2319

24-
class s3_input(object):
20+
class TrainingInput(object):
2521
"""Amazon SageMaker channel configurations for S3 data sources.
2622
2723
Attributes:
@@ -80,10 +76,6 @@ def __init__(
8076
this channel. See the SageMaker API documentation for more info:
8177
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
8278
"""
83-
logger.warning(
84-
"'s3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2."
85-
)
86-
8779
self.config = {
8880
"DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}}
8981
}

src/sagemaker/job.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
from abc import abstractmethod
1717
from six import string_types
1818

19-
from sagemaker.inputs import FileSystemInput
19+
from sagemaker.inputs import FileSystemInput, TrainingInput
2020
from sagemaker.local import file_input
21-
from sagemaker.session import s3_input
2221

2322

2423
class _Job(object):
@@ -142,7 +141,7 @@ def _format_inputs_to_input_config(inputs, validate_uri=True):
142141
input_dict = {}
143142
if isinstance(inputs, string_types):
144143
input_dict["training"] = _Job._format_string_uri_input(inputs, validate_uri)
145-
elif isinstance(inputs, s3_input):
144+
elif isinstance(inputs, TrainingInput):
146145
input_dict["training"] = inputs
147146
elif isinstance(inputs, file_input):
148147
input_dict["training"] = inputs
@@ -154,7 +153,10 @@ def _format_inputs_to_input_config(inputs, validate_uri=True):
154153
elif isinstance(inputs, FileSystemInput):
155154
input_dict["training"] = inputs
156155
else:
157-
msg = "Cannot format input {}. Expecting one of str, dict, s3_input or FileSystemInput"
156+
msg = (
157+
"Cannot format input {}. Expecting one of str, dict, TrainingInput or "
158+
"FileSystemInput"
159+
)
158160
raise ValueError(msg.format(inputs))
159161

160162
channels = [
@@ -193,7 +195,7 @@ def _format_string_uri_input(
193195
target_attribute_name:
194196
"""
195197
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
196-
s3_input_result = s3_input(
198+
s3_input_result = TrainingInput(
197199
uri_input,
198200
content_type=content_type,
199201
input_mode=input_mode,
@@ -209,19 +211,19 @@ def _format_string_uri_input(
209211
'"file://"'.format(uri_input)
210212
)
211213
if isinstance(uri_input, str):
212-
s3_input_result = s3_input(
214+
s3_input_result = TrainingInput(
213215
uri_input,
214216
content_type=content_type,
215217
input_mode=input_mode,
216218
compression=compression,
217219
target_attribute_name=target_attribute_name,
218220
)
219221
return s3_input_result
220-
if isinstance(uri_input, (s3_input, file_input, FileSystemInput)):
222+
if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)):
221223
return uri_input
222224

223225
raise ValueError(
224-
"Cannot format input {}. Expecting one of str, s3_input, file_input or "
226+
"Cannot format input {}. Expecting one of str, TrainingInput, file_input or "
225227
"FileSystemInput".format(uri_input)
226228
)
227229

@@ -270,7 +272,7 @@ def _format_model_uri_input(model_uri, validate_uri=True):
270272
validate_uri:
271273
"""
272274
if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("s3://"):
273-
return s3_input(
275+
return TrainingInput(
274276
model_uri,
275277
input_mode="File",
276278
distribution="FullyReplicated",
@@ -283,7 +285,7 @@ def _format_model_uri_input(model_uri, validate_uri=True):
283285
'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://'
284286
)
285287
if isinstance(model_uri, string_types):
286-
return s3_input(
288+
return TrainingInput(
287289
model_uri,
288290
input_mode="File",
289291
distribution="FullyReplicated",

src/sagemaker/session.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import sagemaker.logs
3030
from sagemaker import vpc_utils
3131

32-
# import s3_input for backward compatibility
33-
from sagemaker.inputs import s3_input # noqa # pylint: disable=unused-import
3432
from sagemaker.user_agent import prepend_user_agent
3533
from sagemaker.utils import (
3634
name_from_image,

src/sagemaker/tuner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2929
from sagemaker.analytics import HyperparameterTuningJobAnalytics
3030
from sagemaker.estimator import Framework
31+
from sagemaker.inputs import TrainingInput
3132
from sagemaker.job import _Job
3233
from sagemaker.parameter import (
3334
CategoricalParameter,
@@ -36,7 +37,6 @@
3637
ParameterRange,
3738
)
3839
from sagemaker.session import Session
39-
from sagemaker.session import s3_input
4040
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
4141

4242
AMAZON_ESTIMATOR_MODULE = "sagemaker"
@@ -377,13 +377,13 @@ def fit(
377377
any of the following forms:
378378
379379
* (str) - The S3 location where training data is saved.
380-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) -
380+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) -
381381
If using multiple channels for training data, you can specify
382382
a dict mapping channel names to strings or
383-
:func:`~sagemaker.session.s3_input` objects.
384-
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
385-
provide additional information about the training dataset.
386-
See :func:`sagemaker.session.s3_input` for full details.
383+
:func:`~sagemaker.inputs.TrainingInput` objects.
384+
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources
385+
that can provide additional information about the training dataset.
386+
See :func:`sagemaker.inputs.TrainingInput` for full details.
387387
* (sagemaker.session.FileSystemInput) - channel configuration for
388388
a file system data source that can provide additional information as well as
389389
the path to the training dataset.
@@ -1500,10 +1500,10 @@ def _prepare_training_config(
15001500
training_config["input_mode"] = estimator.input_mode
15011501
training_config["metric_definitions"] = metric_definitions
15021502

1503-
if isinstance(inputs, s3_input):
1503+
if isinstance(inputs, TrainingInput):
15041504
if "InputMode" in inputs.config:
15051505
logging.debug(
1506-
"Selecting s3_input's input_mode (%s) for TrainingInputMode.",
1506+
"Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
15071507
inputs.config["InputMode"],
15081508
)
15091509
training_config["input_mode"] = inputs.config["InputMode"]

src/sagemaker/workflow/airflow.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
115115
116116
* (str) - The S3 location where training data is saved.
117117
118-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
118+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
119119
channels for training data, you can specify a dict mapping channel names to
120-
strings or :func:`~sagemaker.session.s3_input` objects.
120+
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
121121
122-
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
122+
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
123123
provide additional information about the training dataset. See
124-
:func:`sagemaker.session.s3_input` for full details.
124+
:func:`sagemaker.inputs.TrainingInput` for full details.
125125
126126
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
127127
Amazon :class:~`Record` objects serialized and stored in S3.
@@ -208,13 +208,13 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
208208
method of the associated estimator, as this can take any of the following forms:
209209
* (str) - The S3 location where training data is saved.
210210
211-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
211+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
212212
channels for training data, you can specify a dict mapping channel names to
213-
strings or :func:`~sagemaker.session.s3_input` objects.
213+
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
214214
215-
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
215+
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
216216
provide additional information about the training dataset. See
217-
:func:`sagemaker.session.s3_input` for full details.
217+
:func:`sagemaker.inputs.TrainingInput` for full details.
218218
219219
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
220220
Amazon :class:~`Record` objects serialized and stored in S3.
@@ -258,13 +258,13 @@ def tuning_config(tuner, inputs, job_name=None, include_cls_metadata=False, mini
258258
259259
* (str) - The S3 location where training data is saved.
260260
261-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
261+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
262262
channels for training data, you can specify a dict mapping channel names to
263-
strings or :func:`~sagemaker.session.s3_input` objects.
263+
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
264264
265-
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
265+
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
266266
provide additional information about the training dataset. See
267-
:func:`sagemaker.session.s3_input` for full details.
267+
:func:`sagemaker.inputs.TrainingInput` for full details.
268268
269269
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
270270
Amazon :class:~`Record` objects serialized and stored in S3.

0 commit comments

Comments
 (0)