Skip to content

Commit ed158e0

Browse files
committed
don't assume create_model takes model_server_workers for airflow
1 parent 64f7095 commit ed158e0

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/sagemaker/workflow/airflow.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -647,12 +647,14 @@ def model_config_from_estimator(
647647
elif isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
648648
model = estimator.create_model(vpc_config_override=vpc_config_override)
649649
elif isinstance(estimator, sagemaker.estimator.Framework):
650-
model = estimator.create_model(
651-
model_server_workers=model_server_workers,
652-
role=role,
653-
vpc_config_override=vpc_config_override,
654-
entry_point=estimator.entry_point,
655-
)
650+
model_kwargs = {
651+
"role": role,
652+
"vpc_config_override": vpc_config_override,
653+
"entry_point": estimator.entry_point,
654+
}
655+
if model_server_workers:
656+
model_kwargs["model_server_workers"] = model_server_workers
657+
model = estimator.create_model(**model_kwargs)
656658
else:
657659
raise TypeError(
658660
"Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework"

0 commit comments

Comments
 (0)