Skip to content

Commit cfeca4e

Browse files
committed
feature: add support for spark ml serving container version 2.4
1 parent 69c2107 commit cfeca4e

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

src/sagemaker/sparkml/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class SparkMLModel(Model):
5959
model .
6060
"""
6161

62-
def __init__(self, model_data, role=None, spark_version=2.2, sagemaker_session=None, **kwargs):
62+
def __init__(self, model_data, role=None, spark_version=2.4, sagemaker_session=None, **kwargs):
6363
"""Initialize a SparkMLModel.
6464
6565
Args:
@@ -73,7 +73,7 @@ def __init__(self, model_data, role=None, spark_version=2.2, sagemaker_session=N
7373
artifacts. After the endpoint is created, the inference code
7474
might use the IAM role, if it needs to access an AWS resource.
7575
spark_version (str): Spark version you want to use for executing the
76-
inference (default: '2.2').
76+
inference (default: '2.4').
7777
sagemaker_session (sagemaker.session.Session): Session object which
7878
manages interactions with Amazon SageMaker APIs and any other
7979
AWS services needed. If not specified, the estimator creates one

tests/integ/test_sparkml_serving.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@
1717

1818
import pytest
1919

20+
from botocore.errorfactory import ClientError
21+
2022
from sagemaker.sparkml.model import SparkMLModel
2123
from sagemaker.utils import sagemaker_timestamp
2224
from tests.integ import DATA_DIR
2325
from tests.integ.timeout import timeout_and_delete_endpoint_by_name
2426

2527

2628
@pytest.mark.canary_quick
27-
@pytest.mark.skip(
28-
reason="This test has always failed, but the failure was masked by a bug. "
29-
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
30-
)
3129
def test_sparkml_model_deploy(sagemaker_session, cpu_instance_type):
32-
# Uploads an MLeap serialized MLeap model to S3 and use that to deploy a SparkML model to perform inference
30+
# Uploads an MLeap serialized MLeap model to S3 and use that to deploy
31+
# a SparkML model to perform inference
3332
data_path = os.path.join(DATA_DIR, "sparkml_model")
3433
endpoint_name = "test-sparkml-deploy-{}".format(sagemaker_timestamp())
3534
model_data = sagemaker_session.upload_data(
@@ -59,7 +58,8 @@ def test_sparkml_model_deploy(sagemaker_session, cpu_instance_type):
5958
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
6059

6160
valid_data = "1.0,C,38.0,71.5,1.0,female"
62-
assert predictor.predict(valid_data) == "1.0,0.0,38.0,1.0,71.5,0.0,1.0"
61+
assert predictor.predict(valid_data) == b"1.0,0.0,38.0,1.0,71.5,0.0,1.0"
6362

6463
invalid_data = "1.0,28.0,C,38.0,71.5,1.0"
65-
assert predictor.predict(invalid_data) is None
64+
with pytest.raises(ClientError):
65+
predictor.predict(invalid_data)

tests/unit/test_sparkml_serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def sagemaker_session():
4949

5050
def test_sparkml_model(sagemaker_session):
5151
sparkml = SparkMLModel(sagemaker_session=sagemaker_session, model_data=MODEL_DATA, role=ROLE)
52-
assert sparkml.image_uri == image_uris.retrieve("sparkml-serving", REGION, version="2.2")
52+
assert sparkml.image_uri == image_uris.retrieve("sparkml-serving", REGION, version="2.4")
5353

5454

5555
def test_predictor_type(sagemaker_session):

0 commit comments

Comments
 (0)