Skip to content

Commit 9c60705

Browse files
committed
add datasetdefinition
1 parent cbf9b58 commit 9c60705

File tree

8 files changed

+696
-29
lines changed

8 files changed

+696
-29
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.dataset_definition.inputs import ( # noqa: F401
17+
DatasetDefinition,
18+
S3Input,
19+
RedshiftDatasetDefinition,
20+
AthenaDatasetDefinition,
21+
)
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The input configs for DatasetDefinition.
14+
15+
DatasetDefinition support the data sources like S3 data queries via Athena
16+
and Redshift. A mechanism has to be created for customers to generate datasets
17+
from Athena/Redshift queries and to retrieve the data, using Processing jobs
18+
so as to make it available for other downstream processes.
19+
"""
20+
from __future__ import absolute_import
21+
22+
from typing import Dict, Any
23+
import attr
24+
25+
_service_attribute_name = "service_attribute_name"
26+
27+
28+
@attr.s
29+
class _BaseConfig:
30+
"""Base config object for DatasetDefinition.
31+
32+
The class implemented common to_dict() and from_dict() methods to
33+
serialize/deserialize the class to construct service request.
34+
"""
35+
36+
def to_dict(self) -> Dict[str, Any]:
37+
"""Construct the dictionary from the class.
38+
39+
Returns:
40+
dict represents the class.
41+
"""
42+
43+
dictionary = {}
44+
45+
for attribute in self.__class__.get_attributes():
46+
if self.__dict__[attribute.name] is not None:
47+
attribute_value = self.__dict__[attribute.name]
48+
if isinstance(self.__dict__[attribute.name], _BaseConfig):
49+
dictionary[
50+
attribute.metadata[_service_attribute_name]
51+
] = attribute_value.to_dict()
52+
else:
53+
dictionary[attribute.metadata[_service_attribute_name]] = attribute_value
54+
55+
return dictionary
56+
57+
@classmethod
58+
def from_dict(cls, service_response) -> "_BaseConfig":
59+
"""Construct the _BaseConfig object from the dictionary.
60+
61+
args:
62+
service_response: the json response returned from the service.
63+
64+
Returns:
65+
a _BaseConfig object.
66+
"""
67+
68+
if service_response is None:
69+
return None
70+
71+
dd_dict = {}
72+
73+
for attribute in attr.fields(cls):
74+
service_attr_name = attribute.metadata[_service_attribute_name]
75+
if service_attr_name in service_response:
76+
if isinstance(service_response[service_attr_name], dict):
77+
dd_dict[attribute.name] = attribute.type.from_dict(
78+
service_response[service_attr_name]
79+
)
80+
else:
81+
dd_dict[attribute.name] = service_response[
82+
attribute.metadata[_service_attribute_name]
83+
]
84+
85+
return cls(**dd_dict)
86+
87+
@classmethod
88+
def get_attributes(cls):
89+
"""Get all class attributes
90+
91+
Returns:
92+
dict represents the class.
93+
"""
94+
return attr.fields(cls)
95+
96+
97+
@attr.s
98+
class RedshiftDatasetDefinition(_BaseConfig):
99+
"""DatasetDefinition for Redshift.
100+
101+
With this input, SQL queries will be executed using Redshift to generate datasets to S3.
102+
103+
Attributes:
104+
cluster_id (str): The Redshift cluster Identifier.
105+
database (str): The Redshift database created for your cluster.
106+
db_user (str): The user name of a user account that has permission to connect
107+
to the database.
108+
query_string (str): The SQL query statements to be executed.
109+
cluster_role_arn (str): Redshift cluster role arn.
110+
output_s3_uri (str): The path to a specific S3 object or a S3 prefix for output
111+
kms_key_id (str): KMS key id.
112+
output_format (str): the data storage format for Redshift query results.
113+
Valid options are "PARQUET", "CSV"
114+
output_compression (str): compression used for Redshift query results.
115+
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
116+
"""
117+
118+
cluster_id: str = attr.ib(
119+
validator=attr.validators.instance_of(str), metadata={_service_attribute_name: "ClusterId"}
120+
)
121+
database: str = attr.ib(
122+
validator=attr.validators.instance_of(str), metadata={_service_attribute_name: "Database"}
123+
)
124+
db_user: str = attr.ib(
125+
validator=attr.validators.instance_of(str), metadata={_service_attribute_name: "DbUser"}
126+
)
127+
query_string: str = attr.ib(
128+
validator=attr.validators.instance_of(str),
129+
metadata={_service_attribute_name: "QueryString"},
130+
)
131+
cluster_role_arn: str = attr.ib(
132+
validator=attr.validators.instance_of(str),
133+
metadata={_service_attribute_name: "ClusterRoleArn"},
134+
)
135+
output_s3_uri: str = attr.ib(
136+
validator=attr.validators.instance_of(str),
137+
metadata={_service_attribute_name: "OutputS3Uri"},
138+
)
139+
140+
kms_key_id: str = attr.ib(default=None, metadata={_service_attribute_name: "KmsKeyId"})
141+
output_format: str = attr.ib(
142+
default="PARQUET", metadata={_service_attribute_name: "OutputFormat"}
143+
)
144+
output_compression: str = attr.ib(
145+
default="GZIP", metadata={_service_attribute_name: "OutputCompression"}
146+
)
147+
148+
149+
@attr.s
150+
class AthenaDatasetDefinition(_BaseConfig):
151+
"""DatasetDefinition for Athena.
152+
153+
With this input, SQL queries will be executed using Athena to generate datasets to S3.
154+
155+
Attributes:
156+
catalog (str): The name of the data catalog used in query execution.
157+
database (str): The name of the database used in the query execution.
158+
query_string (str): The SQL query statements to be executed.
159+
output_s3_uri (str): the path to a specific S3 object or a S3 prefix for output
160+
work_group (str): The name of the workgroup in which the query is being started.
161+
kms_key_id (str): KMS key id.
162+
output_format (str): the data storage format for Athena query results.
163+
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
164+
output_compression (str): compression used for Athena query results.
165+
Valid options are "GZIP", "SNAPPY", "ZLIB"
166+
"""
167+
168+
catalog: str = attr.ib(
169+
validator=attr.validators.instance_of(str), metadata={_service_attribute_name: "Catalog"}
170+
)
171+
database: str = attr.ib(
172+
validator=attr.validators.instance_of(str), metadata={_service_attribute_name: "Database"}
173+
)
174+
query_string: str = attr.ib(
175+
validator=attr.validators.instance_of(str),
176+
metadata={_service_attribute_name: "QueryString"},
177+
)
178+
output_s3_uri: str = attr.ib(
179+
validator=attr.validators.instance_of(str),
180+
metadata={_service_attribute_name: "OutputS3Uri"},
181+
)
182+
work_group: str = attr.ib(default=None, metadata={_service_attribute_name: "WorkGroup"})
183+
kms_key_id: str = attr.ib(default=None, metadata={_service_attribute_name: "KmsKeyId"})
184+
output_format: str = attr.ib(
185+
default="PARQUET", metadata={_service_attribute_name: "OutputFormat"}
186+
)
187+
output_compression: str = attr.ib(
188+
default="GZIP", metadata={_service_attribute_name: "OutputCompression"}
189+
)
190+
191+
192+
@attr.s
193+
class DatasetDefinition(_BaseConfig):
194+
"""DatasetDefinition input.
195+
196+
Attributes:
197+
data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
198+
input_mode (str): Valid options are "Pipe" or "File".
199+
local_path (str): the path to a local directory. If not provided, skips data download by
200+
SageMaker platform.
201+
redshift_dataset_definition
202+
(:class:`~sagemaker.dataset_definition.RedshiftDatasetDefinition`): Redshift
203+
dataset definition.
204+
athena_dataset_definition (:class:`~sagemaker.dataset_definition.AthenaDatasetDefinition`):
205+
Athena dataset definition.
206+
"""
207+
208+
data_distribution_type: str = attr.ib(
209+
default="ShardedByS3Key", metadata={_service_attribute_name: "DataDistributionType"}
210+
)
211+
input_mode: str = attr.ib(default="Pipe", metadata={_service_attribute_name: "InputMode"})
212+
local_path: str = attr.ib(default=None, metadata={_service_attribute_name: "LocalPath"})
213+
redshift_dataset_definition: RedshiftDatasetDefinition = attr.ib(
214+
default=None, metadata={_service_attribute_name: "RedshiftDatasetDefinition"}
215+
)
216+
athena_dataset_definition: AthenaDatasetDefinition = attr.ib(
217+
default=None, metadata={_service_attribute_name: "AthenaDatasetDefinition"}
218+
)
219+
220+
221+
@attr.s
222+
class S3Input(_BaseConfig):
223+
"""Metadata of data objects stored in S3.
224+
225+
Two options are provided: specifying a S3 prefix or by explicitly listing the files
226+
in a manifest file and referencing the manifest file's S3 path.
227+
Note: Strong consistency is not guaranteed if S3Prefix is provided here.
228+
S3 list operations are not strongly consistent.
229+
Use ManifestFile if strong consistency is required.
230+
231+
Attributes:
232+
s3_uri (str): the path to a specific S3 object or a S3 prefix
233+
local_path (str): the path to a local directory. If not provided, skips data download
234+
by SageMaker platform.
235+
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
236+
s3_input_mode (str): Valid options are "Pipe" or "File".
237+
s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
238+
s3_compression_type (str): Valid options are "None" or "Gzip".
239+
"""
240+
241+
s3_uri: str = attr.ib(
242+
validator=attr.validators.instance_of(str), metadata={_service_attribute_name: "S3Uri"}
243+
)
244+
local_path: str = attr.ib(
245+
validator=attr.validators.instance_of(str), metadata={_service_attribute_name: "LocalPath"}
246+
)
247+
s3_data_type: str = attr.ib(
248+
default="S3Prefix", metadata={_service_attribute_name: "S3DataType"}
249+
)
250+
s3_input_mode: str = attr.ib(default="File", metadata={_service_attribute_name: "S3InputMode"})
251+
s3_data_distribution_type: str = attr.ib(
252+
default="FullyReplicated", metadata={_service_attribute_name: "S3DataDistributionType"}
253+
)
254+
s3_compression_type: str = attr.ib(
255+
default=None, metadata={_service_attribute_name: "S3CompressionType"}
256+
)

0 commit comments

Comments
 (0)