Skip to content

Commit 3968264

Browse files
author
Mike Schneider
authored
Merge pull request #3 from ShiboXing/add-pt1.13.1-training
Add pt1.13.1 training
2 parents a6ad8ba + d3a7822 commit 3968264

File tree

4 files changed

+35
-10
lines changed

4 files changed

+35
-10
lines changed

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@
264264
},
265265
"1.4.0": {
266266
"py_versions": [
267+
"py2",
267268
"py3"
268269
],
269270
"registries": {

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"huggingface_training_compiler",
8787
)
8888

89+
PYTORCH_RENEWED_GPU = "ml.g4dn.xlarge"
8990

9091
def pytest_addoption(parser):
9192
parser.addoption("--sagemaker-client-config", action="store", default=None)
@@ -514,6 +515,21 @@ def gpu_instance_type(sagemaker_session, request):
514515
else:
515516
return "ml.p3.2xlarge"
516517

518+
@pytest.fixture()
519+
def gpu_pytorch_instance_type(sagemaker_session, request):
520+
if "pytorch_inference_version" in request.fixturenames:
521+
fw_version = request.getfixturevalue("pytorch_inference_version")
522+
else:
523+
fw_version = request.param
524+
525+
region = sagemaker_session.boto_session.region_name
526+
if region in NO_P3_REGIONS:
527+
if Version(fw_version) >= Version("1.13"):
528+
return PYTORCH_RENEWED_GPU
529+
else:
530+
return "ml.p2.xlarge"
531+
else:
532+
return "ml.p3.2xlarge"
517533

518534
@pytest.fixture(scope="session")
519535
def gpu_instance_type_list(sagemaker_session, request):

tests/unit/sagemaker/image_uris/test_dlc_frameworks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tests.unit.sagemaker.image_uris import expected_uris
1919

2020
INSTANCE_TYPES_AND_PROCESSORS = (("ml.c4.xlarge", "cpu"), ("ml.p2.xlarge", "gpu"))
21+
RENEWED_PYTORCH_INSTANCE_TYPES_AND_PROCESSORS = (("ml.c4.xlarge", "cpu"), ("ml.g4dn.xlarge", "gpu"))
2122
REGION = "us-west-2"
2223

2324
DLC_ACCOUNT = "763104351884"
@@ -70,7 +71,12 @@ def _test_image_uris(
7071
"image_scope": scope,
7172
}
7273

73-
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
74+
TYPES_AND_PROCESSORS = INSTANCE_TYPES_AND_PROCESSORS
75+
if framework == "pytorch" and Version(fw_version) >= Version("1.13"):
76+
'''Handle P2 deprecation'''
77+
TYPES_AND_PROCESSORS = RENEWED_PYTORCH_INSTANCE_TYPES_AND_PROCESSORS
78+
79+
for instance_type, processor in TYPES_AND_PROCESSORS:
7480
uri = image_uris.retrieve(region=REGION, instance_type=instance_type, **base_args)
7581

7682
expected = expected_fn(processor=processor, **expected_fn_args)

tests/unit/test_pytorch.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
302302
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
303303
@patch("time.time", return_value=TIME)
304304
def test_pytorch(
305-
time, name_from_base, sagemaker_session, pytorch_inference_version, pytorch_inference_py_version
305+
time, name_from_base, sagemaker_session, pytorch_inference_version, pytorch_inference_py_version, gpu_pytorch_instance_type
306306
):
307307
pytorch = PyTorch(
308308
entry_point=SCRIPT_PATH,
@@ -339,24 +339,24 @@ def test_pytorch(
339339
REGION,
340340
version=pytorch_inference_version,
341341
py_version=pytorch_inference_py_version,
342-
instance_type=GPU,
342+
instance_type=gpu_pytorch_instance_type,
343343
image_scope="inference",
344344
)
345345

346-
actual_environment = model.prepare_container_def(GPU)
346+
actual_environment = model.prepare_container_def(gpu_pytorch_instance_type)
347347
submit_directory = actual_environment["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"]
348348
model_url = actual_environment["ModelDataUrl"]
349349
expected_environment = _get_environment(submit_directory, model_url, expected_image_uri)
350350
assert actual_environment == expected_environment
351351

352352
assert "cpu" in model.prepare_container_def(CPU)["Image"]
353-
predictor = pytorch.deploy(1, GPU)
353+
predictor = pytorch.deploy(1, gpu_pytorch_instance_type)
354354
assert isinstance(predictor, PyTorchPredictor)
355355

356356

357357
@patch("sagemaker.utils.repack_model", MagicMock())
358358
@patch("sagemaker.utils.create_tar_file", MagicMock())
359-
def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_py_version):
359+
def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_py_version, gpu_pytorch_instance_type):
360360
model = PyTorchModel(
361361
MODEL_DATA,
362362
role=ROLE,
@@ -365,21 +365,22 @@ def test_model(sagemaker_session, pytorch_inference_version, pytorch_inference_p
365365
py_version=pytorch_inference_py_version,
366366
sagemaker_session=sagemaker_session,
367367
)
368-
predictor = model.deploy(1, GPU)
368+
predictor = model.deploy(1, gpu_pytorch_instance_type)
369369
assert isinstance(predictor, PyTorchPredictor)
370370

371371

372372
@patch("sagemaker.utils.create_tar_file", MagicMock())
373373
@patch("sagemaker.utils.repack_model")
374-
def test_mms_model(repack_model, sagemaker_session):
374+
@pytest.mark.parametrize("gpu_pytorch_instance_type", ["1.2"], indirect=True)
375+
def test_mms_model(repack_model, sagemaker_session, gpu_pytorch_instance_type):
375376
PyTorchModel(
376377
MODEL_DATA,
377378
role=ROLE,
378379
entry_point=SCRIPT_PATH,
379380
sagemaker_session=sagemaker_session,
380381
framework_version="1.2",
381382
py_version="py3",
382-
).deploy(1, GPU)
383+
).deploy(1, gpu_pytorch_instance_type)
383384

384385
repack_model.assert_called_with(
385386
dependencies=[],
@@ -428,6 +429,7 @@ def test_model_custom_serialization(
428429
sagemaker_session,
429430
pytorch_inference_version,
430431
pytorch_inference_py_version,
432+
gpu_pytorch_instance_type
431433
):
432434
model = PyTorchModel(
433435
MODEL_DATA,
@@ -441,7 +443,7 @@ def test_model_custom_serialization(
441443
custom_deserializer = Mock()
442444
predictor = model.deploy(
443445
1,
444-
GPU,
446+
gpu_pytorch_instance_type,
445447
serializer=custom_serializer,
446448
deserializer=custom_deserializer,
447449
)

0 commit comments

Comments
 (0)