Skip to content

Commit 891e8eb

Browse files
committed
Add tests
1 parent 1f86451 commit 891e8eb

File tree

5 files changed

+112
-2
lines changed

5 files changed

+112
-2
lines changed

tests/unit/sagemaker/config/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,15 @@ def s3_resource_mock():
292292
@pytest.fixture()
293293
def get_data_dir():
294294
return os.path.join(os.path.dirname(__file__), "..", "..", "..", "data", "config")
295+
296+
297+
@pytest.fixture()
298+
def base_local_mode_config():
299+
return {
300+
"local": {
301+
"local_code": True,
302+
"region_name": "",
303+
"serving_port": 8080,
304+
"container_config": {"shm_size": "128M"},
305+
}
306+
}

tests/unit/sagemaker/config/test_config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
import pytest
1717
import yaml
1818
import logging
19-
from mock import Mock, MagicMock
19+
from mock import Mock, MagicMock, patch
2020

2121
from sagemaker.config.config import (
22+
load_local_mode_config,
2223
load_sagemaker_config,
2324
logger,
2425
_DEFAULT_ADMIN_CONFIG_FILE_PATH,
2526
_DEFAULT_USER_CONFIG_FILE_PATH,
27+
_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH,
2628
)
2729
from jsonschema import exceptions
2830
from yaml.constructor import ConstructorError
@@ -402,3 +404,13 @@ def test_logging_with_additional_configs_and_none_are_found(caplog):
402404
in caplog.text
403405
)
404406
logger.propagate = False
407+
408+
409+
@patch("sagemaker.config.config._load_config_from_file")
410+
def test_load_local_mode_config(mock_load_config):
411+
load_local_mode_config()
412+
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
413+
414+
415+
def test_load_local_mode_config_when_config_file_is_not_found():
416+
assert load_local_mode_config() is None

tests/unit/sagemaker/config/test_config_schema.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from __future__ import absolute_import
1414
from jsonschema import validate, exceptions
1515
import pytest
16-
from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
16+
from sagemaker.config.config_schema import (
17+
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA,
18+
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA,
19+
)
1720

1821

1922
def _validate_config(base_config_with_schema, sagemaker_config):
@@ -291,3 +294,19 @@ def test_session_s3_object_key_prefix_schema(base_config_with_schema, prefix_nam
291294
def test_invalid_session_s3_object_key_prefix_schema(base_config_with_schema, invalid_prefix_name):
292295
with pytest.raises(exceptions.ValidationError):
293296
test_session_s3_object_key_prefix_schema(base_config_with_schema, invalid_prefix_name)
297+
298+
299+
def test_validate_local_mode_schema(base_local_mode_config):
300+
validate(base_local_mode_config, SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA)
301+
302+
303+
def test_validate_local_mode_schema_with_additional_key(base_local_mode_config):
304+
config = dict(**base_local_mode_config)
305+
config["foo"] = "bar"
306+
with pytest.raises(exceptions.ValidationError):
307+
validate(config, SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA)
308+
309+
config2 = dict(**base_local_mode_config)
310+
config2["local"]["foo"] = "bar"
311+
with pytest.raises(exceptions.ValidationError):
312+
validate(config2, SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA)

tests/unit/sagemaker/local/test_local_image.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,23 @@ def test_container_does_not_enable_nvidia_docker_for_cpu_containers(sagemaker_se
657657
assert "runtime" not in docker_host
658658

659659

660+
def test_container_with_custom_config(sagemaker_session):
661+
custom_config = {
662+
"local": {
663+
"container_config": {"shm_size": "128M"},
664+
}
665+
}
666+
sagemaker_session.config = custom_config
667+
instance_count = 1
668+
image = "my-image"
669+
sagemaker_container = _SageMakerContainer(
670+
"local", instance_count, image, sagemaker_session=sagemaker_session
671+
)
672+
673+
docker_host = sagemaker_container._create_docker_host("host-1", {}, set(), "train", [])
674+
assert "shm_size" in docker_host
675+
676+
660677
@patch("sagemaker.local.image._HostingContainer.run", Mock())
661678
@patch("sagemaker.local.image._SageMakerContainer._prepare_serving_volumes", Mock(return_value=[]))
662679
@patch("shutil.copy", Mock())

tests/unit/sagemaker/local/test_local_session.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import jsonschema
1415

1516
import pytest
1617
import urllib3
@@ -1031,3 +1032,52 @@ def test_default_bucket_prefix_with_sagemaker_config(boto_session, client):
10311032
**session_kwargs,
10321033
)
10331034
assert session_with_no_prefix.default_bucket_prefix is None
1035+
1036+
1037+
VALID_LOCAL_MODE_CONFIG = {
1038+
"local": {
1039+
"local_code": True,
1040+
"serving_port": 8888,
1041+
"container_config": {"shm_size": "128M"},
1042+
}
1043+
}
1044+
1045+
INVALID_LOCAL_MODE_CONFIG = {
1046+
"locals": {
1047+
"local_code": True,
1048+
"serving_port": 8888,
1049+
"container_config": {"shm_size": "128M"},
1050+
}
1051+
}
1052+
1053+
1054+
@patch("sagemaker.local.local_session.load_local_mode_config", return_value=VALID_LOCAL_MODE_CONFIG)
1055+
def test_config_getter(load_config_mock):
1056+
boto_session = Mock(region_name="us-west-2")
1057+
session = LocalSession(boto_session=boto_session)
1058+
load_config_mock.assert_called()
1059+
assert session.config == VALID_LOCAL_MODE_CONFIG
1060+
1061+
1062+
@patch(
1063+
"sagemaker.local.local_session.load_local_mode_config", return_value=INVALID_LOCAL_MODE_CONFIG
1064+
)
1065+
def test_config_validation(load_config_mock):
1066+
boto_session = Mock(region_name="us-west-2")
1067+
1068+
with pytest.raises(jsonschema.ValidationError):
1069+
LocalSession(boto_session=boto_session)
1070+
1071+
1072+
def test_config_setter():
1073+
boto_session = Mock(region_name="us-west-2")
1074+
1075+
session = LocalSession(boto_session=boto_session)
1076+
session.config = VALID_LOCAL_MODE_CONFIG
1077+
assert (
1078+
session.sagemaker_runtime_client.serving_port
1079+
== VALID_LOCAL_MODE_CONFIG["local"]["serving_port"]
1080+
)
1081+
1082+
with pytest.raises(jsonschema.ValidationError):
1083+
session.config = INVALID_LOCAL_MODE_CONFIG

0 commit comments

Comments
 (0)