@@ -133,14 +133,18 @@ def data_location(self, data_location: str):
133
133
134
134
if not data_location .startswith ("s3://" ):
135
135
raise ValueError (
136
- 'Expecting an S3 URL beginning with "s3://". Got "{}"' .format (data_location )
136
+ 'Expecting an S3 URL beginning with "s3://". Got "{}"' .format (
137
+ data_location
138
+ )
137
139
)
138
140
if data_location [- 1 ] != "/" :
139
141
data_location = data_location + "/"
140
142
self ._data_location = data_location
141
143
142
144
@classmethod
143
- def _prepare_init_params_from_job_description (cls , job_details , model_channel_name = None ):
145
+ def _prepare_init_params_from_job_description (
146
+ cls , job_details , model_channel_name = None
147
+ ):
144
148
"""Convert the job description to init params that can be handled by the class constructor.
145
149
146
150
Args:
@@ -168,7 +172,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
168
172
del init_params ["image_uri" ]
169
173
return init_params
170
174
171
- def prepare_workflow_for_training (self , records = None , mini_batch_size = None , job_name = None ):
175
+ def prepare_workflow_for_training (
176
+ self , records = None , mini_batch_size = None , job_name = None
177
+ ):
172
178
"""Calls _prepare_for_training. Used when setting up a workflow.
173
179
174
180
Args:
@@ -194,7 +200,9 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
194
200
specified, one is generated, using the base name given to the
195
201
constructor if applicable.
196
202
"""
197
- super (AmazonAlgorithmEstimatorBase , self )._prepare_for_training (job_name = job_name )
203
+ super (AmazonAlgorithmEstimatorBase , self )._prepare_for_training (
204
+ job_name = job_name
205
+ )
198
206
199
207
feature_dim = None
200
208
@@ -260,7 +268,9 @@ def fit(
260
268
will be unassociated.
261
269
* `TrialComponentDisplayName` is used for display in Studio.
262
270
"""
263
- self ._prepare_for_training (records , job_name = job_name , mini_batch_size = mini_batch_size )
271
+ self ._prepare_for_training (
272
+ records , job_name = job_name , mini_batch_size = mini_batch_size
273
+ )
264
274
265
275
experiment_config = check_and_get_run_experiment_config (experiment_config )
266
276
self .latest_training_job = _TrainingJob .start_new (
@@ -269,12 +279,14 @@ def fit(
269
279
if wait :
270
280
self .latest_training_job .wait (logs = logs )
271
281
272
- def record_set (self ,
273
- train ,
274
- labels = None ,
275
- channel = "train" ,
276
- encrypt = False ,
277
- distribution = "ShardedByS3Key" ):
282
+ def record_set (
283
+ self ,
284
+ train ,
285
+ labels = None ,
286
+ channel = "train" ,
287
+ encrypt = False ,
288
+ distribution = "ShardedByS3Key" ,
289
+ ):
278
290
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
279
291
280
292
For the 2D ``ndarray`` ``train``, each row is converted to a
@@ -311,7 +323,9 @@ def record_set(self,
311
323
)
312
324
parsed_s3_url = urlparse (self .data_location )
313
325
bucket , key_prefix = parsed_s3_url .netloc , parsed_s3_url .path
314
- key_prefix = key_prefix + "{}-{}/" .format (type (self ).__name__ , sagemaker_timestamp ())
326
+ key_prefix = key_prefix + "{}-{}/" .format (
327
+ type (self ).__name__ , sagemaker_timestamp ()
328
+ )
315
329
key_prefix = key_prefix .lstrip ("/" )
316
330
logger .debug ("Uploading to bucket %s and key_prefix %s" , bucket , key_prefix )
317
331
manifest_s3_file = upload_numpy_to_s3_shards (
@@ -338,7 +352,9 @@ def _get_default_mini_batch_size(self, num_records: int):
338
352
)
339
353
return 1
340
354
341
- return min (self .DEFAULT_MINI_BATCH_SIZE , max (1 , int (num_records / self .instance_count )))
355
+ return min (
356
+ self .DEFAULT_MINI_BATCH_SIZE , max (1 , int (num_records / self .instance_count ))
357
+ )
342
358
343
359
344
360
class RecordSet (object ):
@@ -447,7 +463,10 @@ def _build_shards(num_shards, array):
447
463
shard_size = int (array .shape [0 ] / num_shards )
448
464
if shard_size == 0 :
449
465
raise ValueError ("Array length is less than num shards" )
450
- shards = [array [i * shard_size : i * shard_size + shard_size ] for i in range (num_shards - 1 )]
466
+ shards = [
467
+ array [i * shard_size : i * shard_size + shard_size ]
468
+ for i in range (num_shards - 1 )
469
+ ]
451
470
shards .append (array [(num_shards - 1 ) * shard_size :])
452
471
return shards
453
472
@@ -494,7 +513,9 @@ def upload_numpy_to_s3_shards(
494
513
manifest_str = json .dumps (
495
514
[{"prefix" : "s3://{}/{}" .format (bucket , key_prefix )}] + uploaded_files
496
515
)
497
- s3 .Object (bucket , manifest_key ).put (Body = manifest_str .encode ("utf-8" ), ** extra_put_kwargs )
516
+ s3 .Object (bucket , manifest_key ).put (
517
+ Body = manifest_str .encode ("utf-8" ), ** extra_put_kwargs
518
+ )
498
519
return "s3://{}/{}" .format (bucket , manifest_key )
499
520
except Exception as ex : # pylint: disable=broad-except
500
521
try :
0 commit comments