|
18 | 18 | import pytest
|
19 | 19 | import numpy
|
20 | 20 |
|
| 21 | +from sagemaker.chainer.defaults import CHAINER_VERSION |
21 | 22 | from sagemaker.chainer.estimator import Chainer
|
22 | 23 | from sagemaker.chainer.model import ChainerModel
|
23 | 24 | from sagemaker.utils import sagemaker_timestamp
|
@@ -70,25 +71,23 @@ def test_attach_deploy(chainer_training_job, sagemaker_session):
|
70 | 71 | _predict_and_assert(predictor)
|
71 | 72 |
|
72 | 73 |
|
73 |
| -def test_deploy_model(chainer_training_job, sagemaker_session, chainer_full_version): |
| 74 | +def test_deploy_model(chainer_training_job, sagemaker_session): |
74 | 75 | endpoint_name = 'test-chainer-deploy-model-{}'.format(sagemaker_timestamp())
|
75 | 76 | with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20):
|
76 | 77 | desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=chainer_training_job)
|
77 | 78 | model_data = desc['ModelArtifacts']['S3ModelArtifacts']
|
78 | 79 | script_path = os.path.join(DATA_DIR, 'chainer_mnist', 'mnist.py')
|
79 |
| - model = ChainerModel(model_data, 'SageMakerRole', entry_point=script_path, |
80 |
| - framework_version=chainer_full_version, |
81 |
| - sagemaker_session=sagemaker_session) |
| 80 | + model = ChainerModel(model_data, 'SageMakerRole', entry_point=script_path, sagemaker_session=sagemaker_session) |
82 | 81 | predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name)
|
83 | 82 | _predict_and_assert(predictor)
|
84 | 83 |
|
85 | 84 |
|
86 |
| -def test_async_fit(sagemaker_session, chainer_full_version): |
| 85 | +def test_async_fit(sagemaker_session): |
87 | 86 | endpoint_name = 'test-chainer-attach-deploy-{}'.format(sagemaker_timestamp())
|
88 | 87 |
|
89 | 88 | with timeout(minutes=5):
|
90 | 89 | training_job_name = _run_mnist_training_job(sagemaker_session, "ml.c4.xlarge", 1,
|
91 |
| - chainer_full_version, wait=False) |
| 90 | + chainer_full_version=CHAINER_VERSION, wait=False) |
92 | 91 |
|
93 | 92 | print("Waiting to re-attach to the training job: %s" % training_job_name)
|
94 | 93 | time.sleep(20)
|
|
0 commit comments