Skip to content

Commit 5987cc8

Browse files
kusumbhattKusumlata Bhatt
andauthored
fix: Extracted profile_name directly from sagemaker.Session if None (#3660)
* fix: Extracted profile_name directly from sagemaker.Session if None * Extracted profile_name only if profile_name used in sagemaker_session is not default --------- Co-authored-by: Kusumlata Bhatt <[email protected]>
1 parent a776dc6 commit 5987cc8

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/sagemaker/feature_store/feature_group.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ def ingest(
805805
if max_workers <= 0:
806806
raise RuntimeError("max_workers must be greater than 0.")
807807

808+
if profile_name is None and self.sagemaker_session.boto_session.profile_name != "default":
809+
profile_name = self.sagemaker_session.boto_session.profile_name
810+
808811
manager = IngestionManagerPandas(
809812
feature_group_name=self.name,
810813
sagemaker_session=self.sagemaker_session,

tests/unit/sagemaker/feature_store/test_feature_group.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_clien
311311
sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
312312
max_workers=10,
313313
max_processes=1,
314-
profile_name=None,
314+
profile_name=sagemaker_session_mock.boto_session.profile_name,
315315
)
316316
mock_ingestion_manager_instance.run.assert_called_once_with(
317317
data_frame=df, wait=True, timeout=None
@@ -323,6 +323,7 @@ def test_ingest_default(ingestion_manager_init, sagemaker_session_mock):
323323
sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
324324
fs_runtime_client_config_mock
325325
)
326+
sagemaker_session_mock.boto_session.profile_name = "default"
326327

327328
feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
328329
df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))

0 commit comments

Comments
 (0)