24
24
from typing import List
25
25
import boto3
26
26
import yaml
27
- from jsonschema import validate
27
+ import jsonschema
28
28
from platformdirs import site_config_dir , user_config_dir
29
29
from botocore .utils import merge_dicts
30
30
from six .moves .urllib .parse import urlparse
51
51
S3_PREFIX = "s3://"
52
52
53
53
54
- class SageMakerConfig (object ):
55
- """A class that encapsulates the configuration for the SageMaker Python SDK.
54
+ def fetch_sagemaker_config (
55
+ additional_config_paths : List [str ] = None , s3_resource = _DEFAULT_S3_RESOURCE
56
+ ) -> dict :
57
+ """Helper method that loads config files and merges them.
58
+
59
+ By default, this method first searches for config files in the default locations
60
+ defined by the SDK.
61
+
62
+ Users can override the default admin and user config file paths using the
63
+ SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE environment variables,
64
+ respectively.
65
+
66
+ Additional config file paths can also be provided as a parameter.
67
+
68
+ This method then:
69
+ * Loads each config file, whether it is Amazon S3 or the local file system.
70
+ * Validates the schema of the config files.
71
+ * Merges the files in the same order.
72
+
73
+ This method throws exceptions in the following cases:
74
+ * jsonschema.exceptions.ValidationError: Schema validation fails for one or more config
75
+ files.
76
+ * RuntimeError: The method is unable to retrieve the list of all S3 files with the
77
+ same prefix or is unable to retrieve the file.
78
+ * ValueError: There are no S3 files with the prefix when an S3 URI is provided.
79
+ * ValueError: There is no config.yaml file in the S3 bucket when an S3 URI is provided.
80
+ * ValueError: A file doesn't exist in a path that was specified by the user as part of an
81
+ environment variable or additional configuration file path. This doesn't include the default
82
+ config file locations.
83
+
84
+ Args:
85
+ additional_config_paths: List of config file paths.
86
+ These paths can be one of the following. In the case of a directory, this method
87
+ searches for a config.yaml file in that directory. This method does not perform a
88
+ recursive search of folders in that directory.
89
+ * Local file path
90
+ * Local directory path
91
+ * S3 URI of the config file
92
+ * S3 URI of the directory containing the config file
93
+ Note: S3 URI follows the format s3://<bucket>/<Key prefix>
94
+ s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch
95
+ config files from S3. If it is not provided, this method creates a default S3 resource
96
+ See :py:meth:boto3.session.Session.resource. This argument is not needed if the
97
+ config files are present in the local file system.
56
98
57
- This class is used to define default values provided by the user.
58
-
59
- This class is integrated with sagemaker.session.Session. Users of the SageMaker Python SDK
60
- have the ability to pass a SageMakerConfig object to sagemaker.session.Session. If a
61
- SageMakerConfig object is not provided by the user, then sagemaker.session.Session
62
- creates its own SageMakerConfig object.
63
-
64
- Note: After sagemaker.session.Session is initialized, it operates with the configuration
65
- values defined at that instant. If you modify the configuration files or file paths after
66
- sagemaker.session.Session is initialized, those changes are not reflected in
67
- sagemaker.session.Session. To incorporate the changes in the configuration files,
68
- initialize sagemaker.session.Session again.
69
99
"""
70
-
71
- def __init__ (self , additional_config_paths : List [str ] = None , s3_resource = _DEFAULT_S3_RESOURCE ):
72
- """Initializes the SageMakerConfig object.
73
-
74
- By default, this method first searches for config files in the default locations
75
- defined by the SDK.
76
-
77
- Users can override the default admin and user config file paths using the
78
- SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE environment variables,
79
- respectively.
80
-
81
- Additional config file paths can also be provided as a constructor parameter.
82
-
83
- This method then:
84
- * Loads each config file, whether it is Amazon S3 or the local file system.
85
- * Validates the schema of the config files.
86
- * Merges the files in the same order.
87
-
88
- This method throws exceptions in the following cases:
89
- * jsonschema.exceptions.ValidationError: Schema validation fails for one or more config
90
- files.
91
- * RuntimeError: The method is unable to retrieve the list of all S3 files with the
92
- same prefix or is unable to retrieve the file.
93
- * ValueError: There are no S3 files with the prefix when an S3 URI is provided.
94
- * ValueError: There is no config.yaml file in the S3 bucket when an S3 URI is provided.
95
- * ValueError: A file doesn't exist in a path that was specified by the user as part of an
96
- environment variable or additional configuration file path. This doesn't include the default
97
- config file locations.
98
-
99
- Args:
100
- additional_config_paths: List of config file paths.
101
- These paths can be one of the following. In the case of a directory, this method
102
- searches for a config.yaml file in that directory. This method does not perform a
103
- recursive search of folders in that directory.
104
- * Local file path
105
- * Local directory path
106
- * S3 URI of the config file
107
- * S3 URI of the directory containing the config file
108
- Note: S3 URI follows the format s3://<bucket>/<Key prefix>
109
- s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch
110
- config files from S3. If it is not provided, this method creates a default S3 resource
111
- See :py:meth:boto3.session.Session.resource. This argument is not needed if the
112
- config files are present in the local file system.
113
-
114
- """
115
- default_config_path = os .getenv (
116
- ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE , _DEFAULT_ADMIN_CONFIG_FILE_PATH
117
- )
118
- user_config_path = os .getenv (
119
- ENV_VARIABLE_USER_CONFIG_OVERRIDE , _DEFAULT_USER_CONFIG_FILE_PATH
120
- )
121
- self .config_paths = [default_config_path , user_config_path ]
122
- if additional_config_paths :
123
- self .config_paths += additional_config_paths
124
- self .config_paths = list (filter (lambda item : item is not None , self .config_paths ))
125
- self .config = _load_config_files (self .config_paths , s3_resource )
126
-
127
-
128
- def _load_config_files (file_paths : List [str ], s3_resource_for_config ) -> dict :
129
- """Placeholder docstring"""
100
+ default_config_path = os .getenv (
101
+ ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE , _DEFAULT_ADMIN_CONFIG_FILE_PATH
102
+ )
103
+ user_config_path = os .getenv (ENV_VARIABLE_USER_CONFIG_OVERRIDE , _DEFAULT_USER_CONFIG_FILE_PATH )
104
+ config_paths = [default_config_path , user_config_path ]
105
+ if additional_config_paths :
106
+ config_paths += additional_config_paths
107
+ config_paths = list (filter (lambda item : item is not None , config_paths ))
130
108
merged_config = {}
131
- for file_path in file_paths :
109
+ for file_path in config_paths :
132
110
config_from_file = {}
133
111
if file_path .startswith (S3_PREFIX ):
134
- config_from_file = _load_config_from_s3 (file_path , s3_resource_for_config )
112
+ config_from_file = _load_config_from_s3 (file_path , s3_resource )
135
113
else :
136
114
try :
137
115
config_from_file = _load_config_from_file (file_path )
@@ -145,11 +123,24 @@ def _load_config_files(file_paths: List[str], s3_resource_for_config) -> dict:
145
123
# Exceptions.
146
124
raise
147
125
if config_from_file :
148
- validate (config_from_file , SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA )
126
+ validate_sagemaker_config (config_from_file )
149
127
merge_dicts (merged_config , config_from_file )
150
128
return merged_config
151
129
152
130
131
+ def validate_sagemaker_config (sagemaker_config : dict = None ):
132
+ """Helper method that validates whether the schema of a given dictionary.
133
+
134
+ This method will validate whether the dictionary adheres to the schema
135
+ defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`
136
+
137
+ Args:
138
+ sagemaker_config: A dictionary containing default values for the
139
+ SageMaker Python SDK. (default: None).
140
+ """
141
+ jsonschema .validate (sagemaker_config , SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA )
142
+
143
+
153
144
def _load_config_from_file (file_path : str ) -> dict :
154
145
"""Placeholder docstring"""
155
146
inferred_file_path = file_path
0 commit comments