@@ -367,14 +367,19 @@ def model_config(instance_type, model, role=None, image=None):
367
367
return config
368
368
369
369
370
- def model_config_from_estimator (instance_type , estimator , role = None , image = None ,
370
+ def model_config_from_estimator (instance_type , estimator , role = None , image = None , model_server_workers = None ,
371
371
vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ):
372
372
"""Export Airflow model config from a SageMaker estimator
373
373
374
374
Args:
375
375
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'
376
376
estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
377
377
It has to be an estimator associated with a training job.
378
+ role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
379
+ image (str): An container image to use for deploying the model
380
+ model_server_workers (int): The number of worker processes used by the inference server.
381
+ If None, server will use one worker per vCPU. Only effective when estimator is
382
+ SageMaker framework.
378
383
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
379
384
Default: use subnets and security groups from this Estimator.
380
385
* 'Subnets' (list[str]): List of subnet ids.
@@ -384,9 +389,12 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
384
389
Model config that can be directly used by SageMakerModelOperator in Airflow. It can also be part
385
390
of the config used by SageMakerEndpointOperator and SageMakerTransformOperator in Airflow.
386
391
"""
387
- try :
388
- model = estimator .create_model (role , image , vpc_config_override = vpc_config_override )
389
- except TypeError :
392
+ if isinstance ( estimator , sagemaker . estimator . Estimator ) :
393
+ model = estimator .create_model (role = role , image = image , vpc_config_override = vpc_config_override )
394
+ elif isinstance ( estimator , sagemaker . amazon . amazon_estimator . AmazonAlgorithmEstimatorBase ) :
390
395
model = estimator .create_model (vpc_config_override = vpc_config_override )
396
+ elif isinstance (estimator , sagemaker .estimator .Framework ):
397
+ model = estimator .create_model (model_server_workers = model_server_workers , role = role ,
398
+ vpc_config_override = vpc_config_override )
391
399
392
400
return model_config (instance_type , model , role , image )
0 commit comments