Skip to content

Commit 01254b2

Browse files
committed
add datasetdefinition
1 parent cbf9b58 commit 01254b2

File tree

9 files changed

+651
-160
lines changed

9 files changed

+651
-160
lines changed

src/sagemaker/apiutils/_base_types.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,21 @@ def _boto_ignore(cls):
4242
return ["ResponseMetadata"]
4343

4444
@classmethod
45-
def from_boto(cls, boto_dict, **kwargs):
45+
def from_boto(cls, boto_dict, **kwargs): # pylint: disable=R1710
4646
"""Construct an instance of this ApiObject from a boto response.
4747
4848
Args:
4949
boto_dict (dict): A dictionary of a boto response.
5050
**kwargs: Arbitrary keyword arguments
5151
"""
52-
boto_dict = {k: v for k, v in boto_dict.items() if k not in cls._boto_ignore()}
53-
custom_boto_names_to_member_names = {a: b for b, a in cls._custom_boto_names.items()}
54-
cls_kwargs = _boto_functions.from_boto(
55-
boto_dict, custom_boto_names_to_member_names, cls._custom_boto_types
56-
)
57-
cls_kwargs.update(kwargs)
58-
return cls(**cls_kwargs)
52+
if boto_dict:
53+
boto_dict = {k: v for k, v in boto_dict.items() if k not in cls._boto_ignore()}
54+
custom_boto_names_to_member_names = {a: b for b, a in cls._custom_boto_names.items()}
55+
cls_kwargs = _boto_functions.from_boto(
56+
boto_dict, custom_boto_names_to_member_names, cls._custom_boto_types
57+
)
58+
cls_kwargs.update(kwargs)
59+
return cls(**cls_kwargs)
5960

6061
@classmethod
6162
def to_boto(cls, obj):

src/sagemaker/apiutils/_boto_functions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type):
6868
api_type, is_collection = member_name_to_type[member_name]
6969
if is_collection:
7070
if isinstance(boto_value, dict):
71-
member_value = {
72-
key: api_type.from_boto(value) for key, value in boto_value.items()
73-
}
71+
member_value = api_type.from_boto(boto_value)
7472
else:
7573
member_value = [api_type.from_boto(item) for item in boto_value]
7674
else:
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for using DatasetDefinition in Processing job with Amazon SageMaker."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.dataset_definition.inputs import ( # noqa: F401
17+
DatasetDefinition,
18+
S3Input,
19+
RedshiftDatasetDefinition,
20+
AthenaDatasetDefinition,
21+
)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The input configs for DatasetDefinition.
14+
15+
DatasetDefinition supports the data sources like S3 which can be queried via Athena
16+
and Redshift. A mechanism has to be created for customers to generate datasets
17+
from Athena/Redshift queries and to retrieve the data, using Processing jobs
18+
so as to make it available for other downstream processes.
19+
"""
20+
from __future__ import absolute_import
21+
22+
from sagemaker.apiutils._base_types import ApiObject
23+
24+
25+
class RedshiftDatasetDefinition(ApiObject):
26+
"""DatasetDefinition for Redshift.
27+
28+
With this input, SQL queries will be executed using Redshift to generate datasets to S3.
29+
30+
Attributes:
31+
cluster_id (str): The Redshift cluster Identifier.
32+
database (str): The Redshift database created for your cluster.
33+
db_user (str): The user name of a user account that has permission to connect
34+
to the database.
35+
query_string (str): The SQL query statements to be executed.
36+
cluster_role_arn (str): Redshift cluster role arn.
37+
output_s3_uri (str): The path to a specific S3 object or a S3 prefix for output
38+
kms_key_id (str): KMS key id.
39+
output_format (str): the data storage format for Redshift query results.
40+
Valid options are "PARQUET", "CSV"
41+
output_compression (str): compression used for Redshift query results.
42+
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
43+
"""
44+
45+
cluster_id = None
46+
database = None
47+
db_user = None
48+
query_string = None
49+
cluster_role_arn = None
50+
output_s3_uri = None
51+
kms_key_id = None
52+
output_format = "PARQUET"
53+
output_compression = "GZIP"
54+
55+
56+
class AthenaDatasetDefinition(ApiObject):
57+
"""DatasetDefinition for Athena.
58+
59+
With this input, SQL queries will be executed using Athena to generate datasets to S3.
60+
61+
Attributes:
62+
catalog (str): The name of the data catalog used in query execution.
63+
database (str): The name of the database used in the query execution.
64+
query_string (str): The SQL query statements to be executed.
65+
output_s3_uri (str): the path to a specific S3 object or a S3 prefix for output
66+
work_group (str): The name of the workgroup in which the query is being started.
67+
kms_key_id (str): KMS key id.
68+
output_format (str): the data storage format for Athena query results.
69+
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
70+
output_compression (str): compression used for Athena query results.
71+
Valid options are "GZIP", "SNAPPY", "ZLIB"
72+
"""
73+
74+
catalog = None
75+
database = None
76+
query_string = None
77+
output_s3_uri = None
78+
work_group = None
79+
kms_key_id = None
80+
output_format = "PARQUET"
81+
output_compression = "GZIP"
82+
83+
84+
class DatasetDefinition(ApiObject):
85+
"""DatasetDefinition input.
86+
87+
Attributes:
88+
data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
89+
input_mode (str): Valid options are "Pipe" or "File".
90+
local_path (str): the path to a local directory. If not provided, skips data download by
91+
SageMaker platform.
92+
redshift_dataset_definition
93+
(:class:`~sagemaker.dataset_definition.RedshiftDatasetDefinition`): Redshift
94+
dataset definition.
95+
athena_dataset_definition (:class:`~sagemaker.dataset_definition.AthenaDatasetDefinition`):
96+
Athena dataset definition.
97+
"""
98+
99+
_custom_boto_types = {
100+
"redshift_dataset_definition": (RedshiftDatasetDefinition, True),
101+
"athena_dataset_definition": (AthenaDatasetDefinition, True),
102+
}
103+
104+
data_distribution_type = "ShardedByS3Key"
105+
input_mode = "Pipe"
106+
local_path = None
107+
redshift_dataset_definition = None
108+
athena_dataset_definition = None
109+
110+
111+
class S3Input(ApiObject):
112+
"""Metadata of data objects stored in S3.
113+
114+
Two options are provided: specifying a S3 prefix or by explicitly listing the files
115+
in a manifest file and referencing the manifest file's S3 path.
116+
Note: Strong consistency is not guaranteed if S3Prefix is provided here.
117+
S3 list operations are not strongly consistent.
118+
Use ManifestFile if strong consistency is required.
119+
120+
Attributes:
121+
s3_uri (str): the path to a specific S3 object or a S3 prefix
122+
local_path (str): the path to a local directory. If not provided, skips data download
123+
by SageMaker platform.
124+
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
125+
s3_input_mode (str): Valid options are "Pipe" or "File".
126+
s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
127+
s3_compression_type (str): Valid options are "None" or "Gzip".
128+
"""
129+
130+
s3_uri = None
131+
local_path = None
132+
s3_data_type = "S3Prefix"
133+
s3_input_mode = "File"
134+
s3_data_distribution_type = "FullyReplicated"
135+
s3_compression_type = None

0 commit comments

Comments
 (0)