Skip to content

Commit ca2077e

Browse files
committed
fix: black check
1 parent c788a7e commit ca2077e

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/sagemaker/feature_store/feature_utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,16 @@ def get_session_from_role(region: str, assume_role: str = None) -> Session:
4747
# It will try to assume the role specified
4848
if assume_role:
4949
sts = boto_session.client(
50-
"sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com"
50+
"sts", region_name=region, endpoint_url=f"https://sts.{region}.amazonaws.com"
5151
)
5252

53-
metadata = sts.assume_role(RoleArn=assume_role, RoleSessionName="SagemakerExecution")
53+
credentials = sts.assume_role(
54+
RoleArn=assume_role, RoleSessionName="SagemakerExecution"
55+
).get("Credentials", {})
5456

55-
access_key_id = metadata["Credentials"]["AccessKeyId"]
56-
secret_access_key = metadata["Credentials"]["SecretAccessKey"]
57-
session_token = metadata["Credentials"]["SessionToken"]
57+
access_key_id = credentials.get("AccessKeyId", None)
58+
secret_access_key = credentials.get("SecretAccessKey", None)
59+
session_token = credentials.get("SessionToken", None)
5860

5961
boto_session = boto3.session.Session(
6062
region_name=region,
@@ -63,15 +65,13 @@ def get_session_from_role(region: str, assume_role: str = None) -> Session:
6365
aws_session_token=session_token,
6466
)
6567

66-
# Sessions
67-
sagemaker_client = boto_session.client("sagemaker")
68-
sagemaker_runtime = boto_session.client("sagemaker-runtime")
69-
runtime_client = boto_session.client(service_name="sagemaker-featurestore-runtime")
7068
sagemaker_session = Session(
7169
boto_session=boto_session,
72-
sagemaker_client=sagemaker_client,
73-
sagemaker_runtime_client=sagemaker_runtime,
74-
sagemaker_featurestore_runtime_client=runtime_client,
70+
sagemaker_client=boto_session.client("sagemaker"),
71+
sagemaker_runtime_client=boto_session.client("sagemaker-runtime"),
72+
sagemaker_featurestore_runtime_client=boto_session.client(
73+
service_name="sagemaker-featurestore-runtime"
74+
),
7575
)
7676

7777
return sagemaker_session
@@ -81,7 +81,7 @@ def get_feature_group_as_dataframe(
8181
feature_group_name: str,
8282
athena_bucket: str,
8383
query: str = """SELECT * FROM "sagemaker_featurestore"."#{table}"
84-
WHERE is_deleted=False """,
84+
WHERE is_deleted=False """,
8585
role: str = None,
8686
region: str = None,
8787
session=None,

0 commit comments

Comments
 (0)