@@ -47,14 +47,16 @@ def get_session_from_role(region: str, assume_role: str = None) -> Session:
47
47
# It will try to assume the role specified
48
48
if assume_role :
49
49
sts = boto_session .client (
50
- "sts" , region_name = region , endpoint_url = "https://sts.eu-west-1 .amazonaws.com"
50
+ "sts" , region_name = region , endpoint_url = f "https://sts.{ region } .amazonaws.com"
51
51
)
52
52
53
- metadata = sts .assume_role (RoleArn = assume_role , RoleSessionName = "SagemakerExecution" )
53
+ credentials = sts .assume_role (
54
+ RoleArn = assume_role , RoleSessionName = "SagemakerExecution"
55
+ ).get ("Credentials" , {})
54
56
55
- access_key_id = metadata [ "Credentials" ][ " AccessKeyId"]
56
- secret_access_key = metadata [ "Credentials" ][ " SecretAccessKey"]
57
- session_token = metadata [ "Credentials" ][ " SessionToken"]
57
+ access_key_id = credentials . get ( " AccessKeyId", None )
58
+ secret_access_key = credentials . get ( " SecretAccessKey", None )
59
+ session_token = credentials . get ( " SessionToken", None )
58
60
59
61
boto_session = boto3 .session .Session (
60
62
region_name = region ,
@@ -63,15 +65,13 @@ def get_session_from_role(region: str, assume_role: str = None) -> Session:
63
65
aws_session_token = session_token ,
64
66
)
65
67
66
- # Sessions
67
- sagemaker_client = boto_session .client ("sagemaker" )
68
- sagemaker_runtime = boto_session .client ("sagemaker-runtime" )
69
- runtime_client = boto_session .client (service_name = "sagemaker-featurestore-runtime" )
70
68
sagemaker_session = Session (
71
69
boto_session = boto_session ,
72
- sagemaker_client = sagemaker_client ,
73
- sagemaker_runtime_client = sagemaker_runtime ,
74
- sagemaker_featurestore_runtime_client = runtime_client ,
70
+ sagemaker_client = boto_session .client ("sagemaker" ),
71
+ sagemaker_runtime_client = boto_session .client ("sagemaker-runtime" ),
72
+ sagemaker_featurestore_runtime_client = boto_session .client (
73
+ service_name = "sagemaker-featurestore-runtime"
74
+ ),
75
75
)
76
76
77
77
return sagemaker_session
@@ -81,7 +81,7 @@ def get_feature_group_as_dataframe(
81
81
feature_group_name : str ,
82
82
athena_bucket : str ,
83
83
query : str = """SELECT * FROM "sagemaker_featurestore"."#{table}"
84
- WHERE is_deleted=False """ ,
84
+ WHERE is_deleted=False """ ,
85
85
role : str = None ,
86
86
region : str = None ,
87
87
session = None ,
0 commit comments