Skip to content

Commit 7320056

Browse files
committed
fix pylint format errors
1 parent 1bf8c66 commit 7320056

File tree

1 file changed

+33
-13
lines changed

1 file changed

+33
-13
lines changed

src/sagemaker/rl/estimator.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
"0.5": {"tensorflow": "1.11"},
4343
"0.6.5": {"tensorflow": "1.12"},
4444
"0.6": {"tensorflow": "1.12"},
45-
"0.8.2":{"tensorflow": "2.1"},
46-
"0.8.5":{"tensorflow": "2.1", "pytorch": "1.5"}
45+
"0.8.2": {"tensorflow": "2.1"},
46+
"0.8.5": {"tensorflow": "2.1", "pytorch": "1.5"},
4747
},
4848
}
4949

@@ -146,7 +146,9 @@ def __init__(
146146
self._validate_images_args(toolkit, toolkit_version, framework, image_name)
147147

148148
if not image_name:
149-
self._validate_toolkit_support(toolkit.value, toolkit_version, framework.value)
149+
self._validate_toolkit_support(
150+
toolkit.value, toolkit_version, framework.value
151+
)
150152
self.toolkit = toolkit.value
151153
self.toolkit_version = toolkit_version
152154
self.framework = framework.value
@@ -263,10 +265,14 @@ def create_model(
263265
return tfsModel(framework_version=self.framework_version, **base_args)
264266
if self.framework == RLFramework.MXNET.value:
265267
return MXNetModel(
266-
framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args
268+
framework_version=self.framework_version,
269+
py_version=PYTHON_VERSION,
270+
**extended_args
267271
)
268272
raise ValueError(
269-
"An unknown RLFramework enum was passed in. framework: {}".format(self.framework)
273+
"An unknown RLFramework enum was passed in. framework: {}".format(
274+
self.framework
275+
)
270276
)
271277

272278
def train_image(self):
@@ -290,8 +296,8 @@ def train_image(self):
290296
self.train_instance_type,
291297
self._image_version(),
292298
py_version="py36",
293-
account=DEFAULT_RL_ACCOUNT
294-
)
299+
account=DEFAULT_RL_ACCOUNT,
300+
)
295301

296302
return fw_utils.create_image_uri(
297303
self.sagemaker_session.boto_region_name,
@@ -302,7 +308,9 @@ def train_image(self):
302308
)
303309

304310
@classmethod
305-
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
311+
def _prepare_init_params_from_job_description(
312+
cls, job_details, model_channel_name=None
313+
):
306314
"""Convert the job description to init params that can be handled by the
307315
class constructor
308316
@@ -356,7 +364,9 @@ def hyperparameters(self):
356364
SAGEMAKER_ESTIMATOR: SAGEMAKER_ESTIMATOR_VALUE,
357365
}
358366

359-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
367+
hyperparameters.update(
368+
Framework._json_encode_hyperparameters(additional_hyperparameters)
369+
)
360370
return hyperparameters
361371

362372
@classmethod
@@ -394,7 +404,9 @@ def _validate_toolkit_format(cls, toolkit):
394404
"""
395405
if toolkit and toolkit not in list(RLToolkit):
396406
raise ValueError(
397-
"Invalid type: {}, valid RL toolkits types are: {}".format(toolkit, list(RLToolkit))
407+
"Invalid type: {}, valid RL toolkits types are: {}".format(
408+
toolkit, list(RLToolkit)
409+
)
398410
)
399411

400412
@classmethod
@@ -506,7 +518,15 @@ def default_metric_definitions(cls, toolkit):
506518
float_regex = "[-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?" # noqa: W605, E501
507519

508520
return [
509-
{"Name": "episode_reward_mean", "Regex": "episode_reward_mean: (%s)" % float_regex},
510-
{"Name": "episode_reward_max", "Regex": "episode_reward_max: (%s)" % float_regex},
521+
{
522+
"Name": "episode_reward_mean",
523+
"Regex": "episode_reward_mean: (%s)" % float_regex,
524+
},
525+
{
526+
"Name": "episode_reward_max",
527+
"Regex": "episode_reward_max: (%s)" % float_regex,
528+
},
511529
]
512-
raise ValueError("An unknown RLToolkit enum was passed in. toolkit: {}".format(toolkit))
530+
raise ValueError(
531+
"An unknown RLToolkit enum was passed in. toolkit: {}".format(toolkit)
532+
)

0 commit comments

Comments
 (0)