Skip to content

Commit 495e289

Browse files
author
Chuyang Deng
committed
breaking: rename s3_input to TrainingInput
1 parent 1487b22 commit 495e289

File tree

17 files changed

+95
-100
lines changed

17 files changed

+95
-100
lines changed

bin/sagemaker-submit

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ if __name__ == '__main__':
5656
hyperparameters=hyperparameters,
5757
instance_count=args.instance_count,
5858
instance_type=args.instance_type)
59-
estimator.fit(sagemaker.s3_input(args.data))
59+
estimator.fit(sagemaker.TrainingInput(args.data))

doc/frameworks/tensorflow/using_tf.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,9 @@ If your TFRecords are compressed, you can train on Gzipped TF Records by passing
329329

330330
.. code:: python
331331
332-
from sagemaker.session import s3_input
332+
from sagemaker.session import TrainingInput
333333
334-
train_s3_input = s3_input('s3://bucket/path/to/training/data', compression='Gzip')
334+
train_s3_input = TrainingInput('s3://bucket/path/to/training/data', compression='Gzip')
335335
tf_estimator.fit(train_s3_input)
336336
337337

src/sagemaker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from sagemaker.session import Session # noqa: F401
5555
from sagemaker.session import container_def, pipeline_container_def # noqa: F401
5656
from sagemaker.session import production_variant # noqa: F401
57-
from sagemaker.session import s3_input # noqa: F401
57+
from sagemaker.session import TrainingInput # noqa: F401
5858
from sagemaker.session import get_execution_role # noqa: F401
5959

6060
from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401

src/sagemaker/algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
the container via a Unix-named pipe.
8686
8787
This argument can be overriden on a per-channel basis using
88-
``sagemaker.session.s3_input.input_mode``.
88+
``sagemaker.session.TrainingInput.input_mode``.
8989
9090
output_path (str): S3 location for saving the training result (model artifacts and
9191
output files). If not specified, results are stored to a default bucket. If

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sagemaker.estimator import EstimatorBase, _TrainingJob
2626
from sagemaker.inputs import FileSystemInput
2727
from sagemaker.model import NEO_IMAGE_ACCOUNT
28-
from sagemaker.session import s3_input
28+
from sagemaker.session import TrainingInput
2929
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
3030
from sagemaker.xgboost.defaults import (
3131
XGBOOST_1P_VERSIONS,
@@ -341,8 +341,10 @@ def data_channel(self):
341341
return {self.channel: self.records_s3_input()}
342342

343343
def records_s3_input(self):
344-
"""Return a s3_input to represent the training data"""
345-
return s3_input(self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type)
344+
"""Return a TrainingInput to represent the training data"""
345+
return TrainingInput(
346+
self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type
347+
)
346348

347349

348350
class FileSystemRecordSet(object):

src/sagemaker/automl/candidate_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def fit(
103103
self.name = candidate_name or self.name
104104
running_jobs = {}
105105

106-
# convert inputs to s3_input format
106+
# convert inputs to TrainingInput format
107107
if isinstance(inputs, string_types):
108108
if not inputs.startswith("s3://"):
109109
inputs = self.sagemaker_session.upload_data(inputs, key_prefix="auto-ml-input-data")

src/sagemaker/estimator.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454
from sagemaker.predictor import Predictor
5555
from sagemaker.session import Session
56-
from sagemaker.session import s3_input
56+
from sagemaker.session import TrainingInput
5757
from sagemaker.transformer import Transformer
5858
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, get_config_value
5959
from sagemaker import vpc_utils
@@ -127,7 +127,7 @@ def __init__(
127127
'Pipe' - Amazon SageMaker streams data directly from S3 to the
128128
container via a Unix-named pipe. This argument can be overriden
129129
on a per-channel basis using
130-
``sagemaker.session.s3_input.input_mode``.
130+
``sagemaker.session.TrainingInput.input_mode``.
131131
output_path (str): S3 location for saving the training result (model
132132
artifacts and output files). If not specified, results are
133133
stored to a default bucket. If the bucket with the specific name
@@ -472,17 +472,18 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
472472
model using the Amazon SageMaker hosting services.
473473
474474
Args:
475-
inputs (str or dict or sagemaker.session.s3_input): Information
475+
inputs (str or dict or sagemaker.session.TrainingInput): Information
476476
about the training data. This can be one of three types:
477477
478478
* (str) the S3 location where training data is saved, or a file:// path in
479479
local mode.
480-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple
480+
* (dict[str, str] or dict[str, sagemaker.session.TrainingInput]) If using multiple
481481
channels for training data, you can specify a dict mapping channel names to
482-
strings or :func:`~sagemaker.session.s3_input` objects.
483-
* (sagemaker.session.s3_input) - channel configuration for S3 data sources that can
484-
provide additional information as well as the path to the training dataset.
485-
See :func:`sagemaker.session.s3_input` for full details.
482+
strings or :func:`~sagemaker.session.TrainingInput` objects.
483+
* (sagemaker.session.TrainingInput) - channel configuration for S3 data sources
484+
that can provide additional information as well as the path to the training
485+
dataset.
486+
See :func:`sagemaker.session.TrainingInput` for full details.
486487
* (sagemaker.session.FileSystemInput) - channel configuration for
487488
a file system data source that can provide additional information as well as
488489
the path to the training dataset.
@@ -1020,10 +1021,10 @@ def start_new(cls, estimator, inputs, experiment_config):
10201021
train_args["metric_definitions"] = estimator.metric_definitions
10211022
train_args["experiment_config"] = experiment_config
10221023

1023-
if isinstance(inputs, s3_input):
1024+
if isinstance(inputs, TrainingInput):
10241025
if "InputMode" in inputs.config:
10251026
logging.debug(
1026-
"Selecting s3_input's input_mode (%s) for TrainingInputMode.",
1027+
"Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
10271028
inputs.config["InputMode"],
10281029
)
10291030
train_args["input_mode"] = inputs.config["InputMode"]
@@ -1191,7 +1192,7 @@ def __init__(
11911192
container via a Unix-named pipe.
11921193
11931194
This argument can be overriden on a per-channel basis using
1194-
``sagemaker.session.s3_input.input_mode``.
1195+
``sagemaker.session.TrainingInput.input_mode``.
11951196
output_path (str): S3 location for saving the training result (model
11961197
artifacts and output files). If not specified, results are
11971198
stored to a default bucket. If the bucket with the specific name
@@ -2028,7 +2029,7 @@ def _s3_uri_prefix(channel_name, s3_data):
20282029
channel_name:
20292030
s3_data:
20302031
"""
2031-
if isinstance(s3_data, s3_input):
2032+
if isinstance(s3_data, TrainingInput):
20322033
s3_uri = s3_data.config["DataSource"]["S3DataSource"]["S3Uri"]
20332034
else:
20342035
s3_uri = s3_data
@@ -2038,7 +2039,7 @@ def _s3_uri_prefix(channel_name, s3_data):
20382039

20392040

20402041
# E.g. 's3://bucket/data' would return 'bucket/data'.
2041-
# Also accepts other valid input types, e.g. dict and s3_input.
2042+
# Also accepts other valid input types, e.g. dict and TrainingInput.
20422043
def _s3_uri_without_prefix_from_input(input_data):
20432044
# Unpack an input_config object from a dict if a dict was passed in.
20442045
"""
@@ -2052,8 +2053,10 @@ def _s3_uri_without_prefix_from_input(input_data):
20522053
return response
20532054
if isinstance(input_data, str):
20542055
return _s3_uri_prefix("training", input_data)
2055-
if isinstance(input_data, s3_input):
2056+
if isinstance(input_data, TrainingInput):
20562057
return _s3_uri_prefix("training", input_data)
20572058
raise ValueError(
2058-
"Unrecognized type for S3 input data config - not str or s3_input: {}".format(input_data)
2059+
"Unrecognized type for S3 input data config - not str or TrainingInput: {}".format(
2060+
input_data
2061+
)
20592062
)

src/sagemaker/inputs.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,11 @@
1313
"""Amazon SageMaker channel configurations for S3 data sources and file system data sources"""
1414
from __future__ import absolute_import, print_function
1515

16-
import logging
17-
1816
FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"]
1917
FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"]
2018

21-
logger = logging.getLogger("sagemaker")
22-
2319

24-
class s3_input(object):
20+
class TrainingInput(object):
2521
"""Amazon SageMaker channel configurations for S3 data sources.
2622
2723
Attributes:
@@ -80,10 +76,6 @@ def __init__(
8076
this channel. See the SageMaker API documentation for more info:
8177
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
8278
"""
83-
logger.warning(
84-
"'s3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2."
85-
)
86-
8779
self.config = {
8880
"DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}}
8981
}

src/sagemaker/job.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from sagemaker.inputs import FileSystemInput
2020
from sagemaker.local import file_input
21-
from sagemaker.session import s3_input
21+
from sagemaker.session import TrainingInput
2222

2323

2424
class _Job(object):
@@ -144,7 +144,7 @@ def _format_inputs_to_input_config(inputs, validate_uri=True):
144144
input_dict = {}
145145
if isinstance(inputs, string_types):
146146
input_dict["training"] = _Job._format_string_uri_input(inputs, validate_uri)
147-
elif isinstance(inputs, s3_input):
147+
elif isinstance(inputs, TrainingInput):
148148
input_dict["training"] = inputs
149149
elif isinstance(inputs, file_input):
150150
input_dict["training"] = inputs
@@ -156,7 +156,10 @@ def _format_inputs_to_input_config(inputs, validate_uri=True):
156156
elif isinstance(inputs, FileSystemInput):
157157
input_dict["training"] = inputs
158158
else:
159-
msg = "Cannot format input {}. Expecting one of str, dict, s3_input or FileSystemInput"
159+
msg = (
160+
"Cannot format input {}. Expecting one of str, dict, TrainingInput or "
161+
"FileSystemInput"
162+
)
160163
raise ValueError(msg.format(inputs))
161164

162165
channels = [
@@ -195,7 +198,7 @@ def _format_string_uri_input(
195198
target_attribute_name:
196199
"""
197200
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
198-
s3_input_result = s3_input(
201+
s3_input_result = TrainingInput(
199202
uri_input,
200203
content_type=content_type,
201204
input_mode=input_mode,
@@ -211,19 +214,19 @@ def _format_string_uri_input(
211214
'"file://"'.format(uri_input)
212215
)
213216
if isinstance(uri_input, str):
214-
s3_input_result = s3_input(
217+
s3_input_result = TrainingInput(
215218
uri_input,
216219
content_type=content_type,
217220
input_mode=input_mode,
218221
compression=compression,
219222
target_attribute_name=target_attribute_name,
220223
)
221224
return s3_input_result
222-
if isinstance(uri_input, (s3_input, file_input, FileSystemInput)):
225+
if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)):
223226
return uri_input
224227

225228
raise ValueError(
226-
"Cannot format input {}. Expecting one of str, s3_input, file_input or "
229+
"Cannot format input {}. Expecting one of str, TrainingInput, file_input or "
227230
"FileSystemInput".format(uri_input)
228231
)
229232

@@ -272,7 +275,7 @@ def _format_model_uri_input(model_uri, validate_uri=True):
272275
validate_uri:
273276
"""
274277
if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("s3://"):
275-
return s3_input(
278+
return TrainingInput(
276279
model_uri,
277280
input_mode="File",
278281
distribution="FullyReplicated",
@@ -285,7 +288,7 @@ def _format_model_uri_input(model_uri, validate_uri=True):
285288
'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://'
286289
)
287290
if isinstance(model_uri, string_types):
288-
return s3_input(
291+
return TrainingInput(
289292
model_uri,
290293
input_mode="File",
291294
distribution="FullyReplicated",

src/sagemaker/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
import sagemaker.logs
3030
from sagemaker import vpc_utils
3131

32-
# import s3_input for backward compatibility
33-
from sagemaker.inputs import s3_input # noqa # pylint: disable=unused-import
32+
# import TrainingInput for backward compatibility
33+
from sagemaker.inputs import TrainingInput # noqa # pylint: disable=unused-import
3434
from sagemaker.user_agent import prepend_user_agent
3535
from sagemaker.utils import (
3636
name_from_image,

src/sagemaker/tuner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
ParameterRange,
3737
)
3838
from sagemaker.session import Session
39-
from sagemaker.session import s3_input
39+
from sagemaker.session import TrainingInput
4040
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
4141

4242
AMAZON_ESTIMATOR_MODULE = "sagemaker"
@@ -377,13 +377,13 @@ def fit(
377377
any of the following forms:
378378
379379
* (str) - The S3 location where training data is saved.
380-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) -
380+
* (dict[str, str] or dict[str, sagemaker.session.TrainingInput]) -
381381
If using multiple channels for training data, you can specify
382382
a dict mapping channel names to strings or
383-
:func:`~sagemaker.session.s3_input` objects.
384-
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
385-
provide additional information about the training dataset.
386-
See :func:`sagemaker.session.s3_input` for full details.
383+
:func:`~sagemaker.session.TrainingInput` objects.
384+
* (sagemaker.session.TrainingInput) - Channel configuration for S3 data sources
385+
that can provide additional information about the training dataset.
386+
See :func:`sagemaker.session.TrainingInput` for full details.
387387
* (sagemaker.session.FileSystemInput) - channel configuration for
388388
a file system data source that can provide additional information as well as
389389
the path to the training dataset.
@@ -1500,10 +1500,10 @@ def _prepare_training_config(
15001500
training_config["input_mode"] = estimator.input_mode
15011501
training_config["metric_definitions"] = metric_definitions
15021502

1503-
if isinstance(inputs, s3_input):
1503+
if isinstance(inputs, TrainingInput):
15041504
if "InputMode" in inputs.config:
15051505
logging.debug(
1506-
"Selecting s3_input's input_mode (%s) for TrainingInputMode.",
1506+
"Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
15071507
inputs.config["InputMode"],
15081508
)
15091509
training_config["input_mode"] = inputs.config["InputMode"]

src/sagemaker/workflow/airflow.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
115115
116116
* (str) - The S3 location where training data is saved.
117117
118-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
118+
* (dict[str, str] or dict[str, sagemaker.session.TrainingInput]) - If using multiple
119119
channels for training data, you can specify a dict mapping channel names to
120-
strings or :func:`~sagemaker.session.s3_input` objects.
120+
strings or :func:`~sagemaker.session.TrainingInput` objects.
121121
122-
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
122+
* (sagemaker.session.TrainingInput) - Channel configuration for S3 data sources that can
123123
provide additional information about the training dataset. See
124-
:func:`sagemaker.session.s3_input` for full details.
124+
:func:`sagemaker.session.TrainingInput` for full details.
125125
126126
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
127127
Amazon :class:~`Record` objects serialized and stored in S3.
@@ -208,13 +208,13 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
208208
method of the associated estimator, as this can take any of the following forms:
209209
* (str) - The S3 location where training data is saved.
210210
211-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
211+
* (dict[str, str] or dict[str, sagemaker.session.TrainingInput]) - If using multiple
212212
channels for training data, you can specify a dict mapping channel names to
213-
strings or :func:`~sagemaker.session.s3_input` objects.
213+
strings or :func:`~sagemaker.session.TrainingInput` objects.
214214
215-
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
215+
* (sagemaker.session.TrainingInput) - Channel configuration for S3 data sources that can
216216
provide additional information about the training dataset. See
217-
:func:`sagemaker.session.s3_input` for full details.
217+
:func:`sagemaker.session.TrainingInput` for full details.
218218
219219
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
220220
Amazon :class:~`Record` objects serialized and stored in S3.
@@ -258,13 +258,13 @@ def tuning_config(tuner, inputs, job_name=None, include_cls_metadata=False, mini
258258
259259
* (str) - The S3 location where training data is saved.
260260
261-
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
261+
* (dict[str, str] or dict[str, sagemaker.session.TrainingInput]) - If using multiple
262262
channels for training data, you can specify a dict mapping channel names to
263-
strings or :func:`~sagemaker.session.s3_input` objects.
263+
strings or :func:`~sagemaker.session.TrainingInput` objects.
264264
265-
* (sagemaker.session.s3_input) - Channel configuration for S3 data sources that can
265+
* (sagemaker.session.TrainingInput) - Channel configuration for S3 data sources that can
266266
provide additional information about the training dataset. See
267-
:func:`sagemaker.session.s3_input` for full details.
267+
:func:`sagemaker.session.TrainingInput` for full details.
268268
269269
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
270270
Amazon :class:~`Record` objects serialized and stored in S3.

0 commit comments

Comments
 (0)