@@ -1222,7 +1222,7 @@ class s3_input(object):
1222
1222
1223
1223
def __init__ (self , s3_data , distribution = 'FullyReplicated' , compression = None ,
1224
1224
content_type = None , record_wrapping = None , s3_data_type = 'S3Prefix' ,
1225
- input_mode = None ):
1225
+ input_mode = None , attribute_names = None , shuffle_config = None ):
1226
1226
"""Create a definition for input data used by an SageMaker training job.
1227
1227
1228
1228
See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters.
@@ -1234,17 +1234,23 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
1234
1234
compression (str): Valid values: 'Gzip', None (default: None). This is used only in Pipe input mode.
1235
1235
content_type (str): MIME type of the input data (default: None).
1236
1236
record_wrapping (str): Valid values: 'RecordIO' (default: None).
1237
- s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines
1238
- a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will
1239
- be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing
1240
- each s3 object to train on. The Manifest file format is described in the SageMaker API documentation:
1241
- https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
1237
+ s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. If 'S3Prefix',
1238
+ ``s3_data`` defines a prefix of s3 objects to train on. All objects with s3 keys beginning with
1239
+ ``s3_data`` will be used to train. If 'ManifestFile' or 'AugmentedManifestFile', then ``s3_data``
1240
+ defines a single s3 manifest file or augmented manifest file (respectively), listing the s3 data to
1241
+ train on. Both the ManifestFile and AugmentedManifestFile formats are described in the SageMaker API
1242
+ documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
1242
1243
input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will
1243
1244
use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore
1244
1245
that setting if this parameter is set.
1245
1246
* None - Amazon SageMaker will use the input mode specified in the ``Estimator``.
1246
1247
* 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory.
1247
1248
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
1249
+ attribute_names (list[str]): A list of one or more attribute names to use that are found in a specified
1250
+ AugmentedManifestFile.
1251
+ shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on this channel. See the
1252
+ SageMaker API documentation for more info:
1253
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
1248
1254
"""
1249
1255
self .config = {
1250
1256
'DataSource' : {
@@ -1264,6 +1270,24 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
1264
1270
self .config ['RecordWrapperType' ] = record_wrapping
1265
1271
if input_mode is not None :
1266
1272
self .config ['InputMode' ] = input_mode
1273
+ if attribute_names is not None :
1274
+ self .config ['DataSource' ]['S3DataSource' ]['AttributeNames' ] = attribute_names
1275
+ if shuffle_config is not None :
1276
+ self .config ['ShuffleConfig' ] = {'Seed' : shuffle_config .seed }
1277
+
1278
+
1279
+ class ShuffleConfig (object ):
1280
+ """
1281
+ Used to configure channel shuffling using a seed. See SageMaker
1282
+ documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
1283
+ """
1284
+ def __init__ (self , seed ):
1285
+ """
1286
+ Create a ShuffleConfig.
1287
+ Args:
1288
+ seed (long): the long value used to seed the shuffled sequence.
1289
+ """
1290
+ self .seed = seed
1267
1291
1268
1292
1269
1293
class ModelContainer (object ):
0 commit comments