Skip to content

Commit 9218095

Browse files
author
Nikhil Kulkarni
committed
Add a condition to retrieve correct image URI for xgboost
1 parent 7e2ee02 commit 9218095

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/sagemaker/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,17 @@ def _compilation_image_uri(self, region, target_instance_type, framework, framew
274274
framework (str): The framework name.
275275
framework_version (str): The framework version.
276276
"""
277-
framework_prefix = "inferentia-" if target_instance_type.startswith("ml_inf") else "neo-"
277+
framework_prefix = ""
278+
framework_suffix = ""
279+
280+
if framework == "xgboost":
281+
framework_suffix = "-neo"
282+
else:
283+
framework_prefix = "inferentia-" if target_instance_type.startswith("ml_inf") else "neo-"
284+
285+
278286
return image_uris.retrieve(
279-
"{}{}".format(framework_prefix, framework),
287+
"{}{}".format(framework_prefix, framework, framework_suffix),
280288
region,
281289
instance_type=target_instance_type,
282290
version=framework_version,

0 commit comments

Comments
 (0)