42
42
"0.5" : {"tensorflow" : "1.11" },
43
43
"0.6.5" : {"tensorflow" : "1.12" },
44
44
"0.6" : {"tensorflow" : "1.12" },
45
- "0.8.2" :{"tensorflow" : "2.1" },
46
- "0.8.5" :{"tensorflow" : "2.1" , "pytorch" : "1.5" }
45
+ "0.8.2" : {"tensorflow" : "2.1" },
46
+ "0.8.5" : {"tensorflow" : "2.1" , "pytorch" : "1.5" },
47
47
},
48
48
}
49
49
@@ -146,7 +146,9 @@ def __init__(
146
146
self ._validate_images_args (toolkit , toolkit_version , framework , image_name )
147
147
148
148
if not image_name :
149
- self ._validate_toolkit_support (toolkit .value , toolkit_version , framework .value )
149
+ self ._validate_toolkit_support (
150
+ toolkit .value , toolkit_version , framework .value
151
+ )
150
152
self .toolkit = toolkit .value
151
153
self .toolkit_version = toolkit_version
152
154
self .framework = framework .value
@@ -263,10 +265,14 @@ def create_model(
263
265
return tfsModel (framework_version = self .framework_version , ** base_args )
264
266
if self .framework == RLFramework .MXNET .value :
265
267
return MXNetModel (
266
- framework_version = self .framework_version , py_version = PYTHON_VERSION , ** extended_args
268
+ framework_version = self .framework_version ,
269
+ py_version = PYTHON_VERSION ,
270
+ ** extended_args
267
271
)
268
272
raise ValueError (
269
- "An unknown RLFramework enum was passed in. framework: {}" .format (self .framework )
273
+ "An unknown RLFramework enum was passed in. framework: {}" .format (
274
+ self .framework
275
+ )
270
276
)
271
277
272
278
def train_image (self ):
@@ -290,8 +296,8 @@ def train_image(self):
290
296
self .train_instance_type ,
291
297
self ._image_version (),
292
298
py_version = "py36" ,
293
- account = DEFAULT_RL_ACCOUNT
294
- )
299
+ account = DEFAULT_RL_ACCOUNT ,
300
+ )
295
301
296
302
return fw_utils .create_image_uri (
297
303
self .sagemaker_session .boto_region_name ,
@@ -302,7 +308,9 @@ def train_image(self):
302
308
)
303
309
304
310
@classmethod
305
- def _prepare_init_params_from_job_description (cls , job_details , model_channel_name = None ):
311
+ def _prepare_init_params_from_job_description (
312
+ cls , job_details , model_channel_name = None
313
+ ):
306
314
"""Convert the job description to init params that can be handled by the
307
315
class constructor
308
316
@@ -356,7 +364,9 @@ def hyperparameters(self):
356
364
SAGEMAKER_ESTIMATOR : SAGEMAKER_ESTIMATOR_VALUE ,
357
365
}
358
366
359
- hyperparameters .update (Framework ._json_encode_hyperparameters (additional_hyperparameters ))
367
+ hyperparameters .update (
368
+ Framework ._json_encode_hyperparameters (additional_hyperparameters )
369
+ )
360
370
return hyperparameters
361
371
362
372
@classmethod
@@ -394,7 +404,9 @@ def _validate_toolkit_format(cls, toolkit):
394
404
"""
395
405
if toolkit and toolkit not in list (RLToolkit ):
396
406
raise ValueError (
397
- "Invalid type: {}, valid RL toolkits types are: {}" .format (toolkit , list (RLToolkit ))
407
+ "Invalid type: {}, valid RL toolkits types are: {}" .format (
408
+ toolkit , list (RLToolkit )
409
+ )
398
410
)
399
411
400
412
@classmethod
@@ -506,7 +518,15 @@ def default_metric_definitions(cls, toolkit):
506
518
float_regex = "[-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?" # noqa: W605, E501
507
519
508
520
return [
509
- {"Name" : "episode_reward_mean" , "Regex" : "episode_reward_mean: (%s)" % float_regex },
510
- {"Name" : "episode_reward_max" , "Regex" : "episode_reward_max: (%s)" % float_regex },
521
+ {
522
+ "Name" : "episode_reward_mean" ,
523
+ "Regex" : "episode_reward_mean: (%s)" % float_regex ,
524
+ },
525
+ {
526
+ "Name" : "episode_reward_max" ,
527
+ "Regex" : "episode_reward_max: (%s)" % float_regex ,
528
+ },
511
529
]
512
- raise ValueError ("An unknown RLToolkit enum was passed in. toolkit: {}" .format (toolkit ))
530
+ raise ValueError (
531
+ "An unknown RLToolkit enum was passed in. toolkit: {}" .format (toolkit )
532
+ )
0 commit comments