@@ -133,18 +133,14 @@ def data_location(self, data_location: str):
133
133
134
134
if not data_location .startswith ("s3://" ):
135
135
raise ValueError (
136
- 'Expecting an S3 URL beginning with "s3://". Got "{}"' .format (
137
- data_location
138
- )
136
+ 'Expecting an S3 URL beginning with "s3://". Got "{}"' .format (data_location )
139
137
)
140
138
if data_location [- 1 ] != "/" :
141
139
data_location = data_location + "/"
142
140
self ._data_location = data_location
143
141
144
142
@classmethod
145
- def _prepare_init_params_from_job_description (
146
- cls , job_details , model_channel_name = None
147
- ):
143
+ def _prepare_init_params_from_job_description (cls , job_details , model_channel_name = None ):
148
144
"""Convert the job description to init params that can be handled by the class constructor.
149
145
150
146
Args:
@@ -172,9 +168,7 @@ def _prepare_init_params_from_job_description(
172
168
del init_params ["image_uri" ]
173
169
return init_params
174
170
175
- def prepare_workflow_for_training (
176
- self , records = None , mini_batch_size = None , job_name = None
177
- ):
171
+ def prepare_workflow_for_training (self , records = None , mini_batch_size = None , job_name = None ):
178
172
"""Calls _prepare_for_training. Used when setting up a workflow.
179
173
180
174
Args:
@@ -200,9 +194,7 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
200
194
specified, one is generated, using the base name given to the
201
195
constructor if applicable.
202
196
"""
203
- super (AmazonAlgorithmEstimatorBase , self )._prepare_for_training (
204
- job_name = job_name
205
- )
197
+ super (AmazonAlgorithmEstimatorBase , self )._prepare_for_training (job_name = job_name )
206
198
207
199
feature_dim = None
208
200
@@ -268,9 +260,7 @@ def fit(
268
260
will be unassociated.
269
261
* `TrialComponentDisplayName` is used for display in Studio.
270
262
"""
271
- self ._prepare_for_training (
272
- records , job_name = job_name , mini_batch_size = mini_batch_size
273
- )
263
+ self ._prepare_for_training (records , job_name = job_name , mini_batch_size = mini_batch_size )
274
264
275
265
experiment_config = check_and_get_run_experiment_config (experiment_config )
276
266
self .latest_training_job = _TrainingJob .start_new (
@@ -323,9 +313,7 @@ def record_set(
323
313
)
324
314
parsed_s3_url = urlparse (self .data_location )
325
315
bucket , key_prefix = parsed_s3_url .netloc , parsed_s3_url .path
326
- key_prefix = key_prefix + "{}-{}/" .format (
327
- type (self ).__name__ , sagemaker_timestamp ()
328
- )
316
+ key_prefix = key_prefix + "{}-{}/" .format (type (self ).__name__ , sagemaker_timestamp ())
329
317
key_prefix = key_prefix .lstrip ("/" )
330
318
logger .debug ("Uploading to bucket %s and key_prefix %s" , bucket , key_prefix )
331
319
manifest_s3_file = upload_numpy_to_s3_shards (
@@ -352,9 +340,7 @@ def _get_default_mini_batch_size(self, num_records: int):
352
340
)
353
341
return 1
354
342
355
- return min (
356
- self .DEFAULT_MINI_BATCH_SIZE , max (1 , int (num_records / self .instance_count ))
357
- )
343
+ return min (self .DEFAULT_MINI_BATCH_SIZE , max (1 , int (num_records / self .instance_count )))
358
344
359
345
360
346
class RecordSet (object ):
@@ -463,10 +449,7 @@ def _build_shards(num_shards, array):
463
449
shard_size = int (array .shape [0 ] / num_shards )
464
450
if shard_size == 0 :
465
451
raise ValueError ("Array length is less than num shards" )
466
- shards = [
467
- array [i * shard_size : i * shard_size + shard_size ]
468
- for i in range (num_shards - 1 )
469
- ]
452
+ shards = [array [i * shard_size : i * shard_size + shard_size ] for i in range (num_shards - 1 )]
470
453
shards .append (array [(num_shards - 1 ) * shard_size :])
471
454
return shards
472
455
@@ -513,9 +496,7 @@ def upload_numpy_to_s3_shards(
513
496
manifest_str = json .dumps (
514
497
[{"prefix" : "s3://{}/{}" .format (bucket , key_prefix )}] + uploaded_files
515
498
)
516
- s3 .Object (bucket , manifest_key ).put (
517
- Body = manifest_str .encode ("utf-8" ), ** extra_put_kwargs
518
- )
499
+ s3 .Object (bucket , manifest_key ).put (Body = manifest_str .encode ("utf-8" ), ** extra_put_kwargs )
519
500
return "s3://{}/{}" .format (bucket , manifest_key )
520
501
except Exception as ex : # pylint: disable=broad-except
521
502
try :
0 commit comments