|
23 | 23 | import tarfile
|
24 | 24 | import tempfile
|
25 | 25 | import time
|
| 26 | +import json |
| 27 | +import abc |
26 | 28 | from datetime import datetime
|
27 | 29 |
|
28 | 30 | import botocore
|
@@ -649,4 +651,72 @@ def _module_import_error(py_module, feature, extras):
|
649 | 651 | return error_msg.format(py_module, feature, extras)
|
650 | 652 |
|
651 | 653 |
|
| 654 | +class DataConfig(abc.ABC): |
| 655 | + """Abstract base class for accessing data config hosted in AWS resources. |
| 656 | +
|
| 657 | + Provides a skeleton for customization by overriding of method fetch_data_config. |
| 658 | + """ |
| 659 | + |
| 660 | + @abc.abstractmethod |
| 661 | + def fetch_data_config(self): |
| 662 | + """Abstract method implementing retrieval of data config from a pre-configured data source. |
| 663 | +
|
| 664 | + Returns: |
| 665 | + object: The data configuration object. |
| 666 | + """ |
| 667 | + |
| 668 | + |
| 669 | +class S3DataConfig(DataConfig): |
| 670 | + """This class extends the DataConfig class to fetch a data config file hosted on S3""" |
| 671 | + |
| 672 | + def __init__( |
| 673 | + self, |
| 674 | + sagemaker_session, |
| 675 | + bucket_name, |
| 676 | + prefix, |
| 677 | + ): |
| 678 | + """Initialize a ``S3DataConfig`` instance. |
| 679 | +
|
| 680 | + Args: |
| 681 | + sagemaker_session (Session): SageMaker session instance to use for boto configuration. |
| 682 | + bucket_name (str): Required. Bucket name from which data config needs to be fetched. |
| 683 | + prefix (str): Required. The object prefix for the hosted data config. |
| 684 | +
|
| 685 | + """ |
| 686 | + if bucket_name is None or prefix is None: |
| 687 | + raise ValueError( |
| 688 | + "Bucket Name and S3 file Prefix are required arguments and must be provided." |
| 689 | + ) |
| 690 | + |
| 691 | + super(S3DataConfig, self).__init__() |
| 692 | + |
| 693 | + self.bucket_name = bucket_name |
| 694 | + self.prefix = prefix |
| 695 | + self.sagemaker_session = sagemaker_session |
| 696 | + |
| 697 | + def fetch_data_config(self): |
| 698 | + """Fetches data configuration from a S3 bucket. |
| 699 | +
|
| 700 | + Returns: |
| 701 | + object: The JSON object containing data configuration. |
| 702 | + """ |
| 703 | + |
| 704 | + json_string = self.sagemaker_session.read_s3_file(self.bucket_name, self.prefix) |
| 705 | + return json.loads(json_string) |
| 706 | + |
| 707 | + def get_data_bucket(self, region_requested=None): |
| 708 | + """Provides the bucket containing the data for specified region. |
| 709 | +
|
| 710 | + Args: |
| 711 | + region_requested (str): The region for which the data is beig requested. |
| 712 | +
|
| 713 | + Returns: |
| 714 | + str: Name of the S3 bucket containing datasets in the requested region. |
| 715 | + """ |
| 716 | + |
| 717 | + config = self.fetch_data_config() |
| 718 | + region = region_requested if region_requested else self.sagemaker_session.boto_region_name |
| 719 | + return config[region] if region in config.keys() else config["default"] |
| 720 | + |
| 721 | + |
652 | 722 | get_ecr_image_uri_prefix = deprecations.removed_function("get_ecr_image_uri_prefix")
|
0 commit comments