Skip to content

Commit 9676312

Browse files
committed
fix: CreateModelPackage API error for Scikit-learn and XGBoost frameworks
1 parent 555e0b7 commit 9676312

File tree

4 files changed

+2
-99
lines changed

4 files changed

+2
-99
lines changed

src/sagemaker/sklearn/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def register(
248248
sample_payload_url=sample_payload_url,
249249
task=task,
250250
framework=framework,
251-
framework_version=framework_version or self.framework_version,
251+
framework_version=framework_version,
252252
nearest_model_name=nearest_model_name,
253253
data_input_configuration=data_input_configuration,
254254
)

src/sagemaker/xgboost/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def register(
236236
sample_payload_url=sample_payload_url,
237237
task=task,
238238
framework=framework,
239-
framework_version=framework_version or self.framework_version,
239+
framework_version=framework_version,
240240
nearest_model_name=nearest_model_name,
241241
data_input_configuration=data_input_configuration,
242242
)

tests/unit/test_sklearn.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -587,49 +587,3 @@ def test_model_py2_raises(sagemaker_session, sklearn_version):
587587
framework_version=sklearn_version,
588588
py_version="py2",
589589
)
590-
591-
592-
def test_register_sklearn_model_auto_infer_framework(sagemaker_session, sklearn_version):
593-
source_dir = "s3://mybucket/source"
594-
595-
model_package_group_name = "test-sklearn-register-model"
596-
content_types = ["application/json"]
597-
response_types = ["application/json"]
598-
image_uri = "fakeimage"
599-
600-
sklearn_model = SKLearnModel(
601-
model_data=source_dir,
602-
role=ROLE,
603-
sagemaker_session=sagemaker_session,
604-
entry_point=SCRIPT_PATH,
605-
framework_version=sklearn_version,
606-
)
607-
608-
sklearn_model.register(
609-
content_types,
610-
response_types,
611-
model_package_group_name=model_package_group_name,
612-
marketplace_cert=True,
613-
image_uri=image_uri,
614-
)
615-
616-
expected_create_model_package_request = {
617-
"containers": [
618-
{
619-
"Image": image_uri,
620-
"Environment": ANY,
621-
"ModelDataUrl": source_dir,
622-
"Framework": "SKLEARN",
623-
"FrameworkVersion": sklearn_version,
624-
},
625-
],
626-
"content_types": content_types,
627-
"response_types": response_types,
628-
"inference_instances": None,
629-
"transform_instances": None,
630-
"model_package_group_name": model_package_group_name,
631-
"marketplace_cert": True,
632-
}
633-
sagemaker_session.create_model_package_from_containers.assert_called_with(
634-
**expected_create_model_package_request
635-
)

tests/unit/test_xgboost.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -673,54 +673,3 @@ def test_unsupported_xgboost_version_error(sagemaker_session):
673673
error_message = "XGBoost 1.1 is not supported"
674674
assert error_message in str(error1)
675675
assert error_message in str(error2)
676-
677-
678-
def test_register_xgboost_model_auto_infer_framework(sagemaker_session, xgboost_framework_version):
679-
source_dir = "s3://mybucket/source"
680-
681-
model_package_group_name = "test-pytorch-register-model"
682-
content_types = ["application/json"]
683-
response_types = ["application/json"]
684-
inference_instances = ["ml.m4.xlarge"]
685-
transform_instances = ["ml.m4.xlarge"]
686-
image_uri = "fakeimage"
687-
688-
xgboost_model = XGBoostModel(
689-
model_data=source_dir,
690-
role=ROLE,
691-
sagemaker_session=sagemaker_session,
692-
entry_point=SCRIPT_PATH,
693-
framework_version=xgboost_framework_version,
694-
)
695-
696-
xgboost_model.register(
697-
content_types,
698-
response_types,
699-
inference_instances,
700-
transform_instances,
701-
model_package_group_name=model_package_group_name,
702-
marketplace_cert=True,
703-
image_uri=image_uri,
704-
)
705-
706-
expected_create_model_package_request = {
707-
"containers": [
708-
{
709-
"Image": image_uri,
710-
"Environment": ANY,
711-
"ModelDataUrl": ANY,
712-
"Framework": "XGBOOST",
713-
"FrameworkVersion": xgboost_framework_version,
714-
},
715-
],
716-
"content_types": content_types,
717-
"response_types": response_types,
718-
"inference_instances": inference_instances,
719-
"transform_instances": transform_instances,
720-
"model_package_group_name": model_package_group_name,
721-
"marketplace_cert": True,
722-
}
723-
724-
sagemaker_session.create_model_package_from_containers.assert_called_with(
725-
**expected_create_model_package_request
726-
)

0 commit comments

Comments
 (0)