@@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
89
89
_test_graviton_framework_uris ("pytorch" , graviton_pytorch_version )
90
90
91
91
92
- def test_graviton_xgboost (graviton_xgboost_versions ):
92
+ def test_graviton_xgboost_instance_type_specified (graviton_xgboost_versions ):
93
93
for xgboost_version in graviton_xgboost_versions :
94
94
for instance_type in GRAVITON_INSTANCE_TYPES :
95
95
uri = image_uris .retrieve (
@@ -102,6 +102,33 @@ def test_graviton_xgboost(graviton_xgboost_versions):
102
102
assert expected == uri
103
103
104
104
105
+ def test_graviton_xgboost_image_scope_specified (graviton_xgboost_versions ):
106
+ for xgboost_version in graviton_xgboost_versions :
107
+ for instance_type in GRAVITON_INSTANCE_TYPES :
108
+ uri = image_uris .retrieve (
109
+ "xgboost" , "us-west-2" , version = xgboost_version , image_scope = "inference_graviton"
110
+ )
111
+ expected = (
112
+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
113
+ f"{ xgboost_version } -arm64"
114
+ )
115
+ assert expected == uri
116
+
117
+
118
+ def test_graviton_xgboost_image_scope_specified_x86_instance (graviton_xgboost_versions ):
119
+ for xgboost_version in graviton_xgboost_versions :
120
+ for instance_type in GRAVITON_INSTANCE_TYPES :
121
+ with pytest .raises (ValueError ) as error :
122
+ image_uris .retrieve (
123
+ "xgboost" ,
124
+ "us-west-2" ,
125
+ version = xgboost_version ,
126
+ image_scope = "inference_graviton" ,
127
+ instance_type = "ml.m5.xlarge" ,
128
+ )
129
+ assert "Unsupported instance type: m5." in str (error )
130
+
131
+
105
132
def test_graviton_xgboost_unsupported_version (graviton_xgboost_unsupported_versions ):
106
133
for xgboost_version in graviton_xgboost_unsupported_versions :
107
134
for instance_type in GRAVITON_INSTANCE_TYPES :
@@ -112,7 +139,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
112
139
assert f"Unsupported xgboost version: { xgboost_version } ." in str (error )
113
140
114
141
115
- def test_graviton_sklearn (graviton_sklearn_versions ):
142
+ def test_graviton_sklearn_instance_type_specified (graviton_sklearn_versions ):
116
143
for sklearn_version in graviton_sklearn_versions :
117
144
for instance_type in GRAVITON_INSTANCE_TYPES :
118
145
uri = image_uris .retrieve (
@@ -125,6 +152,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
125
152
assert expected == uri
126
153
127
154
155
+ def test_graviton_sklearn_image_scope_specified (graviton_sklearn_versions ):
156
+ for sklearn_version in graviton_sklearn_versions :
157
+ for instance_type in GRAVITON_INSTANCE_TYPES :
158
+ uri = image_uris .retrieve (
159
+ "sklearn" , "us-west-2" , version = sklearn_version , image_scope = "inference_graviton"
160
+ )
161
+ expected = (
162
+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
163
+ f"{ sklearn_version } -arm64-cpu-py3"
164
+ )
165
+ assert expected == uri
166
+
167
+
128
168
def test_graviton_sklearn_unsupported_version (graviton_sklearn_unsupported_versions ):
129
169
for sklearn_version in graviton_sklearn_unsupported_versions :
130
170
for instance_type in GRAVITON_INSTANCE_TYPES :
@@ -138,6 +178,20 @@ def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versi
138
178
assert expected == uri
139
179
140
180
181
+ def test_graviton_sklearn_image_scope_specified_x86_instance (graviton_sklearn_unsupported_versions ):
182
+ for sklearn_version in graviton_sklearn_unsupported_versions :
183
+ for instance_type in GRAVITON_INSTANCE_TYPES :
184
+ with pytest .raises (ValueError ) as error :
185
+ image_uris .retrieve (
186
+ "sklearn" ,
187
+ "us-west-2" ,
188
+ version = sklearn_version ,
189
+ image_scope = "inference_graviton" ,
190
+ instance_type = "ml.m5.xlarge" ,
191
+ )
192
+ assert "Unsupported instance type: m5." in str (error )
193
+
194
+
141
195
def _expected_graviton_framework_uri (framework , version , region ):
142
196
return expected_uris .graviton_framework_uri (
143
197
"{}-inference-graviton" .format (framework ),
0 commit comments