Skip to content

Commit 13e4040

Browse files
ishaaqjesterhazy
authored andcommitted
Add AugmentedManifestFile & ShuffleConfig support (#528)
1 parent c042a6c commit 13e4040

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ CHANGELOG
55
1.16.2.dev
66
==========
77

8+
* feature: Add support for AugmentedManifestFile and ShuffleConfig
89
* bug-fix: add version bound for requests module to avoid version conflicts between docker-compose and docker-py
910
* bug-fix: Remove unnecessary dependency tensorflow
1011
* doc-fix: Change ``distribution`` to ``distributions``

src/sagemaker/session.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ class s3_input(object):
12221222

12231223
def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
12241224
content_type=None, record_wrapping=None, s3_data_type='S3Prefix',
1225-
input_mode=None):
1225+
input_mode=None, attribute_names=None, shuffle_config=None):
12261226
"""Create a definition for input data used by an SageMaker training job.
12271227
12281228
See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters.
@@ -1234,17 +1234,23 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
12341234
compression (str): Valid values: 'Gzip', None (default: None). This is used only in Pipe input mode.
12351235
content_type (str): MIME type of the input data (default: None).
12361236
record_wrapping (str): Valid values: 'RecordIO' (default: None).
1237-
s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines
1238-
a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will
1239-
be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing
1240-
each s3 object to train on. The Manifest file format is described in the SageMaker API documentation:
1241-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
1237+
s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. If 'S3Prefix',
1238+
``s3_data`` defines a prefix of s3 objects to train on. All objects with s3 keys beginning with
1239+
``s3_data`` will be used to train. If 'ManifestFile' or 'AugmentedManifestFile', then ``s3_data``
1240+
defines a single s3 manifest file or augmented manifest file (respectively), listing the s3 data to
1241+
train on. Both the ManifestFile and AugmentedManifestFile formats are described in the SageMaker API
1242+
documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
12421243
input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will
12431244
use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore
12441245
that setting if this parameter is set.
12451246
* None - Amazon SageMaker will use the input mode specified in the ``Estimator``.
12461247
* 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory.
12471248
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
1249+
attribute_names (list[str]): A list of one or more attribute names to use that are found in a specified
1250+
AugmentedManifestFile.
1251+
shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on this channel. See the
1252+
SageMaker API documentation for more info:
1253+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
12481254
"""
12491255
self.config = {
12501256
'DataSource': {
@@ -1264,6 +1270,24 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
12641270
self.config['RecordWrapperType'] = record_wrapping
12651271
if input_mode is not None:
12661272
self.config['InputMode'] = input_mode
1273+
if attribute_names is not None:
1274+
self.config['DataSource']['S3DataSource']['AttributeNames'] = attribute_names
1275+
if shuffle_config is not None:
1276+
self.config['ShuffleConfig'] = {'Seed': shuffle_config.seed}
1277+
1278+
1279+
class ShuffleConfig(object):
1280+
"""
1281+
Used to configure channel shuffling using a seed. See SageMaker
1282+
documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
1283+
"""
1284+
def __init__(self, seed):
1285+
"""
1286+
Create a ShuffleConfig.
1287+
Args:
1288+
seed (long): the long value used to seed the shuffled sequence.
1289+
"""
1290+
self.seed = seed
12671291

12681292

12691293
class ModelContainer(object):

tests/unit/test_estimator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob
2525
from sagemaker.model import FrameworkModel
2626
from sagemaker.predictor import RealTimePredictor
27-
from sagemaker.session import s3_input
27+
from sagemaker.session import s3_input, ShuffleConfig
2828
from sagemaker.transformer import Transformer
2929

3030
MODEL_DATA = "s3://bucket/model.tar.gz"
@@ -277,6 +277,30 @@ def test_invalid_custom_code_bucket(sagemaker_session):
277277
assert "Expecting 's3' scheme" in str(error)
278278

279279

280+
def test_augmented_manifest(sagemaker_session):
281+
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
282+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
283+
enable_cloudwatch_metrics=True)
284+
fw.fit(inputs=s3_input('s3://mybucket/train_manifest', s3_data_type='AugmentedManifestFile',
285+
attribute_names=['foo', 'bar']))
286+
287+
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
288+
s3_data_source = train_kwargs['input_config'][0]['DataSource']['S3DataSource']
289+
assert s3_data_source['S3Uri'] == 's3://mybucket/train_manifest'
290+
assert s3_data_source['S3DataType'] == 'AugmentedManifestFile'
291+
assert s3_data_source['AttributeNames'] == ['foo', 'bar']
292+
293+
294+
def test_shuffle_config(sagemaker_session):
295+
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
296+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
297+
enable_cloudwatch_metrics=True)
298+
fw.fit(inputs=s3_input('s3://mybucket/train_manifest', shuffle_config=ShuffleConfig(100)))
299+
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
300+
channel = train_kwargs['input_config'][0]
301+
assert channel['ShuffleConfig']['Seed'] == 100
302+
303+
280304
BASE_HP = {
281305
'sagemaker_program': json.dumps(SCRIPT_NAME),
282306
'sagemaker_submit_directory': json.dumps('s3://mybucket/{}/source/sourcedir.tar.gz'.format(JOB_NAME)),

0 commit comments

Comments
 (0)