Skip to content

Commit e57034a

Browse files
committed
Address comments before test
1 parent e42c800 commit e57034a

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

src/sagemaker/workflow/airflow.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020

2121

2222
def prepare_framework(estimator, s3_operations):
23-
"""Prepare S3 operations (specify where to upload source_dir) and environment variables
23+
"""Prepare S3 operations (specify where to upload `source_dir`) and environment variables
2424
related to framework.
2525
2626
Args:
2727
estimator (sagemaker.estimator.Estimator): The framework estimator to get information from and update.
28-
s3_operations (dict): The dict to specify s3 operations (upload source_dir).
28+
s3_operations (dict): The dict to specify s3 operations (upload `source_dir`).
2929
"""
3030
bucket = estimator.code_location if estimator.code_location else estimator.sagemaker_session._default_bucket
3131
key = '{}/source/sourcedir.tar.gz'.format(estimator._current_job_name)
@@ -106,8 +106,8 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
106106
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
107107
Amazon algorithm. For other estimators, batch size should be specified in the estimator.
108108
109-
Returns (dict):
110-
Training config that can be directly used by SageMakerTrainingOperator in Airflow.
109+
Returns:
110+
dict: Training config that can be directly used by SageMakerTrainingOperator in Airflow.
111111
"""
112112
default_bucket = estimator.sagemaker_session.default_bucket()
113113
s3_operations = {}
@@ -185,8 +185,8 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
185185
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an
186186
Amazon algorithm. For other estimators, batch size should be specified in the estimator.
187187
188-
Returns (dict):
189-
Training config that can be directly used by SageMakerTrainingOperator in Airflow.
188+
Returns:
189+
dict: Training config that can be directly used by SageMakerTrainingOperator in Airflow.
190190
"""
191191

192192
train_config = training_base_config(estimator, inputs, job_name, mini_batch_size)
@@ -223,8 +223,8 @@ def tuning_config(tuner, inputs, job_name=None):
223223
224224
job_name (str): Specify a tuning job name if needed.
225225
226-
Returns (dict):
227-
Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
226+
Returns:
227+
dict: Tuning config that can be directly used by SageMakerTuningOperator in Airflow.
228228
"""
229229
train_config = training_base_config(tuner.estimator, inputs)
230230
hyperparameters = train_config.pop('HyperParameters', None)
@@ -277,15 +277,15 @@ def tuning_config(tuner, inputs, job_name=None):
277277

278278
def prepare_framework_container_def(model, instance_type, s3_operations):
279279
"""Prepare the framework model container information. Specify related S3 operations for Airflow to perform.
280-
(Upload source_dir)
280+
(Upload `source_dir`)
281281
282282
Args:
283283
model (sagemaker.model.FrameworkModel): The framework model
284284
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
285-
s3_operations (dict): The dict to specify S3 operations (upload source_dir).
285+
s3_operations (dict): The dict to specify S3 operations (upload `source_dir`).
286286
287-
Returns (dict):
288-
The container information of this framework model.
287+
Returns:
288+
dict: The container information of this framework model.
289289
"""
290290
deploy_image = model.image
291291
if not deploy_image:
@@ -322,9 +322,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
322322
# This applies to a FrameworkModel which is not SageMaker Deep Learning Framework Model
323323
pass
324324

325-
container_def = sagemaker.container_def(deploy_image, model.model_data, deploy_env)
326-
327-
return container_def
325+
return sagemaker.container_def(deploy_image, model.model_data, deploy_env)
328326

329327

330328
def model_config(instance_type, model, role=None, image=None):
@@ -336,8 +334,8 @@ def model_config(instance_type, model, role=None, image=None):
336334
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
337335
image (str): An container image to use for deploying the model
338336
339-
Returns (dict):
340-
Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part
337+
Returns:
338+
dict: Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part
341339
of the config used by SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
342340
"""
343341
s3_operations = {}
@@ -385,8 +383,8 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
385383
* 'Subnets' (list[str]): List of subnet ids.
386384
* 'SecurityGroupIds' (list[str]): List of security group ids.
387385
388-
Returns (dict):
389-
Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part
386+
Returns:
387+
dict: Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part
390388
of the config used by SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
391389
"""
392390
if isinstance(estimator, sagemaker.estimator.Estimator):

0 commit comments

Comments
 (0)