Skip to content

Commit cc00169

Browse files
author
Namrata Madan
committed
feature: support intelligent defaults config for pathways
1 parent aa65a44 commit cc00169

File tree

8 files changed

+176
-38
lines changed

8 files changed

+176
-38
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ application_import_names = sagemaker, tests
33
import-order-style = google
44
per-file-ignores =
55
tests/unit/test_tuner.py: F405
6+
src/sagemaker/config/config_schema.py: E501

src/sagemaker/config/config_schema.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@
4444
SAGEMAKER = "SageMaker"
4545
PYTHON_SDK = "PythonSDK"
4646
MODULES = "Modules"
47+
REMOTE_FUNCTION = "RemoteFunction"
48+
DEPENDENCIES = "Dependencies"
49+
ENVIRONMENT_VARIABLES = "EnvironmentVariables"
50+
IMAGE_URI = "ImageUri"
51+
INSTANCE_TYPE = "InstanceType"
52+
S3_KMS_KEY_ID = "S3KmsKeyId"
53+
S3_ROOT_URI = "S3RootUri"
54+
SOURCE_DIR = "SourceDir"
55+
JOB_CONDA_ENV = "JobCondaEnvironment"
4756
OFFLINE_STORE_CONFIG = "OfflineStoreConfig"
4857
ONLINE_STORE_CONFIG = "OnlineStoreConfig"
4958
S3_STORAGE_CONFIG = "S3StorageConfig"
@@ -246,6 +255,7 @@ def _simple_path(*args: str):
246255
)
247256

248257

258+
<<<<<<< HEAD
249259
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
250260
"$schema": "https://json-schema.org/draft/2020-12/schema",
251261
TYPE: OBJECT,

src/sagemaker/remote_function/job.py

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import re
1818
from typing import Dict, List, Tuple
1919

20+
from sagemaker.config import config_schema
21+
from sagemaker.config.config_factory import SageMakerConfigFactory
2022
from sagemaker.session import get_execution_role, _logs_for_job, Session
21-
from sagemaker.utils import name_from_base, base_name_from_image
23+
from sagemaker.utils import name_from_base
2224
from sagemaker.s3 import s3_path_join
2325
from sagemaker import vpc_utils
2426
from sagemaker.remote_function.core.stored_function import StoredFunction
@@ -32,7 +34,6 @@
3234
logger = logging_config.get_logger()
3335

3436

35-
# TODO: extend this class to load job settings from the configuration files.
3637
class _JobSettings:
3738
"""Helper class that processes the job settings.
3839
@@ -62,46 +63,84 @@ def __init__(
6263
volume_kms_key: str = None,
6364
volume_size: int = 30,
6465
):
66+
67+
self.sagemaker_config = SageMakerConfigFactory.build_sagemaker_config(
68+
additional_override_config_location=os.getcwd()
69+
)
6570
self.sagemaker_session = sagemaker_session or Session()
6671

67-
self.environment_variables = environment_variables
72+
self.environment_variables = self._get_from_config(
73+
environment_variables, config_schema.ENVIRONMENT_VARIABLES
74+
)
6875

69-
self.image_uri = image_uri or _JobSettings._get_default_image_uri()
70-
self.dependencies = dependencies
76+
# TODO: provide default image uri if not set
77+
self.image_uri = self._get_from_config(image_uri, config_schema.IMAGE_URI, required=True)
78+
self.dependencies = self._get_from_config(dependencies, config_schema.DEPENDENCIES)
7179

72-
self.instance_type = instance_type
80+
self.instance_type = self._get_from_config(
81+
instance_type, config_schema.INSTANCE_TYPE, required=True
82+
)
7383
self.instance_count = instance_count
7484
self.volume_size = volume_size
7585
self.max_runtime_in_seconds = max_runtime_in_seconds
7686
self.max_retry_attempts = max_retry_attempts
7787
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
78-
self.source_dir = source_dir
79-
self.job_conda_env = job_conda_env
88+
self.source_dir = self._get_from_config(source_dir, config_schema.SOURCE_DIR)
89+
self.job_conda_env = self._get_from_config(job_conda_env, config_schema.JOB_CONDA_ENV)
8090

81-
if role is not None:
82-
self.role = self.sagemaker_session.expand_role(role)
91+
_role = self._get_from_config(role, config_schema.ROLE_ARN)
92+
if _role:
93+
self.role = self.sagemaker_session.expand_role(_role)
8394
else:
8495
self.role = get_execution_role(self.sagemaker_session)
8596

86-
self.s3_root_uri = s3_root_uri or os.path.join(
87-
"s3://",
88-
self.sagemaker_session.default_bucket(),
89-
base_name_from_image(self.image_uri),
97+
self.s3_root_uri = self._get_from_config(
98+
s3_root_uri,
99+
config_schema.S3_ROOT_URI,
100+
default=os.path.join("s3://", self.sagemaker_session.default_bucket()),
90101
)
91-
self.s3_kms_key = s3_kms_key
92-
self.volume_kms_key = volume_kms_key
102+
self.s3_kms_key = self._get_from_config(s3_kms_key, config_schema.S3_KMS_KEY_ID)
103+
self.volume_kms_key = self._get_from_config(volume_kms_key, config_schema.VOLUME_KMS_KEY_ID)
93104

94-
vpc_config = vpc_utils.to_dict(subnets=subnets, security_group_ids=security_group_ids)
105+
_subnets = self._get_from_config(subnets, config_schema.SUBNETS)
106+
_security_group_ids = self._get_from_config(
107+
security_group_ids, config_schema.SECURITY_GROUP_IDS
108+
)
109+
vpc_config = vpc_utils.to_dict(subnets=_subnets, security_group_ids=_security_group_ids)
95110
self.vpc_config = vpc_utils.sanitize(vpc_config)
96111

97-
self.tags = [] if tags is None else [{"Key": k, "Value": v} for k, v in tags]
98-
99-
@staticmethod
100-
def _get_default_image_uri():
101-
"""Get the default image uri"""
112+
self.tags = self._get_from_config(
113+
tags,
114+
config_schema.TAGS,
115+
transform=lambda x: [next(iter(tuple(e.items()))) for e in x],
116+
default=[],
117+
)
102118

103-
# TODO: provide default image uri if not set
104-
raise ValueError("image_uri must be set")
119+
def _get_from_config(
120+
self,
121+
override_value,
122+
sagemaker_config_key,
123+
transform=lambda x: x,
124+
default=None,
125+
required=False,
126+
):
127+
"""Get default value from sagemaker config."""
128+
if override_value:
129+
return override_value
130+
config_value = self.sagemaker_config.get_config_value(
131+
"{}.{}.{}.{}.{}".format(
132+
config_schema.SAGEMAKER,
133+
config_schema.PYTHON_SDK,
134+
config_schema.MODULES,
135+
config_schema.REMOTE_FUNCTION,
136+
sagemaker_config_key,
137+
)
138+
)
139+
if config_value:
140+
return transform(config_value)
141+
if not default and required:
142+
raise ValueError(f"{sagemaker_config_key} is a required parameter!")
143+
return default
105144

106145

107146
class _Job:

tests/data/config/config.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
SchemaVersion: '1.0'
22
SageMaker:
3+
PythonSDK:
4+
Modules:
5+
RemoteFunction:
6+
Dependencies: "path/to/requirements.txt"
7+
EnvironmentVariables: {"EnvVarKey": "EnvVarValue"}
8+
ImageUri: "366666666666.dkr.ecr.us-west-2.amazonaws.com/my-image:latest"
9+
InstanceType: "ml.m5.large"
10+
JobCondaEnvironment: "my_conda_env"
11+
RoleArn: "arn:aws:iam::366666666666:role/IMRole"
12+
S3KmsKeyId: "somekmskeyid"
13+
S3RootUri: "s3://bucket/key"
14+
SecurityGroupIds: ["sg123"]
15+
SourceDir: "../mymodule"
16+
Subnets: ["subnet-1234"]
17+
Tags: [{"someTagKey": "someTagValue"}]
18+
VolumeKmsKeyId: "somekmskeyid"
319
FeatureGroup:
420
OnlineStoreConfig:
521
SecurityConfig:
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
SchemaVersion: '1.0'
2+
SageMaker:
3+
PythonSDK:
4+
Modules:
5+
RemoteFunction:
6+
Dependencies: "path/to/requirements.txt"
7+
EnvironmentVariables: {"EnvVarKey": "EnvVarValue"}
8+
InstanceType: "ml.m5.large"
9+
JobCondaEnvironment: "my_conda_env"
10+
S3KmsKeyId: "someS3KmsKey"
11+
SecurityGroupIds: ["sg123"]
12+
SourceDir: "../mymodule"
13+
Subnets: ["subnet-1234"]
14+
Tags: [{"someTagKey": "someTagValue"}, {"someTagKey2": "someTagValue2"}]
15+
VolumeKmsKeyId: "someVolumeKmsKey"

tests/unit/sagemaker/config/conftest.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,32 @@ def base_config_with_schema():
2222
return {"SchemaVersion": "1.0"}
2323

2424

25+
<<<<<<< HEAD
2526
@pytest.fixture()
2627
def valid_vpc_config():
2728
return {"SecurityGroupIds": ["sg123"], "Subnets": ["subnet-1234"]}
29+
=======
30+
@pytest.fixture(scope="module")
31+
def valid_vpc_subnet():
32+
return "subnet-1234"
33+
34+
35+
@pytest.fixture(scope="module")
36+
def valid_vpc_security_group():
37+
return "sg123"
38+
39+
40+
@pytest.fixture(scope="module")
41+
def valid_vpc_config(valid_vpc_security_group, valid_vpc_subnet):
42+
return {"SecurityGroupIds": [valid_vpc_security_group], "Subnets": [valid_vpc_subnet]}
43+
>>>>>>> 740cd77d (feature: support intelligent defaults config for pathways)
2844

2945

3046
@pytest.fixture()
3147
def valid_iam_role_arn():
3248
return "arn:aws:iam::555555555555:role/IMRole"
3349

3450

35-
@pytest.fixture()
3651
def valid_feature_group_config(valid_iam_role_arn):
3752
security_storage_config = {"KmsKeyId": "kmskeyid1"}
3853
s3_storage_config = {"KmsKeyId": "kmskeyid2"}

tests/unit/sagemaker/config/test_config_factory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,14 @@ def test_default_config_file_when_directory_is_provided_as_the_path(
7070

7171

7272
def test_default_config_file_when_path_is_provided_as_environment_variable(
73-
get_data_dir, valid_config_with_all_the_scopes, base_config_with_schema
73+
get_data_dir, valid_config_with_all_the_scopes, base_config_with_schema, monkeypatch
7474
):
75-
os.environ["SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"] = get_data_dir
75+
monkeypatch.setenv("SAGEMAKER_DEFAULT_CONFIG_OVERRIDE", get_data_dir)
7676
# This will try to load config.yaml file from that directory if present.
7777
expected_config = base_config_with_schema
7878
expected_config["SageMaker"] = valid_config_with_all_the_scopes
7979
assert expected_config == SageMakerConfigFactory.build_sagemaker_config().get_config()
80+
monkeypatch.delenv("SAGEMAKER_DEFAULT_CONFIG_OVERRIDE")
8081

8182

8283
def test_merge_behavior_when_additional_config_file_is_not_found(

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
16+
1517
import pytest
1618
from mock import patch, Mock, ANY
19+
from tests.unit import DATA_DIR
1720
from sagemaker.remote_function.job import _JobSettings, _Job
1821

1922

@@ -80,19 +83,38 @@ def job_function(a, b=1, *, c, d=3):
8083

8184
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
8285
@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
83-
def test_default_settings(*args):
86+
def test_sagemaker_config_job_settings(get_execution_role, session, monkeypatch):
87+
monkeypatch.setenv(
88+
"SAGEMAKER_DEFAULT_CONFIG_OVERRIDE", os.path.join(DATA_DIR, "remote_function")
89+
)
90+
8491
job_settings = _JobSettings(image_uri="image_uri")
8592
assert job_settings.image_uri == "image_uri"
86-
assert job_settings.s3_root_uri == f"s3://{BUCKET}/image_uri"
93+
assert job_settings.s3_root_uri == f"s3://{BUCKET}"
8794
assert job_settings.role == DEFAULT_ROLE_ARN
88-
assert job_settings.tags == []
95+
assert job_settings.tags == [("someTagKey", "someTagValue"), ("someTagKey2", "someTagValue2")]
96+
assert job_settings.vpc_config == {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]}
97+
assert job_settings.dependencies == "path/to/requirements.txt"
98+
assert job_settings.environment_variables == {"EnvVarKey": "EnvVarValue"}
99+
assert job_settings.job_conda_env == "my_conda_env"
100+
assert job_settings.source_dir == "../mymodule"
101+
assert job_settings.volume_kms_key == "someVolumeKmsKey"
102+
assert job_settings.s3_kms_key == "someS3KmsKey"
103+
assert job_settings.instance_type == "ml.m5.large"
104+
105+
monkeypatch.delenv("SAGEMAKER_DEFAULT_CONFIG_OVERRIDE")
89106

90107

91108
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
92109
@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
93-
def test_fails_on_missing_image_uri(*args):
94-
with pytest.raises(ValueError):
95-
_JobSettings(image_uri=None)
110+
def test_sagemaker_config_job_settings_missing_image_uri(get_execution_role, session, monkeypatch):
111+
monkeypatch.setenv(
112+
"SAGEMAKER_DEFAULT_CONFIG_OVERRIDE", os.path.join(DATA_DIR, "remote_function")
113+
)
114+
115+
with pytest.raises(ValueError, match="ImageUri is a required parameter!"):
116+
_JobSettings()
117+
monkeypatch.delenv("SAGEMAKER_DEFAULT_CONFIG_OVERRIDE")
96118

97119

98120
@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager")
@@ -101,7 +123,11 @@ def test_fails_on_missing_image_uri(*args):
101123
def test_start(session, mock_stored_function, mock_runtime_manager):
102124

103125
job_settings = _JobSettings(
104-
image_uri=IMAGE, s3_root_uri=S3_URI, role=ROLE_ARN, source_dir=PATH_TO_SRC_DIR
126+
image_uri=IMAGE,
127+
s3_root_uri=S3_URI,
128+
role=ROLE_ARN,
129+
source_dir=PATH_TO_SRC_DIR,
130+
instance_type="ml.m5.large",
105131
)
106132

107133
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
@@ -139,7 +165,10 @@ def test_start(session, mock_stored_function, mock_runtime_manager):
139165
],
140166
),
141167
ResourceConfig=dict(
142-
VolumeSizeInGB=30, InstanceCount=1, InstanceType=None, KeepAlivePeriodInSeconds=0
168+
VolumeSizeInGB=30,
169+
InstanceCount=1,
170+
InstanceType="ml.m5.large",
171+
KeepAlivePeriodInSeconds=0,
143172
),
144173
)
145174

@@ -221,7 +250,11 @@ def test_start_with_complete_job_settings(session, mock_stored_function, mock_ru
221250
def test_describe(session, *args):
222251

223252
job_settings = _JobSettings(
224-
image_uri=IMAGE, s3_root_uri=S3_URI, role=ROLE_ARN, source_dir=PATH_TO_SRC_DIR
253+
image_uri=IMAGE,
254+
s3_root_uri=S3_URI,
255+
role=ROLE_ARN,
256+
source_dir=PATH_TO_SRC_DIR,
257+
instance_type="ml.m5.large",
225258
)
226259
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
227260

@@ -236,7 +269,11 @@ def test_describe(session, *args):
236269
def test_stop(session, *args):
237270

238271
job_settings = _JobSettings(
239-
image_uri=IMAGE, s3_root_uri=S3_URI, role=ROLE_ARN, source_dir=PATH_TO_SRC_DIR
272+
image_uri=IMAGE,
273+
s3_root_uri=S3_URI,
274+
role=ROLE_ARN,
275+
source_dir=PATH_TO_SRC_DIR,
276+
instance_type="ml.m5.large",
240277
)
241278
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
242279

@@ -253,7 +290,11 @@ def test_stop(session, *args):
253290
def test_wait(session, mock_stored_function, mock_logs_for_job):
254291

255292
job_settings = _JobSettings(
256-
image_uri=IMAGE, s3_root_uri=S3_URI, role=ROLE_ARN, source_dir=PATH_TO_SRC_DIR
293+
image_uri=IMAGE,
294+
s3_root_uri=S3_URI,
295+
role=ROLE_ARN,
296+
source_dir=PATH_TO_SRC_DIR,
297+
instance_type="ml.m5.large",
257298
)
258299
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
259300

0 commit comments

Comments
 (0)