@@ -128,20 +128,20 @@ def _create_train_job(toolkit, toolkit_version, framework):
128
128
}
129
129
130
130
131
- def test_create_tf_model (sagemaker_session , rl_coach_version ):
131
+ def test_create_tf_model (sagemaker_session , rl_coach_tf_version ):
132
132
container_log_level = '"logging.INFO"'
133
133
source_dir = 's3://mybucket/source'
134
134
rl = RLEstimator (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
135
135
train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
136
- toolkit = RLToolkit .COACH , toolkit_version = rl_coach_version ,
136
+ toolkit = RLToolkit .COACH , toolkit_version = rl_coach_tf_version ,
137
137
framework = RLFramework .TENSORFLOW , container_log_level = container_log_level ,
138
138
source_dir = source_dir )
139
139
140
140
job_name = 'new_name'
141
141
rl .fit (inputs = 's3://mybucket/train' , job_name = 'new_name' )
142
142
model = rl .create_model ()
143
143
supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP [RLToolkit .COACH .value ]
144
- framework_version = supported_versions [rl_coach_version ][RLFramework .TENSORFLOW .value ]
144
+ framework_version = supported_versions [rl_coach_tf_version ][RLFramework .TENSORFLOW .value ]
145
145
146
146
assert isinstance (model , tfs .Model )
147
147
assert model .sagemaker_session == sagemaker_session
@@ -152,20 +152,20 @@ def test_create_tf_model(sagemaker_session, rl_coach_version):
152
152
assert model .vpc_config is None
153
153
154
154
155
- def test_create_mxnet_model (sagemaker_session , rl_coach_version ):
155
+ def test_create_mxnet_model (sagemaker_session , rl_coach_mxnet_version ):
156
156
container_log_level = '"logging.INFO"'
157
157
source_dir = 's3://mybucket/source'
158
158
rl = RLEstimator (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
159
159
train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
160
- toolkit = RLToolkit .COACH , toolkit_version = rl_coach_version ,
160
+ toolkit = RLToolkit .COACH , toolkit_version = rl_coach_mxnet_version ,
161
161
framework = RLFramework .MXNET , container_log_level = container_log_level ,
162
162
source_dir = source_dir )
163
163
164
164
job_name = 'new_name'
165
165
rl .fit (inputs = 's3://mybucket/train' , job_name = 'new_name' )
166
166
model = rl .create_model ()
167
167
supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP [RLToolkit .COACH .value ]
168
- framework_version = supported_versions [rl_coach_version ][RLFramework .MXNET .value ]
168
+ framework_version = supported_versions [rl_coach_mxnet_version ][RLFramework .MXNET .value ]
169
169
170
170
assert isinstance (model , MXNetModel )
171
171
assert model .sagemaker_session == sagemaker_session
@@ -179,12 +179,12 @@ def test_create_mxnet_model(sagemaker_session, rl_coach_version):
179
179
assert model .vpc_config is None
180
180
181
181
182
- def test_create_model_with_optional_params (sagemaker_session , rl_coach_version ):
182
+ def test_create_model_with_optional_params (sagemaker_session , rl_coach_mxnet_version ):
183
183
container_log_level = '"logging.INFO"'
184
184
source_dir = 's3://mybucket/source'
185
185
rl = RLEstimator (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
186
186
train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
187
- toolkit = RLToolkit .COACH , toolkit_version = rl_coach_version ,
187
+ toolkit = RLToolkit .COACH , toolkit_version = rl_coach_mxnet_version ,
188
188
framework = RLFramework .MXNET , container_log_level = container_log_level ,
189
189
source_dir = source_dir )
190
190
@@ -226,10 +226,10 @@ def test_create_model_with_custom_image(sagemaker_session):
226
226
227
227
@patch ('sagemaker.utils.create_tar_file' , MagicMock ())
228
228
@patch ('time.strftime' , return_value = TIMESTAMP )
229
- def test_rl (strftime , sagemaker_session , rl_coach_version ):
229
+ def test_rl (strftime , sagemaker_session , rl_coach_mxnet_version ):
230
230
rl = RLEstimator (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
231
231
train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
232
- toolkit = RLToolkit .COACH , toolkit_version = rl_coach_version ,
232
+ toolkit = RLToolkit .COACH , toolkit_version = rl_coach_mxnet_version ,
233
233
framework = RLFramework .MXNET )
234
234
235
235
inputs = 's3://mybucket/train'
@@ -241,7 +241,7 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
241
241
boto_call_names = [c [0 ] for c in sagemaker_session .boto_session .method_calls ]
242
242
assert boto_call_names == ['resource' ]
243
243
244
- expected_train_args = _create_train_job (RLToolkit .COACH .value , rl_coach_version ,
244
+ expected_train_args = _create_train_job (RLToolkit .COACH .value , rl_coach_mxnet_version ,
245
245
RLFramework .MXNET .value )
246
246
expected_train_args ['input_config' ][0 ]['DataSource' ]['S3DataSource' ]['S3Uri' ] = inputs
247
247
@@ -250,7 +250,7 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
250
250
251
251
model = rl .create_model ()
252
252
supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP [RLToolkit .COACH .value ]
253
- framework_version = supported_versions [rl_coach_version ][RLFramework .MXNET .value ]
253
+ framework_version = supported_versions [rl_coach_mxnet_version ][RLFramework .MXNET .value ]
254
254
255
255
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py3'
256
256
submit_dir = 's3://notmybucket/sagemaker-rl-mxnet-{}/source/sourcedir.tar.gz' .format (TIMESTAMP )
@@ -266,17 +266,17 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
266
266
267
267
268
268
@patch ('sagemaker.utils.create_tar_file' , MagicMock ())
269
- def test_deploy_mxnet (sagemaker_session , rl_coach_version ):
270
- rl = _rl_estimator (sagemaker_session , RLToolkit .COACH , rl_coach_version , RLFramework .MXNET ,
269
+ def test_deploy_mxnet (sagemaker_session , rl_coach_mxnet_version ):
270
+ rl = _rl_estimator (sagemaker_session , RLToolkit .COACH , rl_coach_mxnet_version , RLFramework .MXNET ,
271
271
train_instance_type = 'ml.g2.2xlarge' )
272
272
rl .fit ()
273
273
predictor = rl .deploy (1 , CPU )
274
274
assert isinstance (predictor , MXNetPredictor )
275
275
276
276
277
277
@patch ('sagemaker.utils.create_tar_file' , MagicMock ())
278
- def test_deploy_tfs (sagemaker_session , rl_coach_version ):
279
- rl = _rl_estimator (sagemaker_session , RLToolkit .COACH , rl_coach_version , RLFramework .TENSORFLOW ,
278
+ def test_deploy_tfs (sagemaker_session , rl_coach_tf_version ):
279
+ rl = _rl_estimator (sagemaker_session , RLToolkit .COACH , rl_coach_tf_version , RLFramework .TENSORFLOW ,
280
280
train_instance_type = 'ml.g2.2xlarge' )
281
281
rl .fit ()
282
282
predictor = rl .deploy (1 , GPU )
@@ -312,25 +312,25 @@ def test_train_image_cpu_instances(sagemaker_session, rl_ray_version):
312
312
framework .value )
313
313
314
314
315
- def test_train_image_gpu_instances (sagemaker_session , rl_coach_version ):
315
+ def test_train_image_gpu_instances (sagemaker_session , rl_coach_mxnet_version ):
316
316
toolkit = RLToolkit .COACH
317
317
framework = RLFramework .MXNET
318
- rl = _rl_estimator (sagemaker_session , toolkit , rl_coach_version , framework ,
318
+ rl = _rl_estimator (sagemaker_session , toolkit , rl_coach_mxnet_version , framework ,
319
319
train_instance_type = 'ml.g2.2xlarge' )
320
- assert rl .train_image () == _get_full_gpu_image_uri (toolkit .value , rl_coach_version ,
320
+ assert rl .train_image () == _get_full_gpu_image_uri (toolkit .value , rl_coach_mxnet_version ,
321
321
framework .value )
322
322
323
- rl = _rl_estimator (sagemaker_session , toolkit , rl_coach_version , framework ,
323
+ rl = _rl_estimator (sagemaker_session , toolkit , rl_coach_mxnet_version , framework ,
324
324
train_instance_type = 'ml.p2.2xlarge' )
325
- assert rl .train_image () == _get_full_gpu_image_uri (toolkit .value , rl_coach_version ,
325
+ assert rl .train_image () == _get_full_gpu_image_uri (toolkit .value , rl_coach_mxnet_version ,
326
326
framework .value )
327
327
328
328
329
- def test_attach (sagemaker_session , rl_coach_version ):
329
+ def test_attach (sagemaker_session , rl_coach_mxnet_version ):
330
330
training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-{}:{}{}-cpu-py3' \
331
- .format (RLFramework .MXNET .value , RLToolkit .COACH .value , rl_coach_version )
331
+ .format (RLFramework .MXNET .value , RLToolkit .COACH .value , rl_coach_mxnet_version )
332
332
supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP [RLToolkit .COACH .value ]
333
- framework_version = supported_versions [rl_coach_version ][RLFramework .MXNET .value ]
333
+ framework_version = supported_versions [rl_coach_mxnet_version ][RLFramework .MXNET .value ]
334
334
returned_job_description = {'AlgorithmSpecification' : {'TrainingInputMode' : 'File' ,
335
335
'TrainingImage' : training_image },
336
336
'HyperParameters' :
@@ -361,7 +361,7 @@ def test_attach(sagemaker_session, rl_coach_version):
361
361
assert estimator .framework == RLFramework .MXNET .value
362
362
assert estimator .toolkit == RLToolkit .COACH .value
363
363
assert estimator .framework_version == framework_version
364
- assert estimator .toolkit_version == rl_coach_version
364
+ assert estimator .toolkit_version == rl_coach_mxnet_version
365
365
assert estimator .role == 'arn:aws:iam::366:role/SageMakerRole'
366
366
assert estimator .train_instance_count == 1
367
367
assert estimator .train_max_run == 24 * 60 * 60
0 commit comments