@@ -35,7 +35,15 @@ def test_create_training_job(train, LocalSession):
35
35
image = "my-docker-image:1.0"
36
36
37
37
algo_spec = {'TrainingImage' : image }
38
- input_data_config = {}
38
+ input_data_config = [{
39
+ 'ChannelName' : 'a' ,
40
+ 'DataSource' : {
41
+ 'S3DataSource' : {
42
+ 'S3DataDistributionType' : 'FullyReplicated' ,
43
+ 'S3Uri' : 's3://my_bucket/tmp/source1'
44
+ }
45
+ }
46
+ }]
39
47
output_data_config = {}
40
48
resource_config = {'InstanceType' : 'local' , 'InstanceCount' : instance_count }
41
49
hyperparameters = {'a' : 1 , 'b' : 'bee' }
@@ -61,6 +69,67 @@ def test_create_training_job(train, LocalSession):
61
69
assert response ['ModelArtifacts' ]['S3ModelArtifacts' ] == expected ['ModelArtifacts' ]['S3ModelArtifacts' ]
62
70
63
71
72
+ @patch ('sagemaker.local.image._SageMakerContainer.train' , return_value = "/some/path/to/model" )
73
+ @patch ('sagemaker.local.local_session.LocalSession' )
74
+ def test_create_training_job_invalid_data_source (train , LocalSession ):
75
+ local_sagemaker_client = sagemaker .local .local_session .LocalSagemakerClient ()
76
+
77
+ instance_count = 2
78
+ image = "my-docker-image:1.0"
79
+
80
+ algo_spec = {'TrainingImage' : image }
81
+
82
+ # InvalidDataSource is not supported. S3DataSource and FileDataSource are currently the only
83
+ # valid Data Sources. We expect a ValueError if we pass this input data config.
84
+ input_data_config = [{
85
+ 'ChannelName' : 'a' ,
86
+ 'DataSource' : {
87
+ 'InvalidDataSource' : {
88
+ 'FileDataDistributionType' : 'FullyReplicated' ,
89
+ 'FileUri' : 'ftp://myserver.com/tmp/source1'
90
+ }
91
+ }
92
+ }]
93
+
94
+ output_data_config = {}
95
+ resource_config = {'InstanceType' : 'local' , 'InstanceCount' : instance_count }
96
+ hyperparameters = {'a' : 1 , 'b' : 'bee' }
97
+
98
+ with pytest .raises (ValueError ):
99
+ local_sagemaker_client .create_training_job ("my-training-job" , algo_spec , 'arn:my-role' , input_data_config ,
100
+ output_data_config , resource_config , None , hyperparameters )
101
+
102
+
103
+ @patch ('sagemaker.local.image._SageMakerContainer.train' , return_value = "/some/path/to/model" )
104
+ @patch ('sagemaker.local.local_session.LocalSession' )
105
+ def test_create_training_job_not_fully_replicated (train , LocalSession ):
106
+ local_sagemaker_client = sagemaker .local .local_session .LocalSagemakerClient ()
107
+
108
+ instance_count = 2
109
+ image = "my-docker-image:1.0"
110
+
111
+ algo_spec = {'TrainingImage' : image }
112
+
113
+ # Local Mode only supports FullyReplicated as Data Distribution type.
114
+ input_data_config = [{
115
+ 'ChannelName' : 'a' ,
116
+ 'DataSource' : {
117
+ 'S3DataSource' : {
118
+ 'S3DataDistributionType' : 'ShardedByS3Key' ,
119
+ 'S3Uri' : 's3://my_bucket/tmp/source1'
120
+ }
121
+ }
122
+ }]
123
+
124
+ output_data_config = {}
125
+ resource_config = {'InstanceType' : 'local' , 'InstanceCount' : instance_count }
126
+ hyperparameters = {'a' : 1 , 'b' : 'bee' }
127
+
128
+ with pytest .raises (RuntimeError ):
129
+ local_sagemaker_client .create_training_job ("my-training-job" , algo_spec , 'arn:my-role' , input_data_config ,
130
+ output_data_config , resource_config , None , hyperparameters )
131
+
132
+
64
133
@patch ('sagemaker.local.local_session.LocalSession' )
65
134
def test_create_model (LocalSession ):
66
135
local_sagemaker_client = sagemaker .local .local_session .LocalSagemakerClient ()
@@ -130,3 +199,34 @@ def test_create_endpoint_fails(serve, request, LocalSession):
130
199
131
200
with pytest .raises (RuntimeError ):
132
201
local_sagemaker_client .create_endpoint ('my-endpoint' , 'some-endpoint-config' )
202
+
203
+
204
+ def test_file_input_all_defaults ():
205
+ prefix = 'pre'
206
+ actual = sagemaker .local .local_session .file_input (fileUri = prefix )
207
+ expected = \
208
+ {
209
+ 'DataSource' : {
210
+ 'FileDataSource' : {
211
+ 'FileDataDistributionType' : 'FullyReplicated' ,
212
+ 'FileUri' : prefix
213
+ }
214
+ }
215
+ }
216
+ assert actual .config == expected
217
+
218
+
219
+ def test_file_input_content_type ():
220
+ prefix = 'pre'
221
+ actual = sagemaker .local .local_session .file_input (fileUri = prefix , content_type = 'text/csv' )
222
+ expected = \
223
+ {
224
+ 'DataSource' : {
225
+ 'FileDataSource' : {
226
+ 'FileDataDistributionType' : 'FullyReplicated' ,
227
+ 'FileUri' : prefix
228
+ }
229
+ },
230
+ 'ContentType' : 'text/csv'
231
+ }
232
+ assert actual .config == expected
0 commit comments