Skip to content

fix: S3Input - add support for instance attributes #2754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 152 additions & 84 deletions src/sagemaker/dataset_definition/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,94 +26,147 @@ class RedshiftDatasetDefinition(ApiObject):
"""DatasetDefinition for Redshift.

With this input, SQL queries will be executed using Redshift to generate datasets to S3.

Parameters:
cluster_id (str): The Redshift cluster Identifier.
database (str): The name of the Redshift database used in Redshift query execution.
db_user (str): The database user name used in Redshift query execution.
query_string (str): The SQL query statements to be executed.
cluster_role_arn (str): The IAM role attached to your Redshift cluster that
Amazon SageMaker uses to generate datasets.
output_s3_uri (str): The location in Amazon S3 where the Redshift query
results are stored.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
SageMaker uses to encrypt data from a Redshift execution.
output_format (str): The data storage format for Redshift query results.
Valid options are "PARQUET", "CSV"
output_compression (str): The compression used for Redshift query results.
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
"""

cluster_id = None
database = None
db_user = None
query_string = None
cluster_role_arn = None
output_s3_uri = None
kms_key_id = None
output_format = None
output_compression = None
def __init__(
self,
cluster_id=None,
database=None,
db_user=None,
query_string=None,
cluster_role_arn=None,
output_s3_uri=None,
kms_key_id=None,
output_format=None,
output_compression=None,
):
"""Initialize RedshiftDatasetDefinition.

Args:
cluster_id (str, default=None): The Redshift cluster Identifier.
database (str, default=None):
The name of the Redshift database used in Redshift query execution.
db_user (str, default=None): The database user name used in Redshift query execution.
query_string (str, default=None): The SQL query statements to be executed.
cluster_role_arn (str, default=None): The IAM role attached to your Redshift cluster
that Amazon SageMaker uses to generate datasets.
output_s3_uri (str, default=None): The location in Amazon S3 where the Redshift query
results are stored.
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
SageMaker uses to encrypt data from a Redshift execution.
output_format (str, default=None): The data storage format for Redshift query results.
Valid options are "PARQUET", "CSV"
output_compression (str, default=None): The compression used for Redshift query results.
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
"""
super(RedshiftDatasetDefinition, self).__init__(
cluster_id=cluster_id,
database=database,
db_user=db_user,
query_string=query_string,
cluster_role_arn=cluster_role_arn,
output_s3_uri=output_s3_uri,
kms_key_id=kms_key_id,
output_format=output_format,
output_compression=output_compression,
)


class AthenaDatasetDefinition(ApiObject):
"""DatasetDefinition for Athena.

With this input, SQL queries will be executed using Athena to generate datasets to S3.

Parameters:
catalog (str): The name of the data catalog used in Athena query execution.
database (str): The name of the database used in the Athena query execution.
query_string (str): The SQL query statements, to be executed.
output_s3_uri (str): The location in Amazon S3 where Athena query results are stored.
work_group (str): The name of the workgroup in which the Athena query is being started.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
SageMaker uses to encrypt data generated from an Athena query execution.
output_format (str): The data storage format for Athena query results.
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
output_compression (str): The compression used for Athena query results.
Valid options are "GZIP", "SNAPPY", "ZLIB"
"""

catalog = None
database = None
query_string = None
output_s3_uri = None
work_group = None
kms_key_id = None
output_format = None
output_compression = None
def __init__(
self,
catalog=None,
database=None,
query_string=None,
output_s3_uri=None,
work_group=None,
kms_key_id=None,
output_format=None,
output_compression=None,
):
"""Initialize AthenaDatasetDefinition.

Args:
catalog (str, default=None): The name of the data catalog used in Athena query
execution.
database (str, default=None): The name of the database used in the Athena query
execution.
query_string (str, default=None): The SQL query statements, to be executed.
output_s3_uri (str, default=None):
The location in Amazon S3 where Athena query results are stored.
work_group (str, default=None):
The name of the workgroup in which the Athena query is being started.
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
SageMaker uses to encrypt data generated from an Athena query execution.
output_format (str, default=None): The data storage format for Athena query results.
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
output_compression (str, default=None): The compression used for Athena query results.
Valid options are "GZIP", "SNAPPY", "ZLIB"
"""
super(AthenaDatasetDefinition, self).__init__(
catalog=catalog,
database=database,
query_string=query_string,
output_s3_uri=output_s3_uri,
work_group=work_group,
kms_key_id=kms_key_id,
output_format=output_format,
output_compression=output_compression,
)


class DatasetDefinition(ApiObject):
"""DatasetDefinition input.

Parameters:
data_distribution_type (str): Whether the generated dataset is FullyReplicated or
ShardedByS3Key (default).
input_mode (str): Whether to use File or Pipe input mode. In File (default) mode, Amazon
SageMaker copies the data from the input source onto the local Amazon Elastic Block
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
source directly to your algorithm without using the EBS volume.
local_path (str): The local path where you want Amazon SageMaker to download the Dataset
Definition inputs to run a processing job. LocalPath is an absolute path to the input
data. This is a required parameter when `AppManaged` is False (default).
redshift_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`):
Configuration for Redshift Dataset Definition input.
athena_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`):
Configuration for Athena Dataset Definition input.
"""
"""DatasetDefinition input."""

_custom_boto_types = {
"redshift_dataset_definition": (RedshiftDatasetDefinition, True),
"athena_dataset_definition": (AthenaDatasetDefinition, True),
}

data_distribution_type = "ShardedByS3Key"
input_mode = "File"
local_path = None
redshift_dataset_definition = None
athena_dataset_definition = None
def __init__(
self,
data_distribution_type="ShardedByS3Key",
input_mode="File",
local_path=None,
redshift_dataset_definition=None,
athena_dataset_definition=None,
):
"""Initialize DatasetDefinition.

Parameters:
data_distribution_type (str, default="ShardedByS3Key"):
Whether the generated dataset is FullyReplicated or ShardedByS3Key (default).
input_mode (str, default="File"):
Whether to use File or Pipe input mode. In File (default) mode, Amazon
SageMaker copies the data from the input source onto the local Amazon Elastic Block
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
source directly to your algorithm without using the EBS volume.
local_path (str, default=None):
The local path where you want Amazon SageMaker to download the Dataset
Definition inputs to run a processing job. LocalPath is an absolute path to the
input data. This is a required parameter when `AppManaged` is False (default).
redshift_dataset_definition
(:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`,
default=None):
Configuration for Redshift Dataset Definition input.
athena_dataset_definition
(:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`,
default=None):
Configuration for Athena Dataset Definition input.
"""
super(DatasetDefinition, self).__init__(
data_distribution_type=data_distribution_type,
input_mode=input_mode,
local_path=local_path,
redshift_dataset_definition=redshift_dataset_definition,
athena_dataset_definition=athena_dataset_definition,
)


class S3Input(ApiObject):
Expand All @@ -124,20 +177,35 @@ class S3Input(ApiObject):
Note: Strong consistency is not guaranteed if S3Prefix is provided here.
S3 list operations are not strongly consistent.
Use ManifestFile if strong consistency is required.

Parameters:
s3_uri (str): the path to a specific S3 object or a S3 prefix
local_path (str): the path to a local directory. If not provided, skips data download
by SageMaker platform.
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
s3_input_mode (str): Valid options are "Pipe" or "File".
s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
s3_compression_type (str): Valid options are "None" or "Gzip".
"""

s3_uri = None
local_path = None
s3_data_type = "S3Prefix"
s3_input_mode = "File"
s3_data_distribution_type = "FullyReplicated"
s3_compression_type = None
def __init__(
self,
s3_uri=None,
local_path=None,
s3_data_type="S3Prefix",
s3_input_mode="File",
s3_data_distribution_type="FullyReplicated",
s3_compression_type=None,
):
"""Initialize S3Input.

Parameters:
s3_uri (str, default=None): the path to a specific S3 object or a S3 prefix
local_path (str, default=None):
the path to a local directory. If not provided, skips data download
by SageMaker platform.
s3_data_type (str, default="S3Prefix"): Valid options are "ManifestFile" or "S3Prefix".
s3_input_mode (str, default="File"): Valid options are "Pipe" or "File".
s3_data_distribution_type (str, default="FullyReplicated"):
Valid options are "FullyReplicated" or "ShardedByS3Key".
s3_compression_type (str, default=None): Valid options are "None" or "Gzip".
"""
super(S3Input, self).__init__(
s3_uri=s3_uri,
local_path=local_path,
s3_data_type=s3_data_type,
s3_input_mode=s3_input_mode,
s3_data_distribution_type=s3_data_distribution_type,
s3_compression_type=s3_compression_type,
)
19 changes: 19 additions & 0 deletions tests/integ/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,14 @@ def _get_processing_inputs_with_all_parameters(bucket):
destination="/opt/ml/processing/input/data/",
input_name="my_dataset",
),
ProcessingInput(
input_name="s3_input_wo_defaults",
s3_input=S3Input(
s3_uri=f"s3://{bucket}",
local_path="/opt/ml/processing/input/s3_input_wo_defaults",
s3_data_type="S3Prefix",
),
),
ProcessingInput(
input_name="s3_input",
s3_input=S3Input(
Expand Down Expand Up @@ -822,6 +830,17 @@ def _get_processing_job_inputs_and_outputs(bucket, output_kms_key):
"S3CompressionType": "None",
},
},
{
"InputName": "s3_input_wo_defaults",
"AppManaged": False,
"S3Input": {
"S3Uri": f"s3://{bucket}",
"LocalPath": "/opt/ml/processing/input/s3_input_wo_defaults",
"S3DataType": "S3Prefix",
"S3InputMode": "File",
"S3DataDistributionType": "FullyReplicated",
},
},
{
"InputName": "s3_input",
"AppManaged": False,
Expand Down