@@ -148,6 +148,40 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
148
148
assert [security_group_id ] == model_desc ['VpcConfig' ]['SecurityGroupIds' ]
149
149
150
150
151
+ def test_transform_mxnet_logs (sagemaker_session , mxnet_full_version ):
152
+ data_path = os .path .join (DATA_DIR , 'mxnet_mnist' )
153
+ script_path = os .path .join (data_path , 'mnist.py' )
154
+ tags = [{'Key' : 'some-tag' , 'Value' : 'value-for-tag' }]
155
+
156
+ mx = MXNet (entry_point = script_path , role = 'SageMakerRole' , train_instance_count = 1 ,
157
+ train_instance_type = 'ml.c4.xlarge' , sagemaker_session = sagemaker_session ,
158
+ framework_version = mxnet_full_version )
159
+
160
+ train_input = mx .sagemaker_session .upload_data (path = os .path .join (data_path , 'train' ),
161
+ key_prefix = 'integ-test-data/mxnet_mnist/train' )
162
+ test_input = mx .sagemaker_session .upload_data (path = os .path .join (data_path , 'test' ),
163
+ key_prefix = 'integ-test-data/mxnet_mnist/test' )
164
+ job_name = unique_name_from_base ('test-mxnet-transform' )
165
+
166
+ with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
167
+ mx .fit ({'train' : train_input , 'test' : test_input }, job_name = job_name )
168
+
169
+ transform_input_path = os .path .join (data_path , 'transform' , 'data.csv' )
170
+ transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
171
+ transform_input = mx .sagemaker_session .upload_data (path = transform_input_path ,
172
+ key_prefix = transform_input_key_prefix )
173
+
174
+ transformer = mx .transformer (1 , 'ml.m4.xlarge' , tags = tags )
175
+ transformer .transform (transform_input , content_type = 'text/csv' )
176
+
177
+ with timeout_and_delete_model_with_transformer (transformer , sagemaker_session ,
178
+ minutes = TRANSFORM_DEFAULT_TIMEOUT_MINUTES ):
179
+ transformer .wait ()
180
+ model_desc = sagemaker_session .sagemaker_client .describe_model (ModelName = transformer .model_name )
181
+ model_tags = sagemaker_session .sagemaker_client .list_tags (ResourceArn = model_desc ['ModelArn' ])['Tags' ]
182
+ assert tags == model_tags
183
+
184
+
151
185
def _create_transformer_and_transform_job (estimator , transform_input , volume_kms_key = None ):
152
186
transformer = estimator .transformer (1 , 'ml.m4.xlarge' , volume_kms_key = volume_kms_key )
153
187
transformer .transform (transform_input , content_type = 'text/csv' )
0 commit comments