|
15 | 15 | import os
|
16 | 16 |
|
17 | 17 | import boto3
|
18 |
| -import pytest |
19 | 18 | from sagemaker.tensorflow import TensorFlow
|
20 | 19 | from six.moves.urllib.parse import urlparse
|
21 | 20 |
|
@@ -81,23 +80,21 @@ def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type, frame
|
81 | 80 | _assert_s3_file_exists(sagemaker_session.boto_region_name, estimator.model_data)
|
82 | 81 |
|
83 | 82 |
|
84 |
| -# TODO: Enable this test when new binary fixing the s3 plugin released |
85 |
| -@pytest.mark.skip(reason='Skip the test until new binary released') |
86 | 83 | def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framework_version):
|
87 | 84 | resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
|
88 | 85 | script = os.path.join(resource_path, 'mnist', 'mnist_estimator.py')
|
89 | 86 | estimator = TensorFlow(entry_point=script,
|
90 | 87 | role='SageMakerRole',
|
91 | 88 | hyperparameters={
|
92 |
| - # Saving a checkpoint after every step to hammer the S3 plugin |
93 |
| - 'save-checkpoint-steps': 1, |
| 89 | + # Saving a checkpoint after every 5 steps to hammer the S3 plugin |
| 90 | + 'save-checkpoint-steps': 10, |
94 | 91 | # Disable throttling for checkpoint and model saving
|
95 | 92 | 'throttle-secs': 0,
|
96 | 93 | # Without the patch training jobs would fail around 100th to
|
97 | 94 | # 150th step
|
98 | 95 | 'max-steps': 200,
|
99 | 96 | # Large batch size would result in a larger checkpoint file
|
100 |
| - 'batch-size': 2048, |
| 97 | + 'batch-size': 1024, |
101 | 98 | # This makes the training job exporting model during training.
|
102 | 99 | # Stale model garbage collection will also be performed.
|
103 | 100 | 'export-model-during-training': True
|
|
0 commit comments