Skip to content

Commit aa65a44

Browse files
balajisankar15Balaji Sankar
authored andcommitted
feature: Added Config parser for SageMaker Python SDK (aws#817)
Co-authored-by: Balaji Sankar <[email protected]>
1 parent 71eb85c commit aa65a44

File tree

2 files changed

+558
-0
lines changed

2 files changed

+558
-0
lines changed
Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying athis file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module configures the default values for SageMaker Python SDK.
14+
15+
It supports loading Config files from local file system/S3.
16+
The schema of the Config file is dictated in config_schema.py in the same module.
17+
18+
"""
19+
from __future__ import absolute_import
20+
21+
import abc
22+
import logging
23+
import os
24+
import yaml
25+
26+
from botocore.utils import merge_dicts
27+
from sagemaker import s3
28+
from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA_V1_0
29+
from sagemaker.utils import get_config_value
30+
31+
logger = logging.getLogger("sagemaker")
32+
33+
34+
class SageMakerConfig(abc.ABC):
35+
"""SageMakerConfig class encapsulates the Config for SageMaker Python SDK.
36+
37+
This class also exposes methods to retrieve the Config.
38+
Note: This class shouldn't be directly instantiated.
39+
40+
.. tip::
41+
Use SageMakerConfigFactory to initialize the SageMakerConfig class.
42+
Subclasses which override ``__init__`` should invoke ``super()``.
43+
"""
44+
45+
def __init__(self):
46+
"""Initializes a SageMakerConfig object.
47+
48+
Note: This constructor invokes _load_config() method which should be implemented by the
49+
subclasses.
50+
51+
Once the Config is loaded, this constructor will validate the schema of the config file.
52+
53+
.. tip::
54+
Subclasses which override ``__init__`` should invoke ``super()``.
55+
56+
"""
57+
self.config_values = self._load_config()
58+
self.validate()
59+
60+
@abc.abstractmethod
61+
def _load_config(self) -> dict:
62+
"""Should be implemented by the Subclass.
63+
64+
Subclass is responsible for fetching the config file.
65+
66+
Returns:
67+
dict: Config as a python dictionary which obeys the schema mentioned in config_schema.py
68+
"""
69+
70+
def validate(self):
71+
"""Validates the schema of the Config.
72+
73+
SchemaError exception is thrown if the schema is not obeyed.
74+
"""
75+
if len(self.config_values) != 0:
76+
# If Config file is not found, then Config object will be empty.
77+
# We will NOT be doing schema validations for an empty config
78+
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA_V1_0.validate(self.config_values)
79+
80+
def merge(self, config_to_be_merged):
81+
"""Merges the given SageMakerConfig object with another SageMakerConfig object.
82+
83+
This method will merge the configuration values. If a same key is present in both the
84+
config files, then the value present in the 'config_to_be_merged' will overwrite the
85+
current value.
86+
87+
Note: This method re-uses the merge_dicts utility method implemented in botocore.
88+
89+
Args:
90+
config_to_be_merged (sagemaker.config_management.config_factory.SageMakerConfig):
91+
The additional SageMakerConfig object that needs to be merged with the current
92+
object.
93+
94+
"""
95+
merge_dicts(self.config_values, config_to_be_merged.config_values)
96+
97+
def get_config_value(self, key_path):
98+
"""Given a key path, retrieve the corresponding value in the Config.
99+
100+
Args:
101+
key_path (str): Key path of the config entry.
102+
Nested entries are represented using '.' (Dot).
103+
104+
Returns:
105+
object: Represents the object that is present in that key path of the Config.
106+
Returns None if the key path doesn't exist.
107+
"""
108+
return get_config_value(key_path, self.config_values)
109+
110+
def get_config(self):
111+
"""Returns the entire configuration as a dictionary.
112+
113+
Returns: The entire configuration as a dictionary.
114+
"""
115+
return self.config_values
116+
117+
118+
class _SageMakerConfigFromFile(SageMakerConfig):
119+
"""An implementation of the SageMakerConfig which loads the Config from a local file system."""
120+
121+
def __init__(self, file_path):
122+
"""Constructor for _SageMakerConfigFromFile.
123+
124+
Note: This internally invokes the super class constructor.
125+
126+
Args:
127+
file_path (str): The local file system path of the YAML file from which the
128+
config needs to be loaded. Note: This needs to be an absolute path.
129+
"""
130+
self.file_path = file_path
131+
super(_SageMakerConfigFromFile, self).__init__()
132+
133+
def _load_config(self) -> dict:
134+
"""Overridden implementation of the _load_config method.
135+
136+
This method loads the config file from the path that was specified as
137+
part of the constructor.
138+
If the path that was provided, corresponds to a directory then this method
139+
will try to search for 'config.yaml' in that directory.
140+
Note: We will not be doing any recursive search.
141+
142+
Returns:
143+
dict: A dictionary representing the configurations.
144+
Note: This dictionary will be empty if an invalid file path is provided
145+
"""
146+
config = {}
147+
if self.file_path:
148+
inferred_file_path = self.file_path
149+
if os.path.isdir(self.file_path):
150+
inferred_file_path = os.path.join(self.file_path, "config.yaml")
151+
if _validate_file_exists(inferred_file_path):
152+
config = yaml.safe_load(open(inferred_file_path, "r"))
153+
return config
154+
155+
156+
class _SageMakerConfigFromS3(SageMakerConfig):
157+
"""An implementation of the SageMakerConfig which loads the Config from a S3 location."""
158+
159+
def __init__(self, s3_uri, sagemaker_session=None):
160+
"""Constructor for _SageMakerConfigFromFile.
161+
162+
This internally invokes the super class constructor.
163+
164+
Args:
165+
sagemaker_session (sagemaker.session.Session): Session object which
166+
manages interactions with Amazon SageMaker APIs and any other
167+
AWS services needed. If not specified, a default session object will be created
168+
using the default AWS configuration chain.
169+
s3_uri (str): An S3 uri that refers to a location in which config file is present.
170+
s3_uri must start with 's3://'.
171+
An example s3_uri: 's3://example-bucket/config.yaml'.
172+
If there are multiple S3 objects with the same key prefix,
173+
then we will infer the path to be s3://example-bucket/somekeyprefix/config.yaml
174+
175+
"""
176+
self.sagemaker_session = sagemaker_session
177+
self.s3_uri = _get_inferred_s3_path(s3_uri, self.sagemaker_session)
178+
super(_SageMakerConfigFromS3, self).__init__()
179+
180+
def _load_config(self) -> dict:
181+
"""Overridden implementation of the _load_config method.
182+
183+
This method loads the config file from the S3 URI that was
184+
specified as part of the constructor.
185+
186+
Note: Currently we don't support Client side KMS. This capability might be added in the
187+
future
188+
189+
Returns:
190+
dict: A dictionary representing the configurations.
191+
Note: This dictionary will be empty if an invalid S3 URI is provided
192+
193+
"""
194+
config = {}
195+
s3_file_content = None
196+
if self.s3_uri:
197+
try:
198+
s3_file_content = s3.S3Downloader.read_file(self.s3_uri, self.sagemaker_session)
199+
except Exception as e: # pylint: disable=W0703
200+
# if customers didn't provide us with a valid S3 File/insufficient read permission,
201+
# we DO want to silently ignore that and operate with an empty config.
202+
logger.warning(
203+
"Unable to fetch Config file from S3 with URI: %s due to %s", self.s3_uri, e
204+
)
205+
if s3_file_content:
206+
config = yaml.safe_load(s3_file_content)
207+
return config
208+
209+
210+
class SageMakerConfigFactory(object):
211+
"""Factory Class to create SageMakerConfig object.
212+
213+
Also supports specifying an additional config for override.
214+
"""
215+
216+
@staticmethod
217+
def build_sagemaker_config(
218+
default_config_location: str = None,
219+
additional_override_config_location: str = None,
220+
sagemaker_session=None,
221+
) -> SageMakerConfig:
222+
"""Factory method to build a SageMakerConfig object.
223+
224+
Note: This method also supports building an additional config override
225+
(in case if there are multiple configs).
226+
This method will then merge the additional config override to the base (default) config.
227+
228+
Args:
229+
default_config_location (str): File path of the Config.
230+
This can even be a S3 URI/Local File path.
231+
additional_override_config_location (str):
232+
File path of the Config override. This can even be a S3 URI/Local File path.
233+
sagemaker_session (sagemaker.session.Session): Session object which
234+
manages interactions with Amazon SageMaker APIs and any other
235+
AWS services needed. If not specified, a default session object will be created
236+
using the default AWS configuration chain.
237+
238+
Returns:
239+
SageMakerConfig: A SageMakerConfig object which contains the merged Config values.
240+
"""
241+
inferred_default_config_path = _get_default_config_path(default_config_location)
242+
inferred_additional_config_override_path = _get_additional_config_override_path(
243+
additional_override_config_location
244+
)
245+
default_config = _get_sagemaker_config(inferred_default_config_path, sagemaker_session)
246+
additional_override_config = _get_sagemaker_config(
247+
inferred_additional_config_override_path, sagemaker_session
248+
)
249+
default_config.merge(additional_override_config)
250+
return default_config
251+
252+
253+
# Default path in which the SageMaker Python SDK looks for Config objects.
254+
DEFAULT_CONFIG_FILE_PATH = os.path.join(
255+
os.path.expanduser("~"), ".sagemaker", "defaults", "sdk-default-config.yaml"
256+
)
257+
ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE = "SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"
258+
ENV_VARIABLE_ADDITIONAL_CONFIG_OVERRIDE = "SAGEMAKER_ADDITIONAL_CONFIG_OVERRIDE"
259+
260+
261+
def _get_sagemaker_config(file_path: str, sagemaker_session) -> SageMakerConfig:
262+
"""Constructs SageMakerConfig object based on the file path that was provided.
263+
264+
Args:
265+
file_path (str): File path of the Config. This can even be a S3 URI/Local File path.
266+
sagemaker_session (sagemaker.session.Session): Session object which
267+
manages interactions with Amazon SageMaker APIs and any other
268+
AWS services needed. If not specified, a default session object will be created
269+
using the default AWS configuration chain.
270+
271+
Returns:
272+
SageMakerConfig: The SageMakerConfig object based on the file path that was provided.
273+
"""
274+
if file_path and file_path.startswith("s3://"):
275+
return _SageMakerConfigFromS3(file_path, sagemaker_session)
276+
return _SageMakerConfigFromFile(file_path)
277+
278+
279+
def _validate_file_exists(file_path) -> bool:
280+
"""Validates whether a file/directory corresponding to that path exists
281+
282+
Args:
283+
file_path (str): The file path for which we need to verify its existence,
284+
285+
Returns:
286+
bool: Boolean indicating whether a file/directory corresponding to that path exists.
287+
"""
288+
if not os.path.exists(file_path):
289+
logger.warning("The specified file path: %s for the Config file doesn't exist.", file_path)
290+
return False
291+
return True
292+
293+
294+
def _get_default_config_path(default_config_directory_method_parameter: str = None) -> str:
295+
"""Returns the default config path.
296+
297+
Args:
298+
default_config_directory_method_parameter (str):
299+
Corresponds to the default_config_location parameter of SageMakerConfigFactory
300+
301+
Returns:
302+
str: If default_config_directory_method_parameter is passed, then we will use that value.
303+
Else, this method will try to fetch the value from SAGEMAKER_CONFIG_FILE environment
304+
variable.If SAGEMAKER_CONFIG_FILE environment variable is not set,
305+
we will infer the path to be DEFAULT_CONFIG_FILE_PATH
306+
"""
307+
if default_config_directory_method_parameter:
308+
return default_config_directory_method_parameter
309+
return os.getenv(ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE, DEFAULT_CONFIG_FILE_PATH)
310+
311+
312+
def _get_additional_config_override_path(
313+
additional_config_override_method_parameter: str = None,
314+
) -> str:
315+
"""Returns the additional config override path.
316+
317+
Args:
318+
additional_config_override_method_parameter (str): Corresponds to the
319+
additional_override_config_location parameter of SageMakerConfigFactory
320+
321+
Returns:
322+
str: If additional_config_override_method_parameter is passed, then we will use that value.
323+
Else, this method will try to fetch the value from SAGEMAKER_ADDITIONAL_CONFIG_OVERRIDE
324+
environment variable
325+
"""
326+
if additional_config_override_method_parameter:
327+
return additional_config_override_method_parameter
328+
return os.getenv(ENV_VARIABLE_ADDITIONAL_CONFIG_OVERRIDE, None)
329+
330+
331+
def _get_inferred_s3_path(s3_uri, sagemaker_session):
332+
"""Verifies whether the given S3 URI exists and returns the URI.
333+
334+
If there are multiple S3 objects with the same key prefix,
335+
then this method will verify whether S3 URI + /config.yaml exists.
336+
s3://example-bucket/somekeyprefix/config.yaml
337+
338+
Args:
339+
s3_uri (str) : An S3 uri that refers to a location in which config file is present.
340+
s3_uri must start with 's3://'.
341+
An example s3_uri: 's3://example-bucket/config.yaml'.
342+
sagemaker_session (sagemaker.session.Session): Session object which
343+
manages interactions with Amazon SageMaker APIs and any other
344+
AWS services needed. If not specified, a default session object will be created
345+
using the default AWS configuration chain.
346+
347+
Returns:
348+
str: Valid S3 URI of the Config file. None if it doesn't exist.
349+
"""
350+
try:
351+
s3_files_with_same_prefix = s3.S3Downloader.list(s3_uri, sagemaker_session)
352+
except Exception as e: # pylint: disable=W0703
353+
# if customers didn't provide us with a valid S3 File/insufficient read permission,
354+
# we DO want to silently ignore that and operate with an empty config.
355+
logger.warning("Unable to read from S3 with URI: %s due to %s", s3_uri, e)
356+
return None
357+
if len(s3_files_with_same_prefix) == 0:
358+
# Provided S3 URI is invalid.
359+
# we DO want to silently ignore that and operate with an empty config.
360+
return None
361+
if len(s3_files_with_same_prefix) > 1:
362+
# Customer has provided us with a S3 URI which points to a directory
363+
inferred_s3_path = s3.s3_path_join(s3_uri, "config.yaml")
364+
if inferred_s3_path not in s3_files_with_same_prefix:
365+
# We don't know which file we should be operating with.
366+
return None
367+
# Customer has a config.yaml present in the directory that was provided as the S3 URI
368+
return inferred_s3_path
369+
return s3_uri

0 commit comments

Comments
 (0)