Skip to content

Commit 2ba54d2

Browse files
author
Basil Beirouti
committed
adding workgroup functionality to athena query
1 parent 284ddbe commit 2ba54d2

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

src/sagemaker/feature_store/feature_group.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ class AthenaQuery:
8181
_result_bucket: str = attr.ib(init=False, default=None)
8282
_result_file_prefix: str = attr.ib(init=False, default=None)
8383

84-
def run(self, query_string: str, output_location: str, kms_key: str = None) -> str:
84+
def run(
85+
self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None
86+
) -> str:
8587
"""Execute a SQL query given a query string, output location and kms key.
8688
8789
This method executes the SQL query using Athena and outputs the results to output_location
@@ -91,6 +93,7 @@ def run(self, query_string: str, output_location: str, kms_key: str = None) -> s
9193
query_string: SQL query string.
9294
output_location: S3 URI of the query result.
9395
kms_key: KMS key id. If set, will be used to encrypt the query result file.
96+
workgroup (str): The name of the workgroup in which the query is being started.
9497
9598
Returns:
9699
Execution id of the query.
@@ -101,6 +104,7 @@ def run(self, query_string: str, output_location: str, kms_key: str = None) -> s
101104
query_string=query_string,
102105
output_location=output_location,
103106
kms_key=kms_key,
107+
workgroup=workgroup,
104108
)
105109
self._current_query_execution_id = response["QueryExecutionId"]
106110
parse_result = urlparse(output_location, allow_fragments=False)

src/sagemaker/session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4185,6 +4185,7 @@ def start_query_execution(
41854185
query_string: str,
41864186
output_location: str,
41874187
kms_key: str = None,
4188+
workgroup: str = None,
41884189
) -> Dict[str, str]:
41894190
"""Start Athena query execution.
41904191
@@ -4194,6 +4195,8 @@ def start_query_execution(
41944195
query_string (str): SQL expression.
41954196
output_location (str): S3 location of the output file.
41964197
kms_key (str): KMS key id will be used to encrypt the result if given.
4198+
workgroup (str): The name of the workgroup in which the query is being started. If workgroup is not \
4199+
specified, the default workgroup is used.
41974200
41984201
Returns:
41994202
Response dict from the service.
@@ -4208,6 +4211,9 @@ def start_query_execution(
42084211
)
42094212
kwargs.update(ResultConfiguration=result_config)
42104213

4214+
if workgroup:
4215+
kwargs.update(WorkGroup=workgroup)
4216+
42114217
athena_client = self.boto_session.client("athena", region_name=self.boto_region_name)
42124218
return athena_client.start_query_execution(**kwargs)
42134219

tests/unit/sagemaker/feature_store/test_feature_store.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,14 +468,18 @@ def query(sagemaker_session_mock):
468468

469469

470470
def test_athena_query_run(sagemaker_session_mock, query):
471+
WORKGROUP = "workgroup"
471472
sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"}
472-
query.run(query_string="query", output_location="s3://some-bucket/some-path")
473+
query.run(
474+
query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP
475+
)
473476
sagemaker_session_mock.start_query_execution.assert_called_with(
474477
catalog="catalog",
475478
database="database",
476479
query_string="query",
477480
output_location="s3://some-bucket/some-path",
478481
kms_key=None,
482+
workgroup=WORKGROUP,
479483
)
480484
assert "some-bucket" == query._result_bucket
481485
assert "some-path" == query._result_file_prefix

0 commit comments

Comments
 (0)