Skip to content

Commit 8729c8a

Browse files
committed
breaking: make instance_type optional and change parameter order for prepare_container_def in Model/FrameworkModel
1 parent 97cd594 commit 8729c8a

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

src/sagemaker/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type):
138138
self.sagemaker_session = session.Session()
139139

140140
def prepare_container_def(
141-
self, instance_type, accelerator_type=None
141+
self, instance_type=None, accelerator_type=None
142142
): # pylint: disable=unused-argument
143143
"""Return a dict created by ``sagemaker.container_def()`` for deploying
144144
this model to a specified instance type.
@@ -166,7 +166,7 @@ def enable_network_isolation(self):
166166
"""
167167
return self._enable_network_isolation
168168

169-
def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=None):
169+
def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tags=None):
170170
"""Create a SageMaker Model Entity
171171
172172
Args:
@@ -808,7 +808,7 @@ def __init__(
808808
self.repacked_model_data = None
809809

810810
def prepare_container_def(
811-
self, instance_type, accelerator_type=None
811+
self, instance_type=None, accelerator_type=None
812812
): # pylint disable=unused-argument
813813
"""Return a container definition with framework configuration set in
814814
model environment variables.

src/sagemaker/workflow/airflow.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -555,14 +555,14 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
555555
return sagemaker.container_def(deploy_image, model.model_data, deploy_env)
556556

557557

558-
def model_config(instance_type, model, role=None, image=None):
558+
def model_config(model, instance_type=None, role=None, image=None):
559559
"""Export Airflow model config from a SageMaker model
560560
561561
Args:
562-
instance_type (str): The EC2 instance type to deploy this Model to. For
563-
example, 'ml.p2.xlarge'
564562
model (sagemaker.model.FrameworkModel): The SageMaker model to export
565563
Airflow config from
564+
instance_type (str): The EC2 instance type to deploy this Model to. For
565+
example, 'ml.p2.xlarge'
566566
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
567567
image (str): An container image to use for deploying the model
568568
@@ -577,7 +577,7 @@ def model_config(instance_type, model, role=None, image=None):
577577
if isinstance(model, sagemaker.model.FrameworkModel):
578578
container_def = prepare_framework_container_def(model, instance_type, s3_operations)
579579
else:
580-
container_def = model.prepare_container_def(instance_type)
580+
container_def = model.prepare_container_def()
581581
base_name = utils.base_name_from_image(container_def["Image"])
582582
model.name = model.name or utils.name_from_base(base_name)
583583

@@ -599,10 +599,10 @@ def model_config(instance_type, model, role=None, image=None):
599599

600600

601601
def model_config_from_estimator(
602-
instance_type,
603602
estimator,
604603
task_id,
605604
task_type,
605+
instance_type=None,
606606
role=None,
607607
image=None,
608608
name=None,
@@ -612,8 +612,6 @@ def model_config_from_estimator(
612612
"""Export Airflow model config from a SageMaker estimator
613613
614614
Args:
615-
instance_type (str): The EC2 instance type to deploy this Model to. For
616-
example, 'ml.p2.xlarge'
617615
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to
618616
export Airflow config from. It has to be an estimator associated
619617
with a training job.
@@ -625,6 +623,8 @@ def model_config_from_estimator(
625623
task_type (str): Whether the task is from SageMakerTrainingOperator or
626624
SageMakerTuningOperator. Values can be 'training', 'tuning' or None
627625
(which means training job is not from any task).
626+
instance_type (str): The EC2 instance type to deploy this Model to. For
627+
example, 'ml.p2.xlarge'
628628
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
629629
image (str): An container image to use for deploying the model
630630
name (str): Name of the model
@@ -665,7 +665,7 @@ def model_config_from_estimator(
665665
)
666666
model.name = name
667667

668-
return model_config(instance_type, model, role, image)
668+
return model_config(model, instance_type, role, image)
669669

670670

671671
def transform_config(
@@ -912,10 +912,10 @@ def transform_config_from_estimator(
912912
SageMakerTransformOperator in Airflow.
913913
"""
914914
model_base_config = model_config_from_estimator(
915-
instance_type=instance_type,
916915
estimator=estimator,
917916
task_id=task_id,
918917
task_type=task_type,
918+
instance_type=instance_type,
919919
role=role,
920920
image=image,
921921
name=model_name,
@@ -995,7 +995,7 @@ def deploy_config(model, initial_instance_count, instance_type, endpoint_name=No
995995
dict: Deploy config that can be directly used by
996996
SageMakerEndpointOperator in Airflow.
997997
"""
998-
model_base_config = model_config(instance_type, model)
998+
model_base_config = model_config(model, instance_type)
999999

10001000
production_variant = sagemaker.production_variant(
10011001
model.name, instance_type, initial_instance_count

tests/unit/test_airflow.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def test_byo_model_config(sagemaker_session):
900900
sagemaker_session=sagemaker_session,
901901
)
902902

903-
config = airflow.model_config(instance_type="ml.c4.xlarge", model=byo_model)
903+
config = airflow.model_config(model=byo_model)
904904
expected_config = {
905905
"ModelName": "model",
906906
"PrimaryContainer": {
@@ -926,7 +926,7 @@ def test_byo_framework_model_config(sagemaker_session):
926926
sagemaker_session=sagemaker_session,
927927
)
928928

929-
config = airflow.model_config(instance_type="ml.c4.xlarge", model=byo_model)
929+
config = airflow.model_config(model=byo_model, instance_type="ml.c4.xlarge")
930930
expected_config = {
931931
"ModelName": "model",
932932
"PrimaryContainer": {
@@ -971,7 +971,7 @@ def test_framework_model_config(sagemaker_session):
971971
sagemaker_session=sagemaker_session,
972972
)
973973

974-
config = airflow.model_config(instance_type="ml.c4.xlarge", model=chainer_model)
974+
config = airflow.model_config(model=chainer_model, instance_type="ml.c4.xlarge")
975975
expected_config = {
976976
"ModelName": "sagemaker-chainer-%s" % TIME_STAMP,
977977
"PrimaryContainer": {
@@ -1009,7 +1009,7 @@ def test_amazon_alg_model_config(sagemaker_session):
10091009
model_data="{{ model_data }}", role="{{ role }}", sagemaker_session=sagemaker_session
10101010
)
10111011

1012-
config = airflow.model_config(instance_type="ml.c4.xlarge", model=pca_model)
1012+
config = airflow.model_config(model=pca_model)
10131013
expected_config = {
10141014
"ModelName": "pca-%s" % TIME_STAMP,
10151015
"PrimaryContainer": {
@@ -1059,10 +1059,10 @@ def test_model_config_from_framework_estimator(ecr_prefix, sagemaker_session):
10591059
airflow.training_config(mxnet_estimator, data)
10601060

10611061
config = airflow.model_config_from_estimator(
1062-
instance_type="ml.c4.xlarge",
10631062
estimator=mxnet_estimator,
10641063
task_id="task_id",
10651064
task_type="training",
1065+
instance_type="ml.c4.xlarge",
10661066
)
10671067
expected_config = {
10681068
"ModelName": "mxnet-inference-%s" % TIME_STAMP,
@@ -1103,7 +1103,7 @@ def test_model_config_from_amazon_alg_estimator(sagemaker_session):
11031103
airflow.training_config(knn_estimator, record, mini_batch_size=256)
11041104

11051105
config = airflow.model_config_from_estimator(
1106-
instance_type="ml.c4.xlarge", estimator=knn_estimator, task_id="task_id", task_type="tuning"
1106+
estimator=knn_estimator, task_id="task_id", task_type="tuning"
11071107
)
11081108
expected_config = {
11091109
"ModelName": "knn-%s" % TIME_STAMP,

0 commit comments

Comments
 (0)