|
22 | 22 | from sagemaker import KMeans
|
23 | 23 | from sagemaker.mxnet import MXNet
|
24 | 24 | from sagemaker.transformer import Transformer
|
| 25 | +from sagemaker.estimator import Estimator |
25 | 26 | from sagemaker.utils import unique_name_from_base
|
26 | 27 | from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
|
27 | 28 | from tests.integ.kms_utils import get_or_create_kms_key
|
@@ -148,6 +149,89 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
|
148 | 149 | assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
|
149 | 150 |
|
150 | 151 |
|
| 152 | +def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version): |
| 153 | + data_path = os.path.join(DATA_DIR, 'mxnet_mnist') |
| 154 | + script_path = os.path.join(data_path, 'mnist.py') |
| 155 | + tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] |
| 156 | + |
| 157 | + mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, |
| 158 | + train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, |
| 159 | + framework_version=mxnet_full_version) |
| 160 | + |
| 161 | + train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), |
| 162 | + key_prefix='integ-test-data/mxnet_mnist/train') |
| 163 | + test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), |
| 164 | + key_prefix='integ-test-data/mxnet_mnist/test') |
| 165 | + job_name = unique_name_from_base('test-mxnet-transform') |
| 166 | + |
| 167 | + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): |
| 168 | + mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) |
| 169 | + |
| 170 | + transform_input_path = os.path.join(data_path, 'transform', 'data.csv') |
| 171 | + transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' |
| 172 | + transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, |
| 173 | + key_prefix=transform_input_key_prefix) |
| 174 | + |
| 175 | + transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags) |
| 176 | + transformer.transform(transform_input, content_type='text/csv') |
| 177 | + |
| 178 | + with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, |
| 179 | + minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): |
| 180 | + transformer.wait() |
| 181 | + model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name) |
| 182 | + model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags'] |
| 183 | + assert tags == model_tags |
| 184 | + |
| 185 | + |
| 186 | +def test_transform_byo_estimator(sagemaker_session): |
| 187 | + data_path = os.path.join(DATA_DIR, 'one_p_mnist') |
| 188 | + pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} |
| 189 | + tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] |
| 190 | + |
| 191 | + # Load the data into memory as numpy arrays |
| 192 | + train_set_path = os.path.join(data_path, 'mnist.pkl.gz') |
| 193 | + with gzip.open(train_set_path, 'rb') as f: |
| 194 | + train_set, _, _ = pickle.load(f, **pickle_args) |
| 195 | + |
| 196 | + kmeans = KMeans(role='SageMakerRole', train_instance_count=1, |
| 197 | + train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session, |
| 198 | + output_path='s3://{}/'.format(sagemaker_session.default_bucket())) |
| 199 | + |
| 200 | + # set kmeans specific hp |
| 201 | + kmeans.init_method = 'random' |
| 202 | + kmeans.max_iterators = 1 |
| 203 | + kmeans.tol = 1 |
| 204 | + kmeans.num_trials = 1 |
| 205 | + kmeans.local_init_method = 'kmeans++' |
| 206 | + kmeans.half_life_time_size = 1 |
| 207 | + kmeans.epochs = 1 |
| 208 | + |
| 209 | + records = kmeans.record_set(train_set[0][:100]) |
| 210 | + |
| 211 | + job_name = unique_name_from_base('test-kmeans-attach') |
| 212 | + |
| 213 | + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): |
| 214 | + kmeans.fit(records, job_name=job_name) |
| 215 | + |
| 216 | + transform_input_path = os.path.join(data_path, 'transform_input.csv') |
| 217 | + transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform' |
| 218 | + transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path, |
| 219 | + key_prefix=transform_input_key_prefix) |
| 220 | + |
| 221 | + estimator = Estimator.attach(training_job_name=job_name, |
| 222 | + sagemaker_session=sagemaker_session) |
| 223 | + |
| 224 | + transformer = estimator.transformer(1, 'ml.m4.xlarge', tags=tags) |
| 225 | + transformer.transform(transform_input, content_type='text/csv') |
| 226 | + |
| 227 | + with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, |
| 228 | + minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): |
| 229 | + transformer.wait() |
| 230 | + model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name) |
| 231 | + model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags'] |
| 232 | + assert tags == model_tags |
| 233 | + |
| 234 | + |
151 | 235 | def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
|
152 | 236 | transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
|
153 | 237 | transformer.transform(transform_input, content_type='text/csv')
|
|
0 commit comments