|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""This module configures the default values for SageMaker Python SDK. |
| 14 | +
|
| 15 | +It supports loading Config files from local file system/S3. |
| 16 | +The schema of the Config file is dictated in config_schema.py in the same module. |
| 17 | +
|
| 18 | +""" |
| 19 | +from __future__ import absolute_import |
| 20 | + |
| 21 | +import pathlib |
| 22 | +import logging |
| 23 | +import os |
| 24 | +from typing import List |
| 25 | +import boto3 |
| 26 | +import yaml |
| 27 | +from jsonschema import validate |
| 28 | +from platformdirs import site_config_dir, user_config_dir |
| 29 | +from botocore.utils import merge_dicts |
| 30 | +from six.moves.urllib.parse import urlparse |
| 31 | +from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA |
| 32 | + |
| 33 | +logger = logging.getLogger("sagemaker") |
| 34 | + |
| 35 | +_APP_NAME = "sagemaker" |
| 36 | +_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), "config.yaml") |
| 37 | +_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), "config.yaml") |
| 38 | + |
| 39 | +ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE = "SAGEMAKER_DEFAULT_CONFIG_OVERRIDE" |
| 40 | +ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE" |
| 41 | + |
| 42 | +_config_paths = [_DEFAULT_ADMIN_CONFIG_FILE_PATH, _DEFAULT_USER_CONFIG_FILE_PATH] |
| 43 | +_BOTO_SESSION = boto3.DEFAULT_SESSION or boto3.Session() |
| 44 | +_DEFAULT_S3_RESOURCE = _BOTO_SESSION.resource("s3") |
| 45 | + |
| 46 | + |
| 47 | +class SageMakerConfig(object): |
| 48 | + """SageMakerConfig class encapsulates the Config for SageMaker Python SDK. |
| 49 | +
|
| 50 | + Usages: |
| 51 | + This class will be integrated with sagemaker.session.Session. Users of SageMaker Python SDK |
| 52 | + will have the ability to pass a SageMakerConfig object to sagemaker.session.Session. If |
| 53 | + SageMakerConfig object is not provided by the user, then sagemaker.session.Session will |
| 54 | + create its own SageMakerConfig object. |
| 55 | +
|
| 56 | + Note: Once sagemaker.session.Session is initialized, it will operate with the configuration |
| 57 | + values at that instant. If the users wish to alter configuration files/file paths after |
| 58 | + sagemaker.session.Session is initialized, then that will not be reflected in |
| 59 | + sagemaker.session.Session. They would have to re-initialize sagemaker.session.Session to |
| 60 | + pick the latest changes. |
| 61 | +
|
| 62 | + """ |
| 63 | + |
| 64 | + def __init__(self, additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE): |
| 65 | + """Constructor for SageMakerConfig. |
| 66 | +
|
| 67 | + By default, it will first look for Config files in paths that are dictated by |
| 68 | + _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH. |
| 69 | +
|
| 70 | + Users can override the _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH |
| 71 | + by using environment variables - SAGEMAKER_DEFAULT_CONFIG_OVERRIDE and |
| 72 | + SAGEMAKER_USER_CONFIG_OVERRIDE |
| 73 | +
|
| 74 | + Additional Configuration file paths can also be provided as a constructor parameter. |
| 75 | +
|
| 76 | + This constructor will then |
| 77 | + * Load each config file. |
| 78 | + * It will validate the schema of the config files. |
| 79 | + * It will perform the merge operation in the same order. |
| 80 | +
|
| 81 | + This constructor will throw exceptions for the following cases: |
| 82 | + * Schema validation fails for one/more config files. |
| 83 | + * When the config file is not a proper YAML file. |
| 84 | + * Any S3 related issues that arises while fetching config file from S3. This includes |
| 85 | + permission issues, S3 Object is not found in the specified S3 URI. |
| 86 | + * File doesn't exist in a path that was specified by the user as part of environment |
| 87 | + variable/ additional_config_paths. This doesn't include |
| 88 | + _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH |
| 89 | +
|
| 90 | +
|
| 91 | + Args: |
| 92 | + additional_config_paths: List of Config file paths. |
| 93 | + These paths can be one of the following: |
| 94 | + * Local file path |
| 95 | + * Local directory path (in this case, we will look for config.yaml in that |
| 96 | + directory) |
| 97 | + * S3 URI of the config file |
| 98 | + * S3 URI of the directory containing the config file (in this case, we will look for |
| 99 | + config.yaml in that directory) |
| 100 | + Note: S3 URI follows the format s3://<bucket>/<Key prefix> |
| 101 | + s3_resource: Corresponds to boto3 S3 resource. This will be used to fetch Config |
| 102 | + files from S3. If it is not provided, we will create a default s3 resource |
| 103 | + See :py:meth:`boto3.session.Session.resource`. This argument is not needed if the |
| 104 | + config files are present in the local file system |
| 105 | +
|
| 106 | + """ |
| 107 | + default_config_path = os.getenv( |
| 108 | + ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH |
| 109 | + ) |
| 110 | + user_config_path = os.getenv( |
| 111 | + ENV_VARIABLE_USER_CONFIG_OVERRIDE, _DEFAULT_USER_CONFIG_FILE_PATH |
| 112 | + ) |
| 113 | + self._config_paths = [default_config_path, user_config_path] |
| 114 | + if additional_config_paths: |
| 115 | + self._config_paths += additional_config_paths |
| 116 | + self._s3_resource = s3_resource |
| 117 | + config = {} |
| 118 | + for file_path in self._config_paths: |
| 119 | + if file_path.startswith("s3://"): |
| 120 | + config_from_file = _load_config_from_s3(file_path, self._s3_resource) |
| 121 | + else: |
| 122 | + config_from_file = _load_config_from_file(file_path) |
| 123 | + if config_from_file: |
| 124 | + validate(config_from_file, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) |
| 125 | + merge_dicts(config, config_from_file) |
| 126 | + self._config = config |
| 127 | + |
| 128 | + @property |
| 129 | + def config_paths(self) -> List[str]: |
| 130 | + """Getter for Config paths. |
| 131 | +
|
| 132 | + Returns: |
| 133 | + List[str]: This corresponds to the list of config file paths. |
| 134 | + """ |
| 135 | + return self._config_paths |
| 136 | + |
| 137 | + @property |
| 138 | + def config(self) -> dict: |
| 139 | + """Getter for the configuration object. |
| 140 | +
|
| 141 | + Returns: |
| 142 | + dict: A dictionary representing the configurations that were loaded from the config |
| 143 | + file(s). |
| 144 | + """ |
| 145 | + return self._config |
| 146 | + |
| 147 | + |
| 148 | +def _load_config_from_file(file_path: str) -> dict: |
| 149 | + """This method loads the config file from the path that was specified as parameter. |
| 150 | +
|
| 151 | + If the path that was provided, corresponds to a directory then this method will try to search |
| 152 | + for 'config.yaml' in that directory. Note: We will not be doing any recursive search. |
| 153 | +
|
| 154 | + Args: |
| 155 | + file_path(str): The file path from which the Config file needs to be loaded. |
| 156 | +
|
| 157 | + Returns: |
| 158 | + dict: A dictionary representing the configurations that were loaded from the config file. |
| 159 | +
|
| 160 | + This method will throw Exceptions for the following cases: |
| 161 | + * When the config file is not a proper YAML file. |
| 162 | + * File doesn't exist in a path that was specified by the consumer. This doesn't include |
| 163 | + _DEFAULT_ADMIN_CONFIG_FILE_PATH and _DEFAULT_USER_CONFIG_FILE_PATH |
| 164 | + """ |
| 165 | + config = {} |
| 166 | + if file_path: |
| 167 | + inferred_file_path = file_path |
| 168 | + if os.path.isdir(file_path): |
| 169 | + inferred_file_path = os.path.join(file_path, "config.yaml") |
| 170 | + if not os.path.exists(inferred_file_path): |
| 171 | + if inferred_file_path not in ( |
| 172 | + _DEFAULT_ADMIN_CONFIG_FILE_PATH, |
| 173 | + _DEFAULT_USER_CONFIG_FILE_PATH, |
| 174 | + ): |
| 175 | + # Customer provided file path is invalid. |
| 176 | + raise ValueError( |
| 177 | + f"Unable to load config file from the location: {file_path} Please" |
| 178 | + f" provide a valid file path" |
| 179 | + ) |
| 180 | + else: |
| 181 | + logger.debug("Fetching configuration file from the path: %s", file_path) |
| 182 | + config = yaml.safe_load(open(inferred_file_path, "r")) |
| 183 | + return config |
| 184 | + |
| 185 | + |
| 186 | +def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: |
| 187 | + """This method loads the config file from the S3 URI that was specified as parameter. |
| 188 | +
|
| 189 | + If the S3 URI that was provided, corresponds to a directory then this method will try to |
| 190 | + search for 'config.yaml' in that directory. Note: We will not be doing any recursive search. |
| 191 | +
|
| 192 | + Args: |
| 193 | + s3_uri(str): The S3 URI of the config file. |
| 194 | + Note: S3 URI follows the format s3://<bucket>/<Key prefix> |
| 195 | + s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config |
| 196 | + files from S3. See :py:meth:`boto3.session.Session.resource`. |
| 197 | +
|
| 198 | + Returns: |
| 199 | + dict: A dictionary representing the configurations that were loaded from the config file. |
| 200 | +
|
| 201 | + This method will throw Exceptions for the following cases: |
| 202 | + * If Boto3 S3 resource is not provided. |
| 203 | + * When the config file is not a proper YAML file. |
| 204 | + * If the method is unable to retrieve the list of all the S3 files with the same prefix |
| 205 | + * If there are no S3 files with that prefix. |
| 206 | + * If a folder in S3 bucket is provided as s3_uri, and if it doesn't have config.yaml, |
| 207 | + then we will throw an Exception. |
| 208 | + """ |
| 209 | + if not s3_resource_for_config: |
| 210 | + raise RuntimeError("Please provide a S3 client for loading the config") |
| 211 | + logger.debug("Fetching configuration file from the S3 URI: %s", s3_uri) |
| 212 | + inferred_s3_uri = _get_inferred_s3_uri(s3_uri, s3_resource_for_config) |
| 213 | + parsed_url = urlparse(inferred_s3_uri) |
| 214 | + bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") |
| 215 | + s3_object = s3_resource_for_config.Object(bucket, key_prefix) |
| 216 | + s3_file_content = s3_object.get()["Body"].read() |
| 217 | + return yaml.safe_load(s3_file_content.decode("utf-8")) |
| 218 | + |
| 219 | + |
| 220 | +def _get_inferred_s3_uri(s3_uri, s3_resource_for_config): |
| 221 | + """Verifies whether the given S3 URI exists and returns the URI. |
| 222 | +
|
| 223 | + If there are multiple S3 objects with the same key prefix, |
| 224 | + then this method will verify whether S3 URI + /config.yaml exists. |
| 225 | + s3://example-bucket/somekeyprefix/config.yaml |
| 226 | +
|
| 227 | + Args: |
| 228 | + s3_uri (str) : An S3 uri that refers to a location in which config file is present. |
| 229 | + s3_uri must start with 's3://'. |
| 230 | + An example s3_uri: 's3://example-bucket/config.yaml'. |
| 231 | + s3_resource_for_config: Corresponds to boto3 S3 resource. This will be used to fetch Config |
| 232 | + files from S3. |
| 233 | + See :py:meth:`boto3.session.Session.resource` |
| 234 | +
|
| 235 | + Returns: |
| 236 | + str: Valid S3 URI of the Config file. None if it doesn't exist. |
| 237 | +
|
| 238 | + This method will throw Exceptions for the following cases: |
| 239 | + * If the method is unable to retrieve the list of all the S3 files with the same prefix |
| 240 | + * If there are no S3 files with that prefix. |
| 241 | + * If a folder in S3 bucket is provided as s3_uri, and if it doesn't have config.yaml, |
| 242 | + then we will throw an Exception. |
| 243 | + """ |
| 244 | + parsed_url = urlparse(s3_uri) |
| 245 | + bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") |
| 246 | + try: |
| 247 | + s3_bucket = s3_resource_for_config.Bucket(name=bucket) |
| 248 | + s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all() |
| 249 | + s3_files_with_same_prefix = [ |
| 250 | + "s3://{}/{}".format(bucket, s3_object.key) for s3_object in s3_objects |
| 251 | + ] |
| 252 | + except Exception as e: # pylint: disable=W0703 |
| 253 | + # if customers didn't provide us with a valid S3 File/insufficient read permission, |
| 254 | + # We will fail hard. |
| 255 | + raise RuntimeError(f"Unable to read from S3 with URI: {s3_uri} due to {e}") |
| 256 | + if len(s3_files_with_same_prefix) == 0: |
| 257 | + # Customer provided us with an incorrect s3 path. |
| 258 | + raise ValueError("Please provide a valid s3 path instead of {}".format(s3_uri)) |
| 259 | + if len(s3_files_with_same_prefix) > 1: |
| 260 | + # Customer has provided us with a S3 URI which points to a directory |
| 261 | + # search for s3://<bucket>/directory-key-prefix/config.yaml |
| 262 | + inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, "config.yaml")).replace("s3:/", "s3://") |
| 263 | + if inferred_s3_uri not in s3_files_with_same_prefix: |
| 264 | + # We don't know which file we should be operating with. |
| 265 | + raise ValueError("Please provide a S3 URI which has config.yaml in the directory") |
| 266 | + # Customer has a config.yaml present in the directory that was provided as the S3 URI |
| 267 | + return inferred_s3_uri |
| 268 | + return s3_uri |
0 commit comments