Skip to content

Commit 46684a7

Browse files
authored
feature: Retrieve data configuration (aws#3016)
1 parent db39cc2 commit 46684a7

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

src/sagemaker/utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import tarfile
2424
import tempfile
2525
import time
26+
import json
27+
import abc
2628
from datetime import datetime
2729

2830
import botocore
@@ -649,4 +651,72 @@ def _module_import_error(py_module, feature, extras):
649651
return error_msg.format(py_module, feature, extras)
650652

651653

654+
class DataConfig(abc.ABC):
655+
"""Abstract base class for accessing data config hosted in AWS resources.
656+
657+
Provides a skeleton for customization by overriding of method fetch_data_config.
658+
"""
659+
660+
@abc.abstractmethod
661+
def fetch_data_config(self):
662+
"""Abstract method implementing retrieval of data config from a pre-configured data source.
663+
664+
Returns:
665+
object: The data configuration object.
666+
"""
667+
668+
669+
class S3DataConfig(DataConfig):
670+
"""This class extends the DataConfig class to fetch a data config file hosted on S3"""
671+
672+
def __init__(
673+
self,
674+
sagemaker_session,
675+
bucket_name,
676+
prefix,
677+
):
678+
"""Initialize a ``S3DataConfig`` instance.
679+
680+
Args:
681+
sagemaker_session (Session): SageMaker session instance to use for boto configuration.
682+
bucket_name (str): Required. Bucket name from which data config needs to be fetched.
683+
prefix (str): Required. The object prefix for the hosted data config.
684+
685+
"""
686+
if bucket_name is None or prefix is None:
687+
raise ValueError(
688+
"Bucket Name and S3 file Prefix are required arguments and must be provided."
689+
)
690+
691+
super(S3DataConfig, self).__init__()
692+
693+
self.bucket_name = bucket_name
694+
self.prefix = prefix
695+
self.sagemaker_session = sagemaker_session
696+
697+
def fetch_data_config(self):
698+
"""Fetches data configuration from a S3 bucket.
699+
700+
Returns:
701+
object: The JSON object containing data configuration.
702+
"""
703+
704+
json_string = self.sagemaker_session.read_s3_file(self.bucket_name, self.prefix)
705+
return json.loads(json_string)
706+
707+
def get_data_bucket(self, region_requested=None):
708+
"""Provides the bucket containing the data for specified region.
709+
710+
Args:
711+
region_requested (str): The region for which the data is beig requested.
712+
713+
Returns:
714+
str: Name of the S3 bucket containing datasets in the requested region.
715+
"""
716+
717+
config = self.fetch_data_config()
718+
region = region_requested if region_requested else self.sagemaker_session.boto_region_name
719+
return config[region] if region in config.keys() else config["default"]
720+
721+
652722
get_ecr_image_uri_prefix = deprecations.removed_function("get_ecr_image_uri_prefix")

tests/unit/test_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import re
2222
import time
23+
import json
2324

2425
from boto3 import exceptions
2526
import botocore
@@ -206,6 +207,39 @@ def test_secondary_training_status_message_prev_missing():
206207
)
207208

208209

210+
SAMPLE_DATA_CONFIG = {"us-west-2": "sagemaker-hosted-datasets", "default": "sagemaker-sample-files"}
211+
212+
213+
def test_notebooks_data_config_if_region_not_present():
214+
215+
sample_data_config = json.dumps(SAMPLE_DATA_CONFIG)
216+
217+
boto_mock = MagicMock(name="boto_session", region_name="ap-northeast-1")
218+
session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock())
219+
session.read_s3_file = Mock(return_value=sample_data_config)
220+
assert (
221+
sagemaker.utils.S3DataConfig(
222+
session, "example-notebooks-data-config", "config/data_config.json"
223+
).get_data_bucket()
224+
== "sagemaker-sample-files"
225+
)
226+
227+
228+
def test_notebooks_data_config_if_region_present():
229+
230+
sample_data_config = json.dumps(SAMPLE_DATA_CONFIG)
231+
232+
boto_mock = MagicMock(name="boto_session", region_name="us-west-2")
233+
session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock())
234+
session.read_s3_file = Mock(return_value=sample_data_config)
235+
assert (
236+
sagemaker.utils.S3DataConfig(
237+
session, "example-notebooks-data-config", "config/data_config.json"
238+
).get_data_bucket()
239+
== "sagemaker-hosted-datasets"
240+
)
241+
242+
209243
@patch("os.makedirs")
210244
def test_download_folder(makedirs):
211245
boto_mock = MagicMock(name="boto_session")

0 commit comments

Comments
 (0)