Skip to content

Commit 9c8d2d4

Browse files
author
Chuyang Deng
committed
avoid calling hyperparaneters twince
1 parent 1453a12 commit 9c8d2d4

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
11131113

11141114
config = _Job._load_config(inputs, estimator)
11151115

1116-
if estimator.hyperparameters() is not None:
1117-
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
1116+
current_hyperparameters = estimator.hyperparameters()
1117+
if current_hyperparameters is not None:
1118+
hyperparameters = {str(k): str(v) for (k, v) in current_hyperparameters.items()}
11181119

11191120
train_args = config.copy()
11201121
train_args["input_mode"] = estimator.input_mode

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,9 @@ def __init__(
135135
kwargs["enable_sagemaker_metrics"] = True
136136

137137
super(TensorFlow, self).__init__(image_uri=image_uri, **kwargs)
138-
self.disable_model_dir = False
139138
self.model_dir = model_dir
140139
self.distribution = distribution or {}
141140

142-
if self.model_dir is False:
143-
self.disable_model_dir = True
144-
145141
self._validate_args(py_version=py_version)
146142

147143
def _validate_args(self, py_version):
@@ -319,7 +315,7 @@ def hyperparameters(self):
319315
"custom_mpi_options", ""
320316
)
321317

322-
if not self.disable_model_dir:
318+
if self.model_dir is not False:
323319
self.model_dir = self.model_dir or self._default_s3_path("model", mpi=mpi_enabled)
324320
additional_hyperparameters["model_dir"] = self.model_dir
325321

0 commit comments

Comments
 (0)