|
17 | 17 |
|
18 | 18 | import pytest
|
19 | 19 |
|
| 20 | +from botocore.errorfactory import ClientError |
| 21 | + |
20 | 22 | from sagemaker.sparkml.model import SparkMLModel
|
21 | 23 | from sagemaker.utils import sagemaker_timestamp
|
22 | 24 | from tests.integ import DATA_DIR
|
23 | 25 | from tests.integ.timeout import timeout_and_delete_endpoint_by_name
|
24 | 26 |
|
25 | 27 |
|
26 | 28 | @pytest.mark.canary_quick
|
27 |
| -@pytest.mark.skip( |
28 |
| - reason="This test has always failed, but the failure was masked by a bug. " |
29 |
| - "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968" |
30 |
| -) |
31 | 29 | def test_sparkml_model_deploy(sagemaker_session, cpu_instance_type):
|
32 |
| - # Uploads an MLeap serialized MLeap model to S3 and use that to deploy a SparkML model to perform inference |
| 30 | + # Uploads an MLeap serialized MLeap model to S3 and use that to deploy |
| 31 | + # a SparkML model to perform inference |
33 | 32 | data_path = os.path.join(DATA_DIR, "sparkml_model")
|
34 | 33 | endpoint_name = "test-sparkml-deploy-{}".format(sagemaker_timestamp())
|
35 | 34 | model_data = sagemaker_session.upload_data(
|
@@ -59,7 +58,8 @@ def test_sparkml_model_deploy(sagemaker_session, cpu_instance_type):
|
59 | 58 | predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
|
60 | 59 |
|
61 | 60 | valid_data = "1.0,C,38.0,71.5,1.0,female"
|
62 |
| - assert predictor.predict(valid_data) == "1.0,0.0,38.0,1.0,71.5,0.0,1.0" |
| 61 | + assert predictor.predict(valid_data) == b"1.0,0.0,38.0,1.0,71.5,0.0,1.0" |
63 | 62 |
|
64 | 63 | invalid_data = "1.0,28.0,C,38.0,71.5,1.0"
|
65 |
| - assert predictor.predict(invalid_data) is None |
| 64 | + with pytest.raises(ClientError): |
| 65 | + predictor.predict(invalid_data) |
0 commit comments