Skip to content

Commit 5f8e532

Browse files
BasilBeiroutiBasil Beirouti
andauthored
feature: adding workgroup functionality to athena query (#3276)
Co-authored-by: Basil Beirouti <[email protected]>
1 parent 3880da9 commit 5f8e532

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

src/sagemaker/feature_store/feature_group.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ class AthenaQuery:
8181
_result_bucket: str = attr.ib(init=False, default=None)
8282
_result_file_prefix: str = attr.ib(init=False, default=None)
8383

84-
def run(self, query_string: str, output_location: str, kms_key: str = None) -> str:
84+
def run(
85+
self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None
86+
) -> str:
8587
"""Execute a SQL query given a query string, output location and kms key.
8688
8789
This method executes the SQL query using Athena and outputs the results to output_location
@@ -91,6 +93,7 @@ def run(self, query_string: str, output_location: str, kms_key: str = None) -> s
9193
query_string: SQL query string.
9294
output_location: S3 URI of the query result.
9395
kms_key: KMS key id. If set, will be used to encrypt the query result file.
96+
workgroup (str): The name of the workgroup in which the query is being started.
9497
9598
Returns:
9699
Execution id of the query.
@@ -101,6 +104,7 @@ def run(self, query_string: str, output_location: str, kms_key: str = None) -> s
101104
query_string=query_string,
102105
output_location=output_location,
103106
kms_key=kms_key,
107+
workgroup=workgroup,
104108
)
105109
self._current_query_execution_id = response["QueryExecutionId"]
106110
parse_result = urlparse(output_location, allow_fragments=False)

src/sagemaker/session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4187,6 +4187,7 @@ def start_query_execution(
41874187
query_string: str,
41884188
output_location: str,
41894189
kms_key: str = None,
4190+
workgroup: str = None,
41904191
) -> Dict[str, str]:
41914192
"""Start Athena query execution.
41924193
@@ -4196,6 +4197,8 @@ def start_query_execution(
41964197
query_string (str): SQL expression.
41974198
output_location (str): S3 location of the output file.
41984199
kms_key (str): KMS key id will be used to encrypt the result if given.
4200+
workgroup (str): The name of the workgroup in which the query is being started.
4201+
If the workgroup is not specified, the default workgroup is used.
41994202
42004203
Returns:
42014204
Response dict from the service.
@@ -4210,6 +4213,9 @@ def start_query_execution(
42104213
)
42114214
kwargs.update(ResultConfiguration=result_config)
42124215

4216+
if workgroup:
4217+
kwargs.update(WorkGroup=workgroup)
4218+
42134219
athena_client = self.boto_session.client("athena", region_name=self.boto_region_name)
42144220
return athena_client.start_query_execution(**kwargs)
42154221

tests/unit/sagemaker/feature_store/test_feature_store.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,14 +468,18 @@ def query(sagemaker_session_mock):
468468

469469

470470
def test_athena_query_run(sagemaker_session_mock, query):
471+
WORKGROUP = "workgroup"
471472
sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"}
472-
query.run(query_string="query", output_location="s3://some-bucket/some-path")
473+
query.run(
474+
query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP
475+
)
473476
sagemaker_session_mock.start_query_execution.assert_called_with(
474477
catalog="catalog",
475478
database="database",
476479
query_string="query",
477480
output_location="s3://some-bucket/some-path",
478481
kms_key=None,
482+
workgroup=WORKGROUP,
479483
)
480484
assert "some-bucket" == query._result_bucket
481485
assert "some-path" == query._result_file_prefix

0 commit comments

Comments
 (0)