Skip to content

Commit 4683967

Browse files
committed
Add unit tests
1 parent 28f41b0 commit 4683967

File tree

3 files changed

+116
-0
lines changed

3 files changed

+116
-0
lines changed

tests/unit/sagemaker/serverless/__init__.py

Whitespace-only changes.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock import Mock
16+
import pytest
17+
18+
from sagemaker.serverless import LambdaModel
19+
20+
IMAGE_URI = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-lambda-image:latest"
21+
ROLE = "arn:aws:iam::123456789012:role/MyLambdaExecutionRole"
22+
23+
24+
@pytest.fixture
25+
def mock_client():
26+
return Mock()
27+
28+
29+
@pytest.mark.parametrize("wait", [False, True])
30+
def test_deploy(mock_client, wait):
31+
model = LambdaModel(IMAGE_URI, ROLE, client=mock_client)
32+
mock_client.create_function = Mock(return_value={"State": "Pending"})
33+
mock_client.get_function_configuration = Mock(return_value={"State": "Active"})
34+
35+
function_name, timeout, memory_size = "my-function", 3, 128
36+
predictor = model.deploy(function_name, timeout=timeout, memory_size=memory_size, wait=wait)
37+
38+
mock_client.create_function.assert_called_once()
39+
_, kwargs = mock_client.create_function.call_args
40+
assert kwargs["FunctionName"] == function_name
41+
assert kwargs["PackageType"] == "Image"
42+
assert kwargs["Timeout"] == timeout
43+
assert kwargs["MemorySize"] == memory_size
44+
assert kwargs["Role"] == ROLE
45+
assert kwargs["Code"] == {"ImageUri": IMAGE_URI}
46+
47+
assert predictor.function_name == function_name
48+
49+
50+
def test_destroy():
51+
model = LambdaModel(IMAGE_URI, ROLE, client=mock_client)
52+
model.destroy() # NOTE: This method is a no-op.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock import Mock
16+
import pytest
17+
18+
from sagemaker.serverless import LambdaPredictor
19+
20+
FUNCTION_NAME = "my-function"
21+
22+
23+
@pytest.fixture
24+
def mock_client():
25+
return Mock()
26+
27+
28+
def test_predict(mock_client):
29+
# TODO
30+
# mock_client.create_function = Mock(return_value={"class": "cat"})
31+
# predictor = LambdaPredictor("my_function", mock_client)
32+
#
33+
# prediction = predictor.predict({"url": "https://images.com/cat.jpg"})
34+
#
35+
# assert prediction = {"class": "cat"}
36+
# mock_client.invoke.assert_called_once
37+
# _, kwargs = mock_client.delete_function.call_args
38+
# assert kwargs["FunctionName"] == "my-function"
39+
pass
40+
41+
42+
def test_destroy(mock_client):
43+
predictor = LambdaPredictor(FUNCTION_NAME, client=mock_client)
44+
45+
predictor.destroy()
46+
47+
mock_client.delete_function.assert_called_once()
48+
_, kwargs = mock_client.delete_function.call_args
49+
assert kwargs["FunctionName"] == FUNCTION_NAME
50+
51+
52+
def test_content_type(mock_client):
53+
predictor = LambdaPredictor(FUNCTION_NAME, client=mock_client)
54+
assert predictor.content_type == "application/json"
55+
56+
57+
def test_accept(mock_client):
58+
predictor = LambdaPredictor(FUNCTION_NAME, client=mock_client)
59+
assert predictor.accept == ("application/json",)
60+
61+
62+
def test_function_name(mock_client):
63+
predictor = LambdaPredictor(FUNCTION_NAME, client=mock_client)
64+
assert predictor.function_name == FUNCTION_NAME

0 commit comments

Comments
 (0)