Skip to content

Commit cc090d4

Browse files
balajisankar15Balaji Sankar
authored andcommitted
feature: Added Config parser for SageMaker Python SDK (aws#840)
Co-authored-by: Balaji Sankar <[email protected]>
1 parent 26fd27b commit cc090d4

13 files changed

+1530
-0
lines changed

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def read_requirements(filename):
5959
"pandas",
6060
"pathos",
6161
"schema",
62+
"PyYAML==5.4.1",
63+
"jsonschema",
64+
"platformdirs",
6265
]
6366

6467
# Specific use case dependencies

src/sagemaker/config/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying athis file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module configures the default values for SageMaker Python SDK."""
14+
15+
from __future__ import absolute_import

src/sagemaker/config/config.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module configures the default values for SageMaker Python SDK.
14+
15+
It supports loading Config files from local file system/S3.
16+
The schema of the Config file is dictated in config_schema.py in the same module.
17+
18+
"""
19+
from __future__ import absolute_import
20+
21+
import pathlib
22+
import logging
23+
import os
24+
from typing import List
25+
import boto3
26+
import yaml
27+
from jsonschema import validate
28+
from platformdirs import site_config_dir, user_config_dir
29+
from botocore.utils import merge_dicts
30+
from six.moves.urllib.parse import urlparse
31+
from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
32+
33+
logger = logging.getLogger("sagemaker")
34+
35+
_APP_NAME = "sagemaker"
36+
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), "config.yaml")
37+
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), "config.yaml")
38+
39+
ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE = "SAGEMAKER_DEFAULT_CONFIG_OVERRIDE"
40+
ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"
41+
42+
_config_paths = [_DEFAULT_ADMIN_CONFIG_FILE_PATH, _DEFAULT_USER_CONFIG_FILE_PATH]
43+
_BOTO_SESSION = boto3.DEFAULT_SESSION or boto3.Session()
44+
_DEFAULT_S3_RESOURCE = _BOTO_SESSION.resource("s3")
45+
46+
47+
class SageMakerConfig(object):
48+
"""SageMakerConfig class encapsulates the Config for SageMaker Python SDK.
49+
50+
Usages:
51+
This class will be integrated with sagemaker.session.Session. Users of SageMaker Python SDK
52+
will have the ability to pass a SageMakerConfig object to sagemaker.session.Session. If
53+
SageMakerConfig object is not provided by the user, then sagemaker.session.Session will
54+
create its own SageMakerConfig object.
55+
56+
Note: Once sagemaker.session.Session is initialized, it will operate with the configuration
57+
values at that instant. If the users wish to alter configuration files/file paths after
58+
sagemaker.session.Session is initialized, then that will not be reflected in
59+
sagemaker.session.Session. They would have to re-initialize sagemaker.session.Session to
60+
pick the latest changes.
61+
62+
"""
63+
64+
def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE):
65+
"""Constructor for SageMakerConfig.
66+
67+
By default, it will first look for Config files in paths that are dictated by
68+
_DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH.
69+
70+
Users can override the _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH
71+
by using environment variables - SAGEMAKER_DEFAULT_CONFIG_OVERRIDE and
72+
SAGEMAKER_USER_CONFIG_OVERRIDE
73+
74+
Additional Configuration file paths can also be provided as a constructor parameter.
75+
76+
This constructor will then
77+
* Load each config file.
78+
* It will validate the schema of the config files.
79+
* It will perform the merge operation in the same order.
80+
81+
This constructor will throw exceptions for the following cases:
82+
* Schema validation fails for one/more config files.
83+
* When the config file is not a proper YAML file.
84+
* Any S3 related issues that arises while fetching config file from S3. This includes
85+
permission issues, S3 Object is not found in the specified S3 URI.
86+
* File doesn't exist in a path that was specified by the user as part of environment
87+
variable/ additional_config_paths. This doesn't include
88+
_DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH
89+
90+
91+
Args:
92+
additional_config_paths: List of Config file paths.
93+
These paths can be one of the following:
94+
* Local file path
95+
* Local directory path (in this case, we will look for config.yaml in that
96+
directory)
97+
* S3 URI of the config file
98+
* S3 URI of the directory containing the config file (in this case, we will look for
99+
config.yaml in that directory)
100+
Note: S3 URI follows the format s3://<bucket>/<Key prefix>
101+
s3_resource: Corresponds to boto3 S3 resource. This will be used to fetch Config
102+
files from S3. If it is not provided, we will create a default s3 resource
103+
See :py:meth:`boto3.session.Session.resource`. This argument is not needed if the
104+
config files are present in the local file system
105+
106+
"""
107+
default_config_path = os.getenv(
108+
ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH
109+
)
110+
user_config_path = os.getenv(
111+
ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH
112+
)
113+
self._config_paths = [default_config_path, user_config_path]
114+
if additional_config_paths:
115+
self._config_paths += additional_config_paths
116+
self._s3_resource = s3_resource
117+
config = {}
118+
for file_path in self._config_paths:
119+
if file_path.startswith("s3://"):
120+
config_from_file = _load_config_from_s3(file_path, self._s3_resource)
121+
else:
122+
config_from_file = _load_config_from_file(file_path)
123+
if config_from_file:
124+
validate(config_from_file, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
125+
merge_dicts(config, config_from_file)
126+
self._config = config
127+
128+
@property
129+
def config_paths(self) -> List[str]:
130+
"""Getter for Config paths.
131+
132+
Returns:
133+
List[str]: This corresponds to the list of config file paths.
134+
"""
135+
return self._config_paths
136+
137+
@property
138+
def config(self) -> dict:
139+
"""Getter for the configuration object.
140+
141+
Returns:
142+
dict: A dictionary representing the configurations that were loaded from the config
143+
file(s).
144+
"""
145+
return self._config
146+
147+
148+
def _load_config_from_file(file_path: str) -> dict:
149+
"""This method loads the config file from the path that was specified as parameter.
150+
151+
If the path that was provided, corresponds to a directory then this method will try to search
152+
for 'config.yaml' in that directory. Note: We will not be doing any recursive search.
153+
154+
Args:
155+
file_path(str): The file path from which the Config file needs to be loaded.
156+
157+
Returns:
158+
dict: A dictionary representing the configurations that were loaded from the config file.
159+
160+
This method will throw Exceptions for the following cases:
161+
* When the config file is not a proper YAML file.
162+
* File doesn't exist in a path that was specified by the consumer. This doesn't include
163+
_DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH
164+
"""
165+
config = {}
166+
if file_path:
167+
inferred_file_path = file_path
168+
if os.path.isdir(file_path):
169+
inferred_file_path = os.path.join(file_path, "config.yaml")
170+
if not os.path.exists(inferred_file_path):
171+
if inferred_file_path not in (
172+
_DEFAULT_ADMIN_CONFIG_FILE_PATH,
173+
_DEFAULT_USER_CONFIG_FILE_PATH,
174+
):
175+
# Customer provided file path is invalid.
176+
raise ValueError(
177+
f"Unable to load config file from the location: {file_path} Please"
178+
f" provide a valid file path"
179+
)
180+
else:
181+
logger.debug("Fetching configuration file from the path: %s", file_path)
182+
config = yaml.safe_load(open(inferred_file_path, "r"))
183+
return config
184+
185+
186+
def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:
187+
"""This method loads the config file from the S3 URI that was specified as parameter.
188+
189+
If the S3 URI that was provided, corresponds to a directory then this method will try to
190+
search for 'config.yaml' in that directory. Note: We will not be doing any recursive search.
191+
192+
Args:
193+
s3_uri(str): The S3 URI of the config file.
194+
Note: S3 URI follows the format s3://<bucket>/<Key prefix>
195+
s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config
196+
files from S3. See :py:meth:`boto3.session.Session.resource`.
197+
198+
Returns:
199+
dict: A dictionary representing the configurations that were loaded from the config file.
200+
201+
This method will throw Exceptions for the following cases:
202+
* If Boto3 S3 resource is not provided.
203+
* When the config file is not a proper YAML file.
204+
* If the method is unable to retrieve the list of all the S3 files with the same prefix
205+
* If there are no S3 files with that prefix.
206+
* If a folder in S3 bucket is provided as s3_uri, and if it doesn't have config.yaml,
207+
then we will throw an Exception.
208+
"""
209+
if not s3_resource_for_config:
210+
raise RuntimeError("Please provide a S3 client for loading the config")
211+
logger.debug("Fetching configuration file from the S3 URI: %s", s3_uri)
212+
inferred_s3_uri = _get_inferred_s3_uri(s3_uri, s3_resource_for_config)
213+
parsed_url = urlparse(inferred_s3_uri)
214+
bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/")
215+
s3_object = s3_resource_for_config.Object(bucket, key_prefix)
216+
s3_file_content = s3_object.get()["Body"].read()
217+
return yaml.safe_load(s3_file_content.decode("utf-8"))
218+
219+
220+
def _get_inferred_s3_uri(s3_uri, s3_resource_for_config):
221+
"""Verifies whether the given S3 URI exists and returns the URI.
222+
223+
If there are multiple S3 objects with the same key prefix,
224+
then this method will verify whether S3 URI + /config.yaml exists.
225+
s3://example-bucket/somekeyprefix/config.yaml
226+
227+
Args:
228+
s3_uri (str) : An S3 uri that refers to a location in which config file is present.
229+
s3_uri must start with 's3://'.
230+
An example s3_uri: 's3://example-bucket/config.yaml'.
231+
s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config
232+
files from S3.
233+
See :py:meth:`boto3.session.Session.resource`
234+
235+
Returns:
236+
str: Valid S3 URI of the Config file. None if it doesn't exist.
237+
238+
This method will throw Exceptions for the following cases:
239+
* If the method is unable to retrieve the list of all the S3 files with the same prefix
240+
* If there are no S3 files with that prefix.
241+
* If a folder in S3 bucket is provided as s3_uri, and if it doesn't have config.yaml,
242+
then we will throw an Exception.
243+
"""
244+
parsed_url = urlparse(s3_uri)
245+
bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/")
246+
try:
247+
s3_bucket = s3_resource_for_config.Bucket(name=bucket)
248+
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
249+
s3_files_with_same_prefix = [
250+
"s3://{}/{}".format(bucket, s3_object.key) for s3_object in s3_objects
251+
]
252+
except Exception as e: # pylint: disable=W0703
253+
# if customers didn't provide us with a valid S3 File/insufficient read permission,
254+
# We will fail hard.
255+
raise RuntimeError(f"Unable to read from S3 with URI: {s3_uri} due to {e}")
256+
if len(s3_files_with_same_prefix) == 0:
257+
# Customer provided us with an incorrect s3 path.
258+
raise ValueError("Please provide a valid s3 path instead of {}".format(s3_uri))
259+
if len(s3_files_with_same_prefix) > 1:
260+
# Customer has provided us with a S3 URI which points to a directory
261+
# search for s3://<bucket>/directory-key-prefix/config.yaml
262+
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, "config.yaml")).replace("s3:/", "s3://")
263+
if inferred_s3_uri not in s3_files_with_same_prefix:
264+
# We don't know which file we should be operating with.
265+
raise ValueError("Please provide a S3 URI which has config.yaml in the directory")
266+
# Customer has a config.yaml present in the directory that was provided as the S3 URI
267+
return inferred_s3_uri
268+
return s3_uri

0 commit comments

Comments
 (0)