|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying athis file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""This module configures the default values for SageMaker Python SDK. |
| 14 | +
|
| 15 | +It supports loading Config files from local file system/S3. |
| 16 | +The schema of the Config file is dictated in config_schema.py in the same module. |
| 17 | +
|
| 18 | +""" |
| 19 | +from __future__ import absolute_import |
| 20 | + |
| 21 | +import abc |
| 22 | +import logging |
| 23 | +import os |
| 24 | +import yaml |
| 25 | + |
| 26 | +from botocore.utils import merge_dicts |
| 27 | +from sagemaker import s3 |
| 28 | +from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA_V1_0 |
| 29 | +from sagemaker.utils import get_config_value |
| 30 | + |
| 31 | +logger = logging.getLogger("sagemaker") |
| 32 | + |
| 33 | + |
| 34 | +class SageMakerConfig(abc.ABC): |
| 35 | + """SageMakerConfig class encapsulates the Config for SageMaker Python SDK. |
| 36 | +
|
| 37 | + This class also exposes methods to retrieve the Config. |
| 38 | + Note: This class shouldn't be directly instantiated. |
| 39 | +
|
| 40 | + .. tip:: |
| 41 | + Use SageMakerConfigFactory to initialize the SageMakerConfig class. |
| 42 | + Subclasses which override ``__init__`` should invoke ``super()``. |
| 43 | + """ |
| 44 | + |
| 45 | + def __init__(self): |
| 46 | + """Initializes a SageMakerConfig object. |
| 47 | +
|
| 48 | + Note: This constructor invokes _load_config() method which should be implemented by the |
| 49 | + subclasses. |
| 50 | +
|
| 51 | + Once the Config is loaded, this constructor will validate the schema of the config file. |
| 52 | +
|
| 53 | + .. tip:: |
| 54 | + Subclasses which override ``__init__`` should invoke ``super()``. |
| 55 | +
|
| 56 | + """ |
| 57 | + self.config_values = self._load_config() |
| 58 | + self.validate() |
| 59 | + |
| 60 | + @abc.abstractmethod |
| 61 | + def _load_config(self) -> dict: |
| 62 | + """Should be implemented by the Subclass. |
| 63 | +
|
| 64 | + Subclass is responsible for fetching the config file. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + dict: Config as a python dictionary which obeys the schema mentioned in config_schema.py |
| 68 | + """ |
| 69 | + |
| 70 | + def validate(self): |
| 71 | + """Validates the schema of the Config. |
| 72 | +
|
| 73 | + SchemaError exception is thrown if the schema is not obeyed. |
| 74 | + """ |
| 75 | + if len(self.config_values) != 0: |
| 76 | + # If Config file is not found, then Config object will be empty. |
| 77 | + # We will NOT be doing schema validations for an empty config |
| 78 | + SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA_V1_0.validate(self.config_values) |
| 79 | + |
| 80 | + def merge(self, config_to_be_merged): |
| 81 | + """Merges the given SageMakerConfig object with another SageMakerConfig object. |
| 82 | +
|
| 83 | + This method will merge the configuration values. If a same key is present in both the |
| 84 | + config files, then the value present in the 'config_to_be_merged' will overwrite the |
| 85 | + current value. |
| 86 | +
|
| 87 | + Note: This method re-uses the merge_dicts utility method implemented in botocore. |
| 88 | +
|
| 89 | + Args: |
| 90 | + config_to_be_merged (sagemaker.config_management.config_factory.SageMakerConfig): |
| 91 | + The additional SageMakerConfig object that needs to be merged with the current |
| 92 | + object. |
| 93 | +
|
| 94 | + """ |
| 95 | + merge_dicts(self.config_values, config_to_be_merged.config_values) |
| 96 | + |
| 97 | + def get_config_value(self, key_path): |
| 98 | + """Given a key path, retrieve the corresponding value in the Config. |
| 99 | +
|
| 100 | + Args: |
| 101 | + key_path (str): Key path of the config entry. |
| 102 | + Nested entries are represented using '.' (Dot). |
| 103 | +
|
| 104 | + Returns: |
| 105 | + object: Represents the object that is present in that key path of the Config. |
| 106 | + Returns None if the key path doesn't exist. |
| 107 | + """ |
| 108 | + return get_config_value(key_path, self.config_values) |
| 109 | + |
| 110 | + def get_config(self): |
| 111 | + """Returns the entire configuration as a dictionary. |
| 112 | +
|
| 113 | + Returns: The entire configuration as a dictionary. |
| 114 | + """ |
| 115 | + return self.config_values |
| 116 | + |
| 117 | + |
| 118 | +class _SageMakerConfigFromFile(SageMakerConfig): |
| 119 | + """An implementation of the SageMakerConfig which loads the Config from a local file system.""" |
| 120 | + |
| 121 | + def __init__(self, file_path): |
| 122 | + """Constructor for _SageMakerConfigFromFile. |
| 123 | +
|
| 124 | + Note: This internally invokes the super class constructor. |
| 125 | +
|
| 126 | + Args: |
| 127 | + file_path (str): The local file system path of the YAML file from which the |
| 128 | + config needs to be loaded. Note: This needs to be an absolute path. |
| 129 | + """ |
| 130 | + self.file_path = file_path |
| 131 | + super(_SageMakerConfigFromFile, self).__init__() |
| 132 | + |
| 133 | + def _load_config(self) -> dict: |
| 134 | + """Overridden implementation of the _load_config method. |
| 135 | +
|
| 136 | + This method loads the config file from the path that was specified as |
| 137 | + part of the constructor. |
| 138 | + If the path that was provided, corresponds to a directory then this method |
| 139 | + will try to search for 'config.yaml' in that directory. |
| 140 | + Note: We will not be doing any recursive search. |
| 141 | +
|
| 142 | + Returns: |
| 143 | + dict: A dictionary representing the configurations. |
| 144 | + Note: This dictionary will be empty if an invalid file path is provided |
| 145 | + """ |
| 146 | + config = {} |
| 147 | + if self.file_path: |
| 148 | + inferred_file_path = self.file_path |
| 149 | + if os.path.isdir(self.file_path): |
| 150 | + inferred_file_path = os.path.join(self.file_path, "config.yaml") |
| 151 | + if _validate_file_exists(inferred_file_path): |
| 152 | + config = yaml.safe_load(open(inferred_file_path, "r")) |
| 153 | + return config |
| 154 | + |
| 155 | + |
| 156 | +class _SageMakerConfigFromS3(SageMakerConfig): |
| 157 | + """An implementation of the SageMakerConfig which loads the Config from a S3 location.""" |
| 158 | + |
| 159 | + def __init__(self, s3_uri, sagemaker_session=None): |
| 160 | + """Constructor for _SageMakerConfigFromFile. |
| 161 | +
|
| 162 | + This internally invokes the super class constructor. |
| 163 | +
|
| 164 | + Args: |
| 165 | + sagemaker_session (sagemaker.session.Session): Session object which |
| 166 | + manages interactions with Amazon SageMaker APIs and any other |
| 167 | + AWS services needed. If not specified, a default session object will be created |
| 168 | + using the default AWS configuration chain. |
| 169 | + s3_uri (str): An S3 uri that refers to a location in which config file is present. |
| 170 | + s3_uri must start with 's3://'. |
| 171 | + An example s3_uri: 's3://example-bucket/config.yaml'. |
| 172 | + If there are multiple S3 objects with the same key prefix, |
| 173 | + then we will infer the path to be s3://example-bucket/somekeyprefix/config.yaml |
| 174 | +
|
| 175 | + """ |
| 176 | + self.sagemaker_session = sagemaker_session |
| 177 | + self.s3_uri = _get_inferred_s3_path(s3_uri, self.sagemaker_session) |
| 178 | + super(_SageMakerConfigFromS3, self).__init__() |
| 179 | + |
| 180 | + def _load_config(self) -> dict: |
| 181 | + """Overridden implementation of the _load_config method. |
| 182 | +
|
| 183 | + This method loads the config file from the S3 URI that was |
| 184 | + specified as part of the constructor. |
| 185 | +
|
| 186 | + Note: Currently we don't support Client side KMS. This capability might be added in the |
| 187 | + future |
| 188 | +
|
| 189 | + Returns: |
| 190 | + dict: A dictionary representing the configurations. |
| 191 | + Note: This dictionary will be empty if an invalid S3 URI is provided |
| 192 | +
|
| 193 | + """ |
| 194 | + config = {} |
| 195 | + s3_file_content = None |
| 196 | + if self.s3_uri: |
| 197 | + try: |
| 198 | + s3_file_content = s3.S3Downloader.read_file(self.s3_uri, self.sagemaker_session) |
| 199 | + except Exception as e: # pylint: disable=W0703 |
| 200 | + # if customers didn't provide us with a valid S3 File/insufficient read permission, |
| 201 | + # we DO want to silently ignore that and operate with an empty config. |
| 202 | + logger.warning( |
| 203 | + "Unable to fetch Config file from S3 with URI: %s due to %s", self.s3_uri, e |
| 204 | + ) |
| 205 | + if s3_file_content: |
| 206 | + config = yaml.safe_load(s3_file_content) |
| 207 | + return config |
| 208 | + |
| 209 | + |
| 210 | +class SageMakerConfigFactory(object): |
| 211 | + """Factory Class to create SageMakerConfig object. |
| 212 | +
|
| 213 | + Also supports specifying an additional config for override. |
| 214 | + """ |
| 215 | + |
| 216 | + @staticmethod |
| 217 | + def build_sagemaker_config( |
| 218 | + default_config_location: str = None, |
| 219 | + additional_override_config_location: str = None, |
| 220 | + sagemaker_session=None, |
| 221 | + ) -> SageMakerConfig: |
| 222 | + """Factory method to build a SageMakerConfig object. |
| 223 | +
|
| 224 | + Note: This method also supports building an additional config override |
| 225 | + (in case if there are multiple configs). |
| 226 | + This method will then merge the additional config override to the base (default) config. |
| 227 | +
|
| 228 | + Args: |
| 229 | + default_config_location (str): File path of the Config. |
| 230 | + This can even be a S3 URI/Local File path. |
| 231 | + additional_override_config_location (str): |
| 232 | + File path of the Config override. This can even be a S3 URI/Local File path. |
| 233 | + sagemaker_session (sagemaker.session.Session): Session object which |
| 234 | + manages interactions with Amazon SageMaker APIs and any other |
| 235 | + AWS services needed. If not specified, a default session object will be created |
| 236 | + using the default AWS configuration chain. |
| 237 | +
|
| 238 | + Returns: |
| 239 | + SageMakerConfig: A SageMakerConfig object which contains the merged Config values. |
| 240 | + """ |
| 241 | + inferred_default_config_path = _get_default_config_path(default_config_location) |
| 242 | + inferred_additional_config_override_path = _get_additional_config_override_path( |
| 243 | + additional_override_config_location |
| 244 | + ) |
| 245 | + default_config = _get_sagemaker_config(inferred_default_config_path, sagemaker_session) |
| 246 | + additional_override_config = _get_sagemaker_config( |
| 247 | + inferred_additional_config_override_path, sagemaker_session |
| 248 | + ) |
| 249 | + default_config.merge(additional_override_config) |
| 250 | + return default_config |
| 251 | + |
| 252 | + |
| 253 | +# Default path in which the SageMaker Python SDK looks for Config objects. |
| 254 | +DEFAULT_CONFIG_FILE_PATH = os.path.join( |
| 255 | + os.path.expanduser("~"), ".sagemaker", "defaults", "sdk-default-config.yaml" |
| 256 | +) |
| 257 | +ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE = "SAGEMAKER_DEFAULT_CONFIG_OVERRIDE" |
| 258 | +ENV_VARIABLE_ADDITIONAL_CONFIG_OVERRIDE = "SAGEMAKER_ADDITIONAL_CONFIG_OVERRIDE" |
| 259 | + |
| 260 | + |
| 261 | +def _get_sagemaker_config(file_path: str, sagemaker_session) -> SageMakerConfig: |
| 262 | + """Constructs SageMakerConfig object based on the file path that was provided. |
| 263 | +
|
| 264 | + Args: |
| 265 | + file_path (str): File path of the Config. This can even be a S3 URI/Local File path. |
| 266 | + sagemaker_session (sagemaker.session.Session): Session object which |
| 267 | + manages interactions with Amazon SageMaker APIs and any other |
| 268 | + AWS services needed. If not specified, a default session object will be created |
| 269 | + using the default AWS configuration chain. |
| 270 | +
|
| 271 | + Returns: |
| 272 | + SageMakerConfig: The SageMakerConfig object based on the file path that was provided. |
| 273 | + """ |
| 274 | + if file_path and file_path.startswith("s3://"): |
| 275 | + return _SageMakerConfigFromS3(file_path, sagemaker_session) |
| 276 | + return _SageMakerConfigFromFile(file_path) |
| 277 | + |
| 278 | + |
| 279 | +def _validate_file_exists(file_path) -> bool: |
| 280 | + """Validates whether a file/directory corresponding to that path exists |
| 281 | +
|
| 282 | + Args: |
| 283 | + file_path (str): The file path for which we need to verify its existence, |
| 284 | +
|
| 285 | + Returns: |
| 286 | + bool: Boolean indicating whether a file/directory corresponding to that path exists. |
| 287 | + """ |
| 288 | + if not os.path.exists(file_path): |
| 289 | + logger.warning("The specified file path: %s for the Config file doesn't exist.", file_path) |
| 290 | + return False |
| 291 | + return True |
| 292 | + |
| 293 | + |
| 294 | +def _get_default_config_path(default_config_directory_method_parameter: str = None) -> str: |
| 295 | + """Returns the default config path. |
| 296 | +
|
| 297 | + Args: |
| 298 | + default_config_directory_method_parameter (str): |
| 299 | + Corresponds to the default_config_location parameter of SageMakerConfigFactory |
| 300 | +
|
| 301 | + Returns: |
| 302 | + str: If default_config_directory_method_parameter is passed, then we will use that value. |
| 303 | + Else, this method will try to fetch the value from SAGEMAKER_CONFIG_FILE environment |
| 304 | + variable.If SAGEMAKER_CONFIG_FILE environment variable is not set, |
| 305 | + we will infer the path to be DEFAULT_CONFIG_FILE_PATH |
| 306 | + """ |
| 307 | + if default_config_directory_method_parameter: |
| 308 | + return default_config_directory_method_parameter |
| 309 | + return os.getenv(ENV_VARIABLE_DEFAULT_CONFIG_OVERRIDE, DEFAULT_CONFIG_FILE_PATH) |
| 310 | + |
| 311 | + |
| 312 | +def _get_additional_config_override_path( |
| 313 | + additional_config_override_method_parameter: str = None, |
| 314 | +) -> str: |
| 315 | + """Returns the additional config override path. |
| 316 | +
|
| 317 | + Args: |
| 318 | + additional_config_override_method_parameter (str): Corresponds to the |
| 319 | + additional_override_config_location parameter of SageMakerConfigFactory |
| 320 | +
|
| 321 | + Returns: |
| 322 | + str: If additional_config_override_method_parameter is passed, then we will use that value. |
| 323 | + Else, this method will try to fetch the value from SAGEMAKER_ADDITIONAL_CONFIG_OVERRIDE |
| 324 | + environment variable |
| 325 | + """ |
| 326 | + if additional_config_override_method_parameter: |
| 327 | + return additional_config_override_method_parameter |
| 328 | + return os.getenv(ENV_VARIABLE_ADDITIONAL_CONFIG_OVERRIDE, None) |
| 329 | + |
| 330 | + |
| 331 | +def _get_inferred_s3_path(s3_uri, sagemaker_session): |
| 332 | + """Verifies whether the given S3 URI exists and returns the URI. |
| 333 | +
|
| 334 | + If there are multiple S3 objects with the same key prefix, |
| 335 | + then this method will verify whether S3 URI + /config.yaml exists. |
| 336 | + s3://example-bucket/somekeyprefix/config.yaml |
| 337 | +
|
| 338 | + Args: |
| 339 | + s3_uri (str) : An S3 uri that refers to a location in which config file is present. |
| 340 | + s3_uri must start with 's3://'. |
| 341 | + An example s3_uri: 's3://example-bucket/config.yaml'. |
| 342 | + sagemaker_session (sagemaker.session.Session): Session object which |
| 343 | + manages interactions with Amazon SageMaker APIs and any other |
| 344 | + AWS services needed. If not specified, a default session object will be created |
| 345 | + using the default AWS configuration chain. |
| 346 | +
|
| 347 | + Returns: |
| 348 | + str: Valid S3 URI of the Config file. None if it doesn't exist. |
| 349 | + """ |
| 350 | + try: |
| 351 | + s3_files_with_same_prefix = s3.S3Downloader.list(s3_uri, sagemaker_session) |
| 352 | + except Exception as e: # pylint: disable=W0703 |
| 353 | + # if customers didn't provide us with a valid S3 File/insufficient read permission, |
| 354 | + # we DO want to silently ignore that and operate with an empty config. |
| 355 | + logger.warning("Unable to read from S3 with URI: %s due to %s", s3_uri, e) |
| 356 | + return None |
| 357 | + if len(s3_files_with_same_prefix) == 0: |
| 358 | + # Provided S3 URI is invalid. |
| 359 | + # we DO want to silently ignore that and operate with an empty config. |
| 360 | + return None |
| 361 | + if len(s3_files_with_same_prefix) > 1: |
| 362 | + # Customer has provided us with a S3 URI which points to a directory |
| 363 | + inferred_s3_path = s3.s3_path_join(s3_uri, "config.yaml") |
| 364 | + if inferred_s3_path not in s3_files_with_same_prefix: |
| 365 | + # We don't know which file we should be operating with. |
| 366 | + return None |
| 367 | + # Customer has a config.yaml present in the directory that was provided as the S3 URI |
| 368 | + return inferred_s3_path |
| 369 | + return s3_uri |
0 commit comments