Skip to content

Commit 7f219eb

Browse files
committed
Change: Accept model_data as dictionary in the model deploy
1 parent 8462f1a commit 7f219eb

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

src/sagemaker/model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -831,16 +831,10 @@ def _create_sagemaker_model(
831831
# _base_name, model_name are not needed under PipelineSession.
832832
# the model_data may be Pipeline variable
833833
# which may break the _base_name generation
834-
model_uri = None
835-
if isinstance(self.model_data, (str, PipelineVariable)):
836-
model_uri = self.model_data
837-
elif isinstance(self.model_data, dict):
838-
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
839-
840834
self._ensure_base_name_if_needed(
841835
image_uri=container_def["Image"],
842836
script_uri=self.source_dir,
843-
model_uri=model_uri,
837+
model_uri=self._get_model_uri(),
844838
)
845839
self._set_model_name_if_needed()
846840

@@ -877,6 +871,14 @@ def _create_sagemaker_model(
877871
)
878872
self.sagemaker_session.create_model(**create_model_args)
879873

874+
def _get_model_uri(self):
875+
model_uri = None
876+
if isinstance(self.model_data, (str, PipelineVariable)):
877+
model_uri = self.model_data
878+
elif isinstance(self.model_data, dict):
879+
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
880+
return model_uri
881+
880882
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
881883
"""Create a base name from the image URI if there is no model name provided.
882884
@@ -1434,7 +1436,7 @@ def deploy(
14341436
self._ensure_base_name_if_needed(
14351437
image_uri=self.image_uri,
14361438
script_uri=self.source_dir,
1437-
model_uri=self.model_data,
1439+
model_uri=self._get_model_uri(),
14381440
)
14391441
if self._base_name is not None:
14401442
self._base_name = "-".join((self._base_name, compiled_model_suffix))

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
framework_version: Optional[str] = None,
4545
py_version: str = "py3",
4646
source_dir: Optional[Union[str, PipelineVariable]] = None,
47-
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
47+
hyperparameters: Optional[Dict[str, Optional[Union[str, PipelineVariable]]]] = None,
4848
image_uri: Optional[Union[str, PipelineVariable]] = None,
4949
image_uri_region: Optional[str] = None,
5050
**kwargs

0 commit comments

Comments
 (0)