Skip to content

Commit 3933900

Browse files
committed
address local mode and MXNet + MMS
1 parent 36ec1a7 commit 3933900

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1772,7 +1772,10 @@ def _model_entry_point(self):
17721772
str: The path to the entry point script. This can be either an absolute path or
17731773
a path relative to ``self._model_source_dir()``.
17741774
"""
1775-
return self.uploaded_code.script_name if self._model_source_dir() else self.entry_point
1775+
if self.sagemaker_session.local_mode or (self._model_source_dir() is None):
1776+
return self.entry_point
1777+
1778+
return self.uploaded_code.script_name
17761779

17771780
def hyperparameters(self):
17781781
"""Return the hyperparameters as a dictionary to use for training.

src/sagemaker/mxnet/estimator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,10 @@ def create_model(
217217
if "name" not in kwargs:
218218
kwargs["name"] = self._current_job_name
219219

220-
return MXNetModel(
220+
model = MXNetModel(
221221
self.model_data,
222222
role or self.role,
223-
entry_point or self._model_entry_point(),
223+
entry_point,
224224
source_dir=(source_dir or self._model_source_dir()),
225225
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
226226
container_log_level=self.container_log_level,
@@ -234,6 +234,11 @@ def create_model(
234234
**kwargs
235235
)
236236

237+
if entry_point is None:
238+
model.entry_point = (
239+
self.entry_point if model._is_mms_version() else self._model_entry_point()
240+
)
241+
237242
@classmethod
238243
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
239244
"""Convert the job description to init params that can be handled by the

0 commit comments

Comments
 (0)