Skip to content

Commit deb9664

Browse files
committed
reformat
1 parent 6c037a6 commit deb9664

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,18 @@ def data_location(self, data_location: str):
133133

134134
if not data_location.startswith("s3://"):
135135
raise ValueError(
136-
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location)
136+
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(
137+
data_location
138+
)
137139
)
138140
if data_location[-1] != "/":
139141
data_location = data_location + "/"
140142
self._data_location = data_location
141143

142144
@classmethod
143-
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
145+
def _prepare_init_params_from_job_description(
146+
cls, job_details, model_channel_name=None
147+
):
144148
"""Convert the job description to init params that can be handled by the class constructor.
145149
146150
Args:
@@ -168,7 +172,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
168172
del init_params["image_uri"]
169173
return init_params
170174

171-
def prepare_workflow_for_training(self, records=None, mini_batch_size=None, job_name=None):
175+
def prepare_workflow_for_training(
176+
self, records=None, mini_batch_size=None, job_name=None
177+
):
172178
"""Calls _prepare_for_training. Used when setting up a workflow.
173179
174180
Args:
@@ -194,7 +200,9 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
194200
specified, one is generated, using the base name given to the
195201
constructor if applicable.
196202
"""
197-
super(AmazonAlgorithmEstimatorBase, self)._prepare_for_training(job_name=job_name)
203+
super(AmazonAlgorithmEstimatorBase, self)._prepare_for_training(
204+
job_name=job_name
205+
)
198206

199207
feature_dim = None
200208

@@ -260,7 +268,9 @@ def fit(
260268
will be unassociated.
261269
* `TrialComponentDisplayName` is used for display in Studio.
262270
"""
263-
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
271+
self._prepare_for_training(
272+
records, job_name=job_name, mini_batch_size=mini_batch_size
273+
)
264274

265275
experiment_config = check_and_get_run_experiment_config(experiment_config)
266276
self.latest_training_job = _TrainingJob.start_new(
@@ -269,12 +279,14 @@ def fit(
269279
if wait:
270280
self.latest_training_job.wait(logs=logs)
271281

272-
def record_set(self,
273-
train,
274-
labels=None,
275-
channel="train",
276-
encrypt=False,
277-
distribution="ShardedByS3Key"):
282+
def record_set(
283+
self,
284+
train,
285+
labels=None,
286+
channel="train",
287+
encrypt=False,
288+
distribution="ShardedByS3Key",
289+
):
278290
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
279291
280292
For the 2D ``ndarray`` ``train``, each row is converted to a
@@ -311,7 +323,9 @@ def record_set(self,
311323
)
312324
parsed_s3_url = urlparse(self.data_location)
313325
bucket, key_prefix = parsed_s3_url.netloc, parsed_s3_url.path
314-
key_prefix = key_prefix + "{}-{}/".format(type(self).__name__, sagemaker_timestamp())
326+
key_prefix = key_prefix + "{}-{}/".format(
327+
type(self).__name__, sagemaker_timestamp()
328+
)
315329
key_prefix = key_prefix.lstrip("/")
316330
logger.debug("Uploading to bucket %s and key_prefix %s", bucket, key_prefix)
317331
manifest_s3_file = upload_numpy_to_s3_shards(
@@ -338,7 +352,9 @@ def _get_default_mini_batch_size(self, num_records: int):
338352
)
339353
return 1
340354

341-
return min(self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)))
355+
return min(
356+
self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count))
357+
)
342358

343359

344360
class RecordSet(object):
@@ -447,7 +463,10 @@ def _build_shards(num_shards, array):
447463
shard_size = int(array.shape[0] / num_shards)
448464
if shard_size == 0:
449465
raise ValueError("Array length is less than num shards")
450-
shards = [array[i * shard_size : i * shard_size + shard_size] for i in range(num_shards - 1)]
466+
shards = [
467+
array[i * shard_size : i * shard_size + shard_size]
468+
for i in range(num_shards - 1)
469+
]
451470
shards.append(array[(num_shards - 1) * shard_size :])
452471
return shards
453472

@@ -494,7 +513,9 @@ def upload_numpy_to_s3_shards(
494513
manifest_str = json.dumps(
495514
[{"prefix": "s3://{}/{}".format(bucket, key_prefix)}] + uploaded_files
496515
)
497-
s3.Object(bucket, manifest_key).put(Body=manifest_str.encode("utf-8"), **extra_put_kwargs)
516+
s3.Object(bucket, manifest_key).put(
517+
Body=manifest_str.encode("utf-8"), **extra_put_kwargs
518+
)
498519
return "s3://{}/{}".format(bucket, manifest_key)
499520
except Exception as ex: # pylint: disable=broad-except
500521
try:

0 commit comments

Comments
 (0)