Skip to content

Commit 227394e

Browse files
committed
reformat
1 parent deb9664 commit 227394e

File tree

1 file changed

+9
-28
lines changed

1 file changed

+9
-28
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,14 @@ def data_location(self, data_location: str):
133133

134134
if not data_location.startswith("s3://"):
135135
raise ValueError(
136-
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(
137-
data_location
138-
)
136+
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location)
139137
)
140138
if data_location[-1] != "/":
141139
data_location = data_location + "/"
142140
self._data_location = data_location
143141

144142
@classmethod
145-
def _prepare_init_params_from_job_description(
146-
cls, job_details, model_channel_name=None
147-
):
143+
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
148144
"""Convert the job description to init params that can be handled by the class constructor.
149145
150146
Args:
@@ -172,9 +168,7 @@ def _prepare_init_params_from_job_description(
172168
del init_params["image_uri"]
173169
return init_params
174170

175-
def prepare_workflow_for_training(
176-
self, records=None, mini_batch_size=None, job_name=None
177-
):
171+
def prepare_workflow_for_training(self, records=None, mini_batch_size=None, job_name=None):
178172
"""Calls _prepare_for_training. Used when setting up a workflow.
179173
180174
Args:
@@ -200,9 +194,7 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
200194
specified, one is generated, using the base name given to the
201195
constructor if applicable.
202196
"""
203-
super(AmazonAlgorithmEstimatorBase, self)._prepare_for_training(
204-
job_name=job_name
205-
)
197+
super(AmazonAlgorithmEstimatorBase, self)._prepare_for_training(job_name=job_name)
206198

207199
feature_dim = None
208200

@@ -268,9 +260,7 @@ def fit(
268260
will be unassociated.
269261
* `TrialComponentDisplayName` is used for display in Studio.
270262
"""
271-
self._prepare_for_training(
272-
records, job_name=job_name, mini_batch_size=mini_batch_size
273-
)
263+
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
274264

275265
experiment_config = check_and_get_run_experiment_config(experiment_config)
276266
self.latest_training_job = _TrainingJob.start_new(
@@ -323,9 +313,7 @@ def record_set(
323313
)
324314
parsed_s3_url = urlparse(self.data_location)
325315
bucket, key_prefix = parsed_s3_url.netloc, parsed_s3_url.path
326-
key_prefix = key_prefix + "{}-{}/".format(
327-
type(self).__name__, sagemaker_timestamp()
328-
)
316+
key_prefix = key_prefix + "{}-{}/".format(type(self).__name__, sagemaker_timestamp())
329317
key_prefix = key_prefix.lstrip("/")
330318
logger.debug("Uploading to bucket %s and key_prefix %s", bucket, key_prefix)
331319
manifest_s3_file = upload_numpy_to_s3_shards(
@@ -352,9 +340,7 @@ def _get_default_mini_batch_size(self, num_records: int):
352340
)
353341
return 1
354342

355-
return min(
356-
self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count))
357-
)
343+
return min(self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)))
358344

359345

360346
class RecordSet(object):
@@ -463,10 +449,7 @@ def _build_shards(num_shards, array):
463449
shard_size = int(array.shape[0] / num_shards)
464450
if shard_size == 0:
465451
raise ValueError("Array length is less than num shards")
466-
shards = [
467-
array[i * shard_size : i * shard_size + shard_size]
468-
for i in range(num_shards - 1)
469-
]
452+
shards = [array[i * shard_size : i * shard_size + shard_size] for i in range(num_shards - 1)]
470453
shards.append(array[(num_shards - 1) * shard_size :])
471454
return shards
472455

@@ -513,9 +496,7 @@ def upload_numpy_to_s3_shards(
513496
manifest_str = json.dumps(
514497
[{"prefix": "s3://{}/{}".format(bucket, key_prefix)}] + uploaded_files
515498
)
516-
s3.Object(bucket, manifest_key).put(
517-
Body=manifest_str.encode("utf-8"), **extra_put_kwargs
518-
)
499+
s3.Object(bucket, manifest_key).put(Body=manifest_str.encode("utf-8"), **extra_put_kwargs)
519500
return "s3://{}/{}".format(bucket, manifest_key)
520501
except Exception as ex: # pylint: disable=broad-except
521502
try:

0 commit comments

Comments
 (0)