19
19
import pytest
20
20
21
21
from sagemaker .tensorflow import TensorFlow
22
- from sagemaker .utils import unique_name_from_base
22
+ from sagemaker .utils import unique_name_from_base , sagemaker_timestamp
23
23
24
24
import tests .integ
25
25
from tests .integ import timeout
39
39
TAGS = [{"Key" : "some-key" , "Value" : "some-value" }]
40
40
41
41
42
- def test_mnist (sagemaker_session , instance_type ):
42
+ def test_mnist_with_checkpoint_config (sagemaker_session , instance_type ):
43
+ checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}" .format (
44
+ sagemaker_session .default_bucket (), sagemaker_timestamp ()
45
+ )
46
+ checkpoint_local_path = "/test/checkpoint/path"
43
47
estimator = TensorFlow (
44
48
entry_point = SCRIPT ,
45
49
role = "SageMakerRole" ,
@@ -50,13 +54,16 @@ def test_mnist(sagemaker_session, instance_type):
50
54
framework_version = TensorFlow .LATEST_VERSION ,
51
55
py_version = tests .integ .PYTHON_VERSION ,
52
56
metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
57
+ checkpoint_s3_uri = checkpoint_s3_uri ,
58
+ checkpoint_local_path = checkpoint_local_path ,
53
59
)
54
60
inputs = estimator .sagemaker_session .upload_data (
55
61
path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/mnist"
56
62
)
57
63
64
+ training_job_name = unique_name_from_base ("test-tf-sm-mnist" )
58
65
with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
59
- estimator .fit (inputs = inputs , job_name = unique_name_from_base ( "test-tf-sm-mnist" ) )
66
+ estimator .fit (inputs = inputs , job_name = training_job_name )
60
67
assert_s3_files_exist (
61
68
sagemaker_session ,
62
69
estimator .model_dir ,
@@ -65,33 +72,6 @@ def test_mnist(sagemaker_session, instance_type):
65
72
df = estimator .training_job_analytics .dataframe ()
66
73
assert df .size > 0
67
74
68
-
69
- @pytest .mark .skipif (
70
- tests .integ .test_region () != "us-east-1" ,
71
- reason = "checkpoint s3 bucket is in us-east-1, ListObjectsV2 will fail in other regions" ,
72
- )
73
- def test_checkpoint_config (sagemaker_session , instance_type ):
74
- checkpoint_s3_uri = "s3://142577830533-us-east-1-sagemaker-checkpoint"
75
- checkpoint_local_path = "/test/checkpoint/path"
76
- estimator = TensorFlow (
77
- entry_point = SCRIPT ,
78
- role = "SageMakerRole" ,
79
- train_instance_count = 1 ,
80
- train_instance_type = instance_type ,
81
- sagemaker_session = sagemaker_session ,
82
- script_mode = True ,
83
- framework_version = TensorFlow .LATEST_VERSION ,
84
- py_version = tests .integ .PYTHON_VERSION ,
85
- checkpoint_s3_uri = checkpoint_s3_uri ,
86
- checkpoint_local_path = checkpoint_local_path ,
87
- )
88
- inputs = estimator .sagemaker_session .upload_data (
89
- path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "script/mnist"
90
- )
91
- training_job_name = unique_name_from_base ("test-tf-sm-checkpoint" )
92
- with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
93
- estimator .fit (inputs = inputs , job_name = training_job_name )
94
-
95
75
expected_training_checkpoint_config = {
96
76
"S3Uri" : checkpoint_s3_uri ,
97
77
"LocalPath" : checkpoint_local_path ,
0 commit comments