Skip to content

Commit 6086451

Browse files
Balaji SankarRuban Hussain
authored andcommitted
fix: Replace SageMakerConfig class with just method invocations
1 parent cd2181b commit 6086451

File tree

73 files changed

+414
-491
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+414
-491
lines changed

src/sagemaker/config/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module configures the default values for SageMaker Python SDK."""
1414

1515
from __future__ import absolute_import
16-
from sagemaker.config.config import SageMakerConfig # noqa: F401
16+
from sagemaker.config.config import fetch_sagemaker_config, validate_sagemaker_config # noqa: F401
1717
from sagemaker.config.config_schema import ( # noqa: F401
1818
KEY,
1919
TRAINING_JOB,
@@ -130,4 +130,5 @@
130130
RESOURCE_CONFIG,
131131
EXECUTION_ROLE_ARN,
132132
ASYNC_INFERENCE_CONFIG,
133+
SCHEMA_VERSION,
133134
)

src/sagemaker/config/config.py

Lines changed: 69 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from typing import List
2525
import boto3
2626
import yaml
27-
from jsonschema import validate
27+
import jsonschema
2828
from platformdirs import site_config_dir, user_config_dir
2929
from botocore.utils import merge_dicts
3030
from six.moves.urllib.parse import urlparse
@@ -51,87 +51,65 @@
5151
S3_PREFIX = "s3://"
5252

5353

54-
class SageMakerConfig(object):
55-
"""A class that encapsulates the configuration for the SageMaker Python SDK.
54+
def fetch_sagemaker_config(
55+
additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE
56+
) -> dict:
57+
"""Helper method that loads config files and merges them.
58+
59+
By default, this method first searches for config files in the default locations
60+
defined by the SDK.
61+
62+
Users can override the default admin and user config file paths using the
63+
SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE environment variables,
64+
respectively.
65+
66+
Additional config file paths can also be provided as a parameter.
67+
68+
This method then:
69+
* Loads each config file, whether it is Amazon S3 or the local file system.
70+
* Validates the schema of the config files.
71+
* Merges the files in the same order.
72+
73+
This method throws exceptions in the following cases:
74+
* jsonschema.exceptions.ValidationError: Schema validation fails for one or more config
75+
files.
76+
* RuntimeError: The method is unable to retrieve the list of all S3 files with the
77+
same prefix or is unable to retrieve the file.
78+
* ValueError: There are no S3 files with the prefix when an S3 URI is provided.
79+
* ValueError: There is no config.yaml file in the S3 bucket when an S3 URI is provided.
80+
* ValueError: A file doesn't exist in a path that was specified by the user as part of an
81+
environment variable or additional configuration file path. This doesn't include the default
82+
config file locations.
83+
84+
Args:
85+
additional_config_paths: List of config file paths.
86+
These paths can be one of the following. In the case of a directory, this method
87+
searches for a config.yaml file in that directory. This method does not perform a
88+
recursive search of folders in that directory.
89+
* Local file path
90+
* Local directory path
91+
* S3 URI of the config file
92+
* S3 URI of the directory containing the config file
93+
Note: S3 URI follows the format s3://<bucket>/<Key prefix>
94+
s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch
95+
config files from S3. If it is not provided, this method creates a default S3 resource
96+
See :py:meth:boto3.session.Session.resource. This argument is not needed if the
97+
config files are present in the local file system.
5698
57-
This class is used to define default values provided by the user.
58-
59-
This class is integrated with sagemaker.session.Session. Users of the SageMaker Python SDK
60-
have the ability to pass a SageMakerConfig object to sagemaker.session.Session. If a
61-
SageMakerConfig object is not provided by the user, then sagemaker.session.Session
62-
creates its own SageMakerConfig object.
63-
64-
Note: After sagemaker.session.Session is initialized, it operates with the configuration
65-
values defined at that instant. If you modify the configuration files or file paths after
66-
sagemaker.session.Session is initialized, those changes are not reflected in
67-
sagemaker.session.Session. To incorporate the changes in the configuration files,
68-
initialize sagemaker.session.Session again.
6999
"""
70-
71-
def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE):
72-
"""Initializes the SageMakerConfig object.
73-
74-
By default, this method first searches for config files in the default locations
75-
defined by the SDK.
76-
77-
Users can override the default admin and user config file paths using the
78-
SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE environment variables,
79-
respectively.
80-
81-
Additional config file paths can also be provided as a constructor parameter.
82-
83-
This method then:
84-
* Loads each config file, whether it is Amazon S3 or the local file system.
85-
* Validates the schema of the config files.
86-
* Merges the files in the same order.
87-
88-
This method throws exceptions in the following cases:
89-
* jsonschema.exceptions.ValidationError: Schema validation fails for one or more config
90-
files.
91-
* RuntimeError: The method is unable to retrieve the list of all S3 files with the
92-
same prefix or is unable to retrieve the file.
93-
* ValueError: There are no S3 files with the prefix when an S3 URI is provided.
94-
* ValueError: There is no config.yaml file in the S3 bucket when an S3 URI is provided.
95-
* ValueError: A file doesn't exist in a path that was specified by the user as part of an
96-
environment variable or additional configuration file path. This doesn't include the default
97-
config file locations.
98-
99-
Args:
100-
additional_config_paths: List of config file paths.
101-
These paths can be one of the following. In the case of a directory, this method
102-
searches for a config.yaml file in that directory. This method does not perform a
103-
recursive search of folders in that directory.
104-
* Local file path
105-
* Local directory path
106-
* S3 URI of the config file
107-
* S3 URI of the directory containing the config file
108-
Note: S3 URI follows the format s3://<bucket>/<Key prefix>
109-
s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch
110-
config files from S3. If it is not provided, this method creates a default S3 resource
111-
See :py:meth:boto3.session.Session.resource. This argument is not needed if the
112-
config files are present in the local file system.
113-
114-
"""
115-
default_config_path = os.getenv(
116-
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH
117-
)
118-
user_config_path = os.getenv(
119-
ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH
120-
)
121-
self.config_paths = [default_config_path, user_config_path]
122-
if additional_config_paths:
123-
self.config_paths += additional_config_paths
124-
self.config_paths = list(filter(lambda item: item is not None, self.config_paths))
125-
self.config = _load_config_files(self.config_paths, s3_resource)
126-
127-
128-
def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict:
129-
"""Placeholder docstring"""
100+
default_config_path = os.getenv(
101+
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH
102+
)
103+
user_config_path = os.getenv(ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH)
104+
config_paths = [default_config_path, user_config_path]
105+
if additional_config_paths:
106+
config_paths += additional_config_paths
107+
config_paths = list(filter(lambda item: item is not None, config_paths))
130108
merged_config = {}
131-
for file_path in file_paths:
109+
for file_path in config_paths:
132110
config_from_file = {}
133111
if file_path.startswith(S3_PREFIX):
134-
config_from_file = _load_config_from_s3(file_path, s3_resource_for_config)
112+
config_from_file = _load_config_from_s3(file_path, s3_resource)
135113
else:
136114
try:
137115
config_from_file = _load_config_from_file(file_path)
@@ -145,11 +123,24 @@ def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict:
145123
# Exceptions.
146124
raise
147125
if config_from_file:
148-
validate(config_from_file, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
126+
validate_sagemaker_config(config_from_file)
149127
merge_dicts(merged_config, config_from_file)
150128
return merged_config
151129

152130

131+
def validate_sagemaker_config(sagemaker_config: dict = None):
132+
"""Helper method that validates whether the schema of a given dictionary.
133+
134+
This method will validate whether the dictionary adheres to the schema
135+
defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`
136+
137+
Args:
138+
sagemaker_config: A dictionary containing default values for the
139+
SageMaker Python SDK. (default: None).
140+
"""
141+
jsonschema.validate(sagemaker_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
142+
143+
153144
def _load_config_from_file(file_path: str) -> dict:
154145
"""Placeholder docstring"""
155146
inferred_file_path = file_path

src/sagemaker/local/local_session.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import boto3
2222
from botocore.exceptions import ClientError
2323

24-
from sagemaker.config import SageMakerConfig
24+
from sagemaker.config import fetch_sagemaker_config, validate_sagemaker_config
2525
from sagemaker.local.image import _SageMakerContainer
2626
from sagemaker.local.utils import get_docker_host
2727
from sagemaker.local.entities import (
@@ -605,7 +605,7 @@ def __init__(
605605
default_bucket=None,
606606
s3_endpoint_url=None,
607607
disable_local_code=False,
608-
sagemaker_config: SageMakerConfig = None,
608+
sagemaker_config: dict = None,
609609
):
610610
"""Create a Local SageMaker Session.
611611
@@ -618,6 +618,16 @@ def __init__(
618618
disable_local_code (bool): Set ``True`` to override the default AWS configuration
619619
chain to disable the ``local.local_code`` setting, which may not be supported for
620620
some SDK features (default: False).
621+
sagemaker_config: A dictionary containing default values for the
622+
SageMaker Python SDK. (default: None). The dictionary must adhere to the schema
623+
defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`.
624+
If sagemaker_config is not provided and configuration files exist (at the default
625+
paths for admins and users, or paths set through the environment variables
626+
SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE),
627+
a new dictionary will be generated from those configuration files. Alternatively,
628+
this dictionary can be generated by calling
629+
:func:`~sagemaker.config.fetch_sagemaker_config` and then be provided to the
630+
Session.
621631
"""
622632
self.s3_endpoint_url = s3_endpoint_url
623633
# We use this local variable to avoid disrupting the __init__->_initialize API of the
@@ -635,12 +645,7 @@ def __init__(
635645
logger.warning("Windows Support for Local Mode is Experimental")
636646

637647
def _initialize(
638-
self,
639-
boto_session,
640-
sagemaker_client,
641-
sagemaker_runtime_client,
642-
sagemaker_config: SageMakerConfig = None,
643-
**kwargs
648+
self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs
644649
): # pylint: disable=unused-argument
645650
"""Initialize this Local SageMaker Session.
646651
@@ -669,20 +674,20 @@ def _initialize(
669674
self.sagemaker_client = LocalSagemakerClient(self)
670675
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
671676
self.local_mode = True
677+
sagemaker_config = kwargs.get("sagemaker_config", None)
678+
if sagemaker_config:
679+
validate_sagemaker_config(sagemaker_config)
672680

673681
if self.s3_endpoint_url is not None:
674682
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
675683
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
676-
self.sagemaker_config = sagemaker_config or (
677-
SageMakerConfig(s3_resource=self.s3_resource)
678-
if "sagemaker_config" not in kwargs
679-
else kwargs.get("sagemaker_config")
684+
self.sagemaker_config = (
685+
sagemaker_config if sagemaker_config else fetch_sagemaker_config(
686+
s3_resource=self.s3_resource)
680687
)
681688
else:
682-
self.sagemaker_config = sagemaker_config or (
683-
SageMakerConfig()
684-
if "sagemaker_config" not in kwargs
685-
else kwargs.get("sagemaker_config")
689+
self.sagemaker_config = (
690+
sagemaker_config if sagemaker_config else fetch_sagemaker_config()
686691
)
687692

688693
sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")

src/sagemaker/processing.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,20 +1246,10 @@ def __init__(
12461246
self.s3_data_distribution_type = s3_data_distribution_type
12471247
self.s3_compression_type = s3_compression_type
12481248
self.s3_input = s3_input
1249-
self._dataset_definition = dataset_definition
1249+
self.dataset_definition = dataset_definition
12501250
self.app_managed = app_managed
12511251
self._create_s3_input()
12521252

1253-
@property
1254-
def dataset_definition(self):
1255-
"""Getter for DataSetDefinition
1256-
1257-
Returns:
1258-
DatasetDefinition: The DatasetDefinition Object.
1259-
1260-
"""
1261-
return self._dataset_definition
1262-
12631253
def _to_request_dict(self):
12641254
"""Generates a request dictionary using the parameters provided to the class."""
12651255

src/sagemaker/session.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import sagemaker.logs
3434
from sagemaker import vpc_utils
3535
from sagemaker._studio import _append_project_tags
36-
from sagemaker.config import SageMakerConfig # noqa: F401
36+
from sagemaker.config import fetch_sagemaker_config, validate_sagemaker_config # noqa: F401
3737
from sagemaker.config import (
3838
KEY,
3939
TRAINING_JOB,
@@ -157,7 +157,7 @@ def __init__(
157157
default_bucket=None,
158158
settings=SessionSettings(),
159159
sagemaker_metrics_client=None,
160-
sagemaker_config: SageMakerConfig = None,
160+
sagemaker_config: dict = None,
161161
):
162162
"""Initialize a SageMaker ``Session``.
163163
@@ -189,9 +189,16 @@ def __init__(
189189
Client which makes SageMaker Metrics related calls to Amazon SageMaker
190190
(default: None). If not provided, one will be created using
191191
this instance's ``boto_session``.
192-
sagemaker_config (sagemaker.config.SageMakerConfig): The SageMakerConfig object which
193-
holds the default values for the SageMaker Python SDK. (default: None). If not
194-
provided, one will be created.
192+
sagemaker_config (dict): A dictionary containing default values for the
193+
SageMaker Python SDK. (default: None). The dictionary must adhere to the schema
194+
defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`.
195+
If sagemaker_config is not provided and configuration files exist (at the default
196+
paths for admins and users, or paths set through the environment variables
197+
SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE),
198+
a new dictionary will be generated from those configuration files. Alternatively,
199+
this dictionary can be generated by calling
200+
:func:`~sagemaker.config.fetch_sagemaker_config` and then be provided to the
201+
Session.
195202
"""
196203
self._default_bucket = None
197204
self._default_bucket_name_override = default_bucket
@@ -217,7 +224,7 @@ def _initialize(
217224
sagemaker_runtime_client,
218225
sagemaker_featurestore_runtime_client,
219226
sagemaker_metrics_client,
220-
sagemaker_config: SageMakerConfig = None,
227+
sagemaker_config: dict = None,
221228
):
222229
"""Initialize this SageMaker Session.
223230
@@ -260,13 +267,13 @@ def _initialize(
260267

261268
self.local_mode = False
262269
if sagemaker_config:
263-
self.sagemaker_config = sagemaker_config
270+
validate_sagemaker_config(sagemaker_config)
264271
else:
265272
if self.s3_resource is None:
266273
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
267274
else:
268275
s3 = self.s3_resource
269-
self.sagemaker_config = SageMakerConfig(s3_resource=s3)
276+
self.sagemaker_config = fetch_sagemaker_config(s3_resource=s3)
270277

271278
@property
272279
def boto_region_name(self):

src/sagemaker/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from six.moves.urllib import parse
3838

3939
from sagemaker import deprecations
40+
from sagemaker.config import validate_sagemaker_config
4041
from sagemaker.session_settings import SessionSettings
4142
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
4243

@@ -1090,7 +1091,9 @@ def get_sagemaker_config_value(sagemaker_session, key):
10901091
"""
10911092
if not sagemaker_session:
10921093
return None
1093-
config_value = get_config_value(key, sagemaker_session.sagemaker_config.config)
1094+
if sagemaker_session.sagemaker_config:
1095+
validate_sagemaker_config(sagemaker_session.sagemaker_config)
1096+
config_value = get_config_value(key, sagemaker_session.sagemaker_config)
10941097
# Copy the value so any modifications to the output will not modify the source config
10951098
return copy.deepcopy(config_value)
10961099

0 commit comments

Comments
 (0)