|
24 | 24 | MODEL_DATA = "s3://bucket/model.tar.gz"
|
25 | 25 | MODEL_IMAGE = "mi"
|
26 | 26 | ENTRY_POINT = "blah.py"
|
27 |
| -INSTANCE_TYPE = "p2.xlarge" |
28 | 27 | ROLE = "some-role"
|
29 | 28 |
|
30 | 29 | DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
@@ -172,145 +171,6 @@ def test_create_no_defaults(sagemaker_session, tmpdir):
|
172 | 171 | }
|
173 | 172 |
|
174 | 173 |
|
175 |
| -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
176 |
| -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) |
177 |
| -def test_deploy(sagemaker_session, tmpdir): |
178 |
| - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
179 |
| - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) |
180 |
| - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
181 |
| - name=MODEL_NAME, |
182 |
| - production_variants=[ |
183 |
| - { |
184 |
| - "InitialVariantWeight": 1, |
185 |
| - "ModelName": MODEL_NAME, |
186 |
| - "InstanceType": INSTANCE_TYPE, |
187 |
| - "InitialInstanceCount": 1, |
188 |
| - "VariantName": "AllTraffic", |
189 |
| - } |
190 |
| - ], |
191 |
| - tags=None, |
192 |
| - kms_key=None, |
193 |
| - wait=True, |
194 |
| - data_capture_config_dict=None, |
195 |
| - ) |
196 |
| - |
197 |
| - |
198 |
| -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
199 |
| -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) |
200 |
| -def test_deploy_endpoint_name(sagemaker_session, tmpdir): |
201 |
| - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
202 |
| - model.deploy(endpoint_name="blah", instance_type=INSTANCE_TYPE, initial_instance_count=55) |
203 |
| - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
204 |
| - name="blah", |
205 |
| - production_variants=[ |
206 |
| - { |
207 |
| - "InitialVariantWeight": 1, |
208 |
| - "ModelName": MODEL_NAME, |
209 |
| - "InstanceType": INSTANCE_TYPE, |
210 |
| - "InitialInstanceCount": 55, |
211 |
| - "VariantName": "AllTraffic", |
212 |
| - } |
213 |
| - ], |
214 |
| - tags=None, |
215 |
| - kms_key=None, |
216 |
| - wait=True, |
217 |
| - data_capture_config_dict=None, |
218 |
| - ) |
219 |
| - |
220 |
| - |
221 |
| -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
222 |
| -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) |
223 |
| -def test_deploy_tags(sagemaker_session, tmpdir): |
224 |
| - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
225 |
| - tags = [{"ModelName": "TestModel"}] |
226 |
| - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags) |
227 |
| - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
228 |
| - name=MODEL_NAME, |
229 |
| - production_variants=[ |
230 |
| - { |
231 |
| - "InitialVariantWeight": 1, |
232 |
| - "ModelName": MODEL_NAME, |
233 |
| - "InstanceType": INSTANCE_TYPE, |
234 |
| - "InitialInstanceCount": 1, |
235 |
| - "VariantName": "AllTraffic", |
236 |
| - } |
237 |
| - ], |
238 |
| - tags=tags, |
239 |
| - kms_key=None, |
240 |
| - wait=True, |
241 |
| - data_capture_config_dict=None, |
242 |
| - ) |
243 |
| - |
244 |
| - |
245 |
| -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
246 |
| -@patch("tarfile.open") |
247 |
| -@patch("time.strftime", return_value=TIMESTAMP) |
248 |
| -def test_deploy_accelerator_type(tfo, time, sagemaker_session): |
249 |
| - model = DummyFrameworkModel(sagemaker_session) |
250 |
| - model.deploy( |
251 |
| - instance_type=INSTANCE_TYPE, initial_instance_count=1, accelerator_type=ACCELERATOR_TYPE |
252 |
| - ) |
253 |
| - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
254 |
| - name=MODEL_NAME, |
255 |
| - production_variants=[ |
256 |
| - { |
257 |
| - "InitialVariantWeight": 1, |
258 |
| - "ModelName": MODEL_NAME, |
259 |
| - "InstanceType": INSTANCE_TYPE, |
260 |
| - "InitialInstanceCount": 1, |
261 |
| - "VariantName": "AllTraffic", |
262 |
| - "AcceleratorType": ACCELERATOR_TYPE, |
263 |
| - } |
264 |
| - ], |
265 |
| - tags=None, |
266 |
| - kms_key=None, |
267 |
| - wait=True, |
268 |
| - data_capture_config_dict=None, |
269 |
| - ) |
270 |
| - |
271 |
| - |
272 |
| -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
273 |
| -@patch("tarfile.open") |
274 |
| -@patch("time.strftime", return_value=TIMESTAMP) |
275 |
| -def test_deploy_kms_key(tfo, time, sagemaker_session): |
276 |
| - key = "some-key-arn" |
277 |
| - model = DummyFrameworkModel(sagemaker_session) |
278 |
| - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, kms_key=key) |
279 |
| - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
280 |
| - name=MODEL_NAME, |
281 |
| - production_variants=[ |
282 |
| - { |
283 |
| - "InitialVariantWeight": 1, |
284 |
| - "ModelName": MODEL_NAME, |
285 |
| - "InstanceType": INSTANCE_TYPE, |
286 |
| - "InitialInstanceCount": 1, |
287 |
| - "VariantName": "AllTraffic", |
288 |
| - } |
289 |
| - ], |
290 |
| - tags=None, |
291 |
| - kms_key=key, |
292 |
| - wait=True, |
293 |
| - data_capture_config_dict=None, |
294 |
| - ) |
295 |
| - |
296 |
| - |
297 |
| -@patch("sagemaker.session.Session") |
298 |
| -@patch("sagemaker.local.LocalSession") |
299 |
| -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
300 |
| -def test_deploy_creates_correct_session(local_session, session, tmpdir): |
301 |
| - # We expect a LocalSession when deploying to instance_type = 'local' |
302 |
| - model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) |
303 |
| - model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) |
304 |
| - assert model.sagemaker_session == local_session.return_value |
305 |
| - |
306 |
| - # We expect a real Session when deploying to instance_type != local/local_gpu |
307 |
| - model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) |
308 |
| - model.deploy( |
309 |
| - endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2 |
310 |
| - ) |
311 |
| - assert model.sagemaker_session == session.return_value |
312 |
| - |
313 |
| - |
314 | 174 | @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
|
315 | 175 | def test_deploy_update_endpoint(sagemaker_session, tmpdir):
|
316 | 176 | model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir)
|
|
0 commit comments