Skip to content

Commit 8b9a216

Browse files
nadiayaJonathan Esterhazy
authored andcommitted
add 0.10.1 coach version
1 parent ccd71e6 commit 8b9a216

File tree

3 files changed

+37
-26
lines changed

3 files changed

+37
-26
lines changed

src/sagemaker/rl/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
PYTHON_VERSION = 'py3'
3232
TOOLKIT_FRAMEWORK_VERSION_MAP = {
3333
'coach': {
34+
'0.10.1': {
35+
'tensorflow': '1.11'
36+
},
37+
'0.10': {
38+
'tensorflow': '1.11'
39+
},
3440
'0.11.0': {
3541
'tensorflow': '1.11',
3642
'mxnet': '1.3'

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,13 @@ def tf_version(request):
112112
return request.param
113113

114114

115+
@pytest.fixture(scope='module', params=['0.10.1', '0.10.1', '0.11', '0.11.0'])
116+
def rl_coach_tf_version(request):
117+
return request.param
118+
119+
115120
@pytest.fixture(scope='module', params=['0.11', '0.11.0'])
116-
def rl_coach_version(request):
121+
def rl_coach_mxnet_version(request):
117122
return request.param
118123

119124

tests/unit/test_rl.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,20 @@ def _create_train_job(toolkit, toolkit_version, framework):
128128
}
129129

130130

131-
def test_create_tf_model(sagemaker_session, rl_coach_version):
131+
def test_create_tf_model(sagemaker_session, rl_coach_tf_version):
132132
container_log_level = '"logging.INFO"'
133133
source_dir = 's3://mybucket/source'
134134
rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
135135
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
136-
toolkit=RLToolkit.COACH, toolkit_version=rl_coach_version,
136+
toolkit=RLToolkit.COACH, toolkit_version=rl_coach_tf_version,
137137
framework=RLFramework.TENSORFLOW, container_log_level=container_log_level,
138138
source_dir=source_dir)
139139

140140
job_name = 'new_name'
141141
rl.fit(inputs='s3://mybucket/train', job_name='new_name')
142142
model = rl.create_model()
143143
supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP[RLToolkit.COACH.value]
144-
framework_version = supported_versions[rl_coach_version][RLFramework.TENSORFLOW.value]
144+
framework_version = supported_versions[rl_coach_tf_version][RLFramework.TENSORFLOW.value]
145145

146146
assert isinstance(model, tfs.Model)
147147
assert model.sagemaker_session == sagemaker_session
@@ -152,20 +152,20 @@ def test_create_tf_model(sagemaker_session, rl_coach_version):
152152
assert model.vpc_config is None
153153

154154

155-
def test_create_mxnet_model(sagemaker_session, rl_coach_version):
155+
def test_create_mxnet_model(sagemaker_session, rl_coach_mxnet_version):
156156
container_log_level = '"logging.INFO"'
157157
source_dir = 's3://mybucket/source'
158158
rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
159159
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
160-
toolkit=RLToolkit.COACH, toolkit_version=rl_coach_version,
160+
toolkit=RLToolkit.COACH, toolkit_version=rl_coach_mxnet_version,
161161
framework=RLFramework.MXNET, container_log_level=container_log_level,
162162
source_dir=source_dir)
163163

164164
job_name = 'new_name'
165165
rl.fit(inputs='s3://mybucket/train', job_name='new_name')
166166
model = rl.create_model()
167167
supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP[RLToolkit.COACH.value]
168-
framework_version = supported_versions[rl_coach_version][RLFramework.MXNET.value]
168+
framework_version = supported_versions[rl_coach_mxnet_version][RLFramework.MXNET.value]
169169

170170
assert isinstance(model, MXNetModel)
171171
assert model.sagemaker_session == sagemaker_session
@@ -179,12 +179,12 @@ def test_create_mxnet_model(sagemaker_session, rl_coach_version):
179179
assert model.vpc_config is None
180180

181181

182-
def test_create_model_with_optional_params(sagemaker_session, rl_coach_version):
182+
def test_create_model_with_optional_params(sagemaker_session, rl_coach_mxnet_version):
183183
container_log_level = '"logging.INFO"'
184184
source_dir = 's3://mybucket/source'
185185
rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
186186
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
187-
toolkit=RLToolkit.COACH, toolkit_version=rl_coach_version,
187+
toolkit=RLToolkit.COACH, toolkit_version=rl_coach_mxnet_version,
188188
framework=RLFramework.MXNET, container_log_level=container_log_level,
189189
source_dir=source_dir)
190190

@@ -226,10 +226,10 @@ def test_create_model_with_custom_image(sagemaker_session):
226226

227227
@patch('sagemaker.utils.create_tar_file', MagicMock())
228228
@patch('time.strftime', return_value=TIMESTAMP)
229-
def test_rl(strftime, sagemaker_session, rl_coach_version):
229+
def test_rl(strftime, sagemaker_session, rl_coach_mxnet_version):
230230
rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
231231
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
232-
toolkit=RLToolkit.COACH, toolkit_version=rl_coach_version,
232+
toolkit=RLToolkit.COACH, toolkit_version=rl_coach_mxnet_version,
233233
framework=RLFramework.MXNET)
234234

235235
inputs = 's3://mybucket/train'
@@ -241,7 +241,7 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
241241
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
242242
assert boto_call_names == ['resource']
243243

244-
expected_train_args = _create_train_job(RLToolkit.COACH.value, rl_coach_version,
244+
expected_train_args = _create_train_job(RLToolkit.COACH.value, rl_coach_mxnet_version,
245245
RLFramework.MXNET.value)
246246
expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs
247247

@@ -250,7 +250,7 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
250250

251251
model = rl.create_model()
252252
supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP[RLToolkit.COACH.value]
253-
framework_version = supported_versions[rl_coach_version][RLFramework.MXNET.value]
253+
framework_version = supported_versions[rl_coach_mxnet_version][RLFramework.MXNET.value]
254254

255255
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py3'
256256
submit_dir = 's3://notmybucket/sagemaker-rl-mxnet-{}/source/sourcedir.tar.gz'.format(TIMESTAMP)
@@ -266,17 +266,17 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
266266

267267

268268
@patch('sagemaker.utils.create_tar_file', MagicMock())
269-
def test_deploy_mxnet(sagemaker_session, rl_coach_version):
270-
rl = _rl_estimator(sagemaker_session, RLToolkit.COACH, rl_coach_version, RLFramework.MXNET,
269+
def test_deploy_mxnet(sagemaker_session, rl_coach_mxnet_version):
270+
rl = _rl_estimator(sagemaker_session, RLToolkit.COACH, rl_coach_mxnet_version, RLFramework.MXNET,
271271
train_instance_type='ml.g2.2xlarge')
272272
rl.fit()
273273
predictor = rl.deploy(1, CPU)
274274
assert isinstance(predictor, MXNetPredictor)
275275

276276

277277
@patch('sagemaker.utils.create_tar_file', MagicMock())
278-
def test_deploy_tfs(sagemaker_session, rl_coach_version):
279-
rl = _rl_estimator(sagemaker_session, RLToolkit.COACH, rl_coach_version, RLFramework.TENSORFLOW,
278+
def test_deploy_tfs(sagemaker_session, rl_coach_tf_version):
279+
rl = _rl_estimator(sagemaker_session, RLToolkit.COACH, rl_coach_tf_version, RLFramework.TENSORFLOW,
280280
train_instance_type='ml.g2.2xlarge')
281281
rl.fit()
282282
predictor = rl.deploy(1, GPU)
@@ -312,25 +312,25 @@ def test_train_image_cpu_instances(sagemaker_session, rl_ray_version):
312312
framework.value)
313313

314314

315-
def test_train_image_gpu_instances(sagemaker_session, rl_coach_version):
315+
def test_train_image_gpu_instances(sagemaker_session, rl_coach_mxnet_version):
316316
toolkit = RLToolkit.COACH
317317
framework = RLFramework.MXNET
318-
rl = _rl_estimator(sagemaker_session, toolkit, rl_coach_version, framework,
318+
rl = _rl_estimator(sagemaker_session, toolkit, rl_coach_mxnet_version, framework,
319319
train_instance_type='ml.g2.2xlarge')
320-
assert rl.train_image() == _get_full_gpu_image_uri(toolkit.value, rl_coach_version,
320+
assert rl.train_image() == _get_full_gpu_image_uri(toolkit.value, rl_coach_mxnet_version,
321321
framework.value)
322322

323-
rl = _rl_estimator(sagemaker_session, toolkit, rl_coach_version, framework,
323+
rl = _rl_estimator(sagemaker_session, toolkit, rl_coach_mxnet_version, framework,
324324
train_instance_type='ml.p2.2xlarge')
325-
assert rl.train_image() == _get_full_gpu_image_uri(toolkit.value, rl_coach_version,
325+
assert rl.train_image() == _get_full_gpu_image_uri(toolkit.value, rl_coach_mxnet_version,
326326
framework.value)
327327

328328

329-
def test_attach(sagemaker_session, rl_coach_version):
329+
def test_attach(sagemaker_session, rl_coach_mxnet_version):
330330
training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-{}:{}{}-cpu-py3'\
331-
.format(RLFramework.MXNET.value, RLToolkit.COACH.value, rl_coach_version)
331+
.format(RLFramework.MXNET.value, RLToolkit.COACH.value, rl_coach_mxnet_version)
332332
supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP[RLToolkit.COACH.value]
333-
framework_version = supported_versions[rl_coach_version][RLFramework.MXNET.value]
333+
framework_version = supported_versions[rl_coach_mxnet_version][RLFramework.MXNET.value]
334334
returned_job_description = {'AlgorithmSpecification': {'TrainingInputMode': 'File',
335335
'TrainingImage': training_image},
336336
'HyperParameters':
@@ -361,7 +361,7 @@ def test_attach(sagemaker_session, rl_coach_version):
361361
assert estimator.framework == RLFramework.MXNET.value
362362
assert estimator.toolkit == RLToolkit.COACH.value
363363
assert estimator.framework_version == framework_version
364-
assert estimator.toolkit_version == rl_coach_version
364+
assert estimator.toolkit_version == rl_coach_mxnet_version
365365
assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole'
366366
assert estimator.train_instance_count == 1
367367
assert estimator.train_max_run == 24 * 60 * 60

0 commit comments

Comments
 (0)