Skip to content

Commit e0e4752

Browse files
committed
Update model config from estimator with framework logic
1 parent 3583adb commit e0e4752

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/sagemaker/workflow/airflow.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,19 @@ def model_config(instance_type, model, role=None, image=None):
367367
return config
368368

369369

370-
def model_config_from_estimator(instance_type, estimator, role=None, image=None,
370+
def model_config_from_estimator(instance_type, estimator, role=None, image=None, model_server_workers=None,
371371
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT):
372372
"""Export Airflow model config from a SageMaker estimator
373373
374374
Args:
375375
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'
376376
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
377377
It has to be an estimator associated with a training job.
378+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
379+
image (str): An container image to use for deploying the model
380+
model_server_workers (int): The number of worker processes used by the inference server.
381+
If None, server will use one worker per vCPU. Only effective when estimator is
382+
SageMaker framework.
378383
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
379384
Default: use subnets and security groups from this Estimator.
380385
* 'Subnets' (list[str]): List of subnet ids.
@@ -384,9 +389,12 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
384389
Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part
385390
of the config used by SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
386391
"""
387-
try:
388-
model = estimator.create_model(role, image, vpc_config_override=vpc_config_override)
389-
except TypeError:
392+
if isinstance(estimator, sagemaker.estimator.Estimator):
393+
model = estimator.create_model(role=role, image=image, vpc_config_override=vpc_config_override)
394+
elif isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
390395
model = estimator.create_model(vpc_config_override=vpc_config_override)
396+
elif isinstance(estimator, sagemaker.estimator.Framework):
397+
model = estimator.create_model(model_server_workers=model_server_workers, role=role,
398+
vpc_config_override=vpc_config_override)
391399

392400
return model_config(instance_type, model, role, image)

0 commit comments

Comments
 (0)