Skip to content

Commit 42fc662

Browse files
nargokulliujiaorr
andauthored
Fix Sniping bug fix (#4730)
* Python SDK bucket sniping fix bug * Python SDK bucket sniping fix bug * Minor fixes to default bucket function and fixing unit tests * fix - Fixes from Pylint failures * fix - Fixes from Flake8 failures * fix - More Flake8 fixes * fix - Remove Whitespace from blankline * fix - Fix black recommendations * fix - Adjust tabbing --------- Co-authored-by: Jiao Liu <[email protected]> Co-authored-by: liujiaor <[email protected]>
1 parent 5e218b1 commit 42fc662

File tree

2 files changed

+99
-57
lines changed

2 files changed

+99
-57
lines changed

src/sagemaker/session.py

Lines changed: 89 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -631,43 +631,68 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
631631

632632
bucket = s3.Bucket(name=bucket_name)
633633
if bucket.creation_date is None:
634-
try:
635-
# trying head bucket call
636-
s3.meta.client.head_bucket(Bucket=bucket.name)
637-
except ClientError as e:
638-
# bucket does not exist or forbidden to access
639-
error_code = e.response["Error"]["Code"]
640-
message = e.response["Error"]["Message"]
634+
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
635+
636+
elif self._default_bucket_set_by_sdk:
637+
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)
638+
639+
expected_bucket_owner_id = self.account_id()
640+
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)
641641

642+
def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
643+
"""Checks if the bucket belongs to a particular owner and throws a Client Error if it is not
644+
645+
Args:
646+
bucket_name (str): Name of the S3 bucket
647+
s3 (str): S3 object from boto session
648+
expected_bucket_owner_id (str): Owner ID string
649+
650+
"""
651+
try:
652+
s3.meta.client.head_bucket(
653+
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
654+
)
655+
except ClientError as e:
656+
error_code = e.response["Error"]["Code"]
657+
message = e.response["Error"]["Message"]
658+
if error_code == "403" and message == "Forbidden":
659+
LOGGER.error(
660+
"Since default_bucket param was not set, SageMaker Python SDK tried to use "
661+
"%s bucket. "
662+
"This bucket cannot be configured to use as it is not owned by Account %s. "
663+
"To unblock it's recommended to use custom default_bucket "
664+
"parameter in sagemaker.Session",
665+
bucket_name,
666+
expected_bucket_owner_id,
667+
)
668+
raise
669+
670+
def general_bucket_check_if_user_has_permission(
671+
self, bucket_name, s3, bucket, region, bucket_creation_date_none
672+
):
673+
"""Checks if the person running has the permissions to the bucket
674+
675+
If there is any other error that comes up with calling head bucket, it is raised up here
676+
If there is no bucket , it will create one
677+
678+
Args:
679+
bucket_name (str): Name of the S3 bucket
680+
s3 (str): S3 object from boto session
681+
region (str): The region in which to create the bucket.
682+
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
683+
684+
"""
685+
try:
686+
s3.meta.client.head_bucket(Bucket=bucket_name)
687+
except ClientError as e:
688+
error_code = e.response["Error"]["Code"]
689+
message = e.response["Error"]["Message"]
690+
# bucket does not exist or forbidden to access
691+
if bucket_creation_date_none:
642692
if error_code == "404" and message == "Not Found":
643-
# bucket does not exist, create one
644-
try:
645-
if region == "us-east-1":
646-
# 'us-east-1' cannot be specified because it is the default region:
647-
# https://github.com/boto/boto3/issues/125
648-
s3.create_bucket(Bucket=bucket_name)
649-
else:
650-
s3.create_bucket(
651-
Bucket=bucket_name,
652-
CreateBucketConfiguration={"LocationConstraint": region},
653-
)
654-
655-
logger.info("Created S3 bucket: %s", bucket_name)
656-
except ClientError as e:
657-
error_code = e.response["Error"]["Code"]
658-
message = e.response["Error"]["Message"]
659-
660-
if (
661-
error_code == "OperationAborted"
662-
and "conflicting conditional operation" in message
663-
):
664-
# If this bucket is already being concurrently created,
665-
# we don't need to create it again.
666-
pass
667-
else:
668-
raise
693+
self.create_bucket_for_not_exist_error(bucket_name, region, s3)
669694
elif error_code == "403" and message == "Forbidden":
670-
logger.error(
695+
LOGGER.error(
671696
"Bucket %s exists, but access is forbidden. Please try again after "
672697
"adding appropriate access.",
673698
bucket.name,
@@ -676,27 +701,37 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
676701
else:
677702
raise
678703

679-
if self._default_bucket_set_by_sdk:
680-
# make sure the s3 bucket is configured in users account.
681-
expected_bucket_owner_id = self.account_id()
682-
try:
683-
s3.meta.client.head_bucket(
684-
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
704+
def create_bucket_for_not_exist_error(self, bucket_name, region, s3):
705+
"""Creates the S3 bucket in the given region
706+
707+
Args:
708+
bucket_name (str): Name of the S3 bucket
709+
s3 (str): S3 object from boto session
710+
region (str): The region in which to create the bucket.
711+
"""
712+
# bucket does not exist, create one
713+
try:
714+
if region == "us-east-1":
715+
# 'us-east-1' cannot be specified because it is the default region:
716+
# https://github.com/boto/boto3/issues/125
717+
s3.create_bucket(Bucket=bucket_name)
718+
else:
719+
s3.create_bucket(
720+
Bucket=bucket_name,
721+
CreateBucketConfiguration={"LocationConstraint": region},
685722
)
686-
except ClientError as e:
687-
error_code = e.response["Error"]["Code"]
688-
message = e.response["Error"]["Message"]
689-
if error_code == "403" and message == "Forbidden":
690-
LOGGER.error(
691-
"Since default_bucket param was not set, SageMaker Python SDK tried to use "
692-
"%s bucket. "
693-
"This bucket cannot be configured to use as it is not owned by Account %s. "
694-
"To unblock it's recommended to use custom default_bucket "
695-
"parameter in sagemaker.Session",
696-
bucket_name,
697-
expected_bucket_owner_id,
698-
)
699-
raise
723+
724+
logger.info("Created S3 bucket: %s", bucket_name)
725+
except ClientError as e:
726+
error_code = e.response["Error"]["Code"]
727+
message = e.response["Error"]["Message"]
728+
729+
if error_code == "OperationAborted" and "conflicting conditional operation" in message:
730+
# If this bucket is already being concurrently created,
731+
# we don't need to create it again.
732+
pass
733+
else:
734+
raise
700735

701736
def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str):
702737
"""Appends tags specified in the sagemaker_config to the given list of tags.

tests/unit/test_default_bucket.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from __future__ import absolute_import
1414

1515
import datetime
16+
from unittest.mock import Mock
17+
1618
import pytest
1719
from botocore.exceptions import ClientError
1820
from mock import MagicMock
@@ -42,8 +44,14 @@ def test_default_bucket_s3_create_call(sagemaker_session):
4244
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
4345
operation_name="foo",
4446
)
45-
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
46-
bucket_name = sagemaker_session.default_bucket()
47+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = Mock(
48+
side_effect=error
49+
)
50+
51+
try:
52+
bucket_name = sagemaker_session.default_bucket()
53+
except ClientError:
54+
pass
4755

4856
create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
4957
_1, _2, create_kwargs = create_calls[0]
@@ -53,7 +61,6 @@ def test_default_bucket_s3_create_call(sagemaker_session):
5361
"CreateBucketConfiguration": {"LocationConstraint": "us-west-2"},
5462
"Bucket": bucket_name,
5563
}
56-
assert sagemaker_session._default_bucket == bucket_name
5764

5865

5966
def test_default_bucket_s3_needs_access(sagemaker_session, caplog):

0 commit comments

Comments
 (0)