Skip to content

Commit d246c52

Browse files
committed
infer framework and its version for mxnet, pytorch, sklearn and tensorflow models
1 parent a5464a2 commit d246c52

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

src/sagemaker/mxnet/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def register(
238238
domain=domain,
239239
sample_payload_url=sample_payload_url,
240240
task=task,
241-
framework=framework,
242-
framework_version=framework_version,
241+
framework=framework or self._framework_name,
242+
framework_version=framework_version or self.framework_version,
243243
nearest_model_name=nearest_model_name,
244244
data_input_configuration=data_input_configuration,
245245
)

src/sagemaker/pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def register(
239239
domain=domain,
240240
sample_payload_url=sample_payload_url,
241241
task=task,
242-
framework=framework,
243-
framework_version=framework_version,
242+
framework=framework or self._framework_name,
243+
framework_version=framework_version or self.framework_version,
244244
nearest_model_name=nearest_model_name,
245245
data_input_configuration=data_input_configuration,
246246
)

src/sagemaker/sklearn/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ def register(
233233
domain=domain,
234234
sample_payload_url=sample_payload_url,
235235
task=task,
236-
framework=framework,
237-
framework_version=framework_version,
236+
framework=framework or self._framework_name,
237+
framework_version=framework_version or self.framework_version,
238238
nearest_model_name=nearest_model_name,
239239
data_input_configuration=data_input_configuration,
240240
)

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ def register(
285285
domain=domain,
286286
sample_payload_url=sample_payload_url,
287287
task=task,
288-
framework=framework,
289-
framework_version=framework_version,
288+
framework=framework or self._framework_name,
289+
framework_version=framework_version or self.framework_version,
290290
nearest_model_name=nearest_model_name,
291291
data_input_configuration=data_input_configuration,
292292
)

tests/unit/test_sklearn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,23 @@ def test_create_model(sagemaker_session, sklearn_version):
185185
assert model_values["Image"] == image_uri
186186

187187

188+
def test_register_model(sagemaker_session, sklearn_version):
189+
source_dir = "s3://mybucket/source"
190+
191+
sklearn_model = SKLearnModel(
192+
model_data=source_dir,
193+
role=ROLE,
194+
sagemaker_session=sagemaker_session,
195+
entry_point=SCRIPT_PATH,
196+
framework_version=sklearn_version,
197+
)
198+
199+
model = sklearn_model.register(
200+
content_types=["application/json"],
201+
response_types=["application/json"],
202+
)
203+
204+
188205
@patch("sagemaker.model.FrameworkModel._upload_code")
189206
def test_create_model_with_network_isolation(upload, sagemaker_session, sklearn_version):
190207
source_dir = "s3://mybucket/source"

0 commit comments

Comments
 (0)