Skip to content

Commit 3be2a11

Browse files
author
Chuyang Deng
committed
use 'pytorch_inference_latest_version' for tests that deploys pytorch model
1 parent e47a516 commit 3be2a11

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

tests/integ/test_airflow_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,15 +574,15 @@ def test_xgboost_airflow_config_uploads_data_source_to_s3(
574574
def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provided(
575575
sagemaker_session,
576576
cpu_instance_type,
577-
pytorch_training_latest_version,
578-
pytorch_training_latest_py_version,
577+
pytorch_inference_latest_version,
578+
pytorch_inference_latest_py_version,
579579
):
580580
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
581581
estimator = PyTorch(
582582
entry_point=PYTORCH_MNIST_SCRIPT,
583583
role=ROLE,
584-
framework_version=pytorch_training_latest_version,
585-
py_version=pytorch_training_latest_py_version,
584+
framework_version=pytorch_inference_latest_version,
585+
py_version=pytorch_inference_latest_py_version,
586586
instance_count=2,
587587
instance_type=cpu_instance_type,
588588
hyperparameters={"epochs": 6, "backend": "gloo"},

tests/integ/test_pytorch.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,26 @@ def fixture_training_job(
5555
return pytorch.latest_training_job.name
5656

5757

58+
@pytest.fixture(scope="module", name="pytorch_training_job_with_latest_infernce_version")
59+
def fixture_training_job_with_latest_inference_version(
60+
sagemaker_session,
61+
pytorch_inference_latest_version,
62+
pytorch_inference_latest_py_version,
63+
cpu_instance_type,
64+
):
65+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
66+
pytorch = _get_pytorch_estimator(
67+
sagemaker_session,
68+
pytorch_inference_latest_version,
69+
pytorch_inference_latest_py_version,
70+
cpu_instance_type,
71+
)
72+
pytorch.fit({"training": _upload_training_data(pytorch)})
73+
return pytorch.latest_training_job.name
74+
75+
5876
@pytest.mark.canary_quick
59-
def test_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_type):
77+
def test_fit_deploy(pytorch_training_job_with_latest_infernce_version, sagemaker_session, cpu_instance_type):
6078
endpoint_name = "test-pytorch-sync-fit-attach-deploy{}".format(sagemaker_timestamp())
6179
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
6280
estimator = PyTorch.attach(pytorch_training_job, sagemaker_session=sagemaker_session)

0 commit comments

Comments
 (0)