Skip to content

Fix Sniping bug fix #4730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 13, 2024
143 changes: 89 additions & 54 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,43 +631,68 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):

bucket = s3.Bucket(name=bucket_name)
if bucket.creation_date is None:
try:
# trying head bucket call
s3.meta.client.head_bucket(Bucket=bucket.name)
except ClientError as e:
# bucket does not exist or forbidden to access
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)

elif self._default_bucket_set_by_sdk:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)

expected_bucket_owner_id = self.account_id()
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)

def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
"""Checks if the bucket belongs to a particular owner and throws a Client Error if it is not

Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
expected_bucket_owner_id (str): Owner ID string

"""
try:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "403" and message == "Forbidden":
LOGGER.error(
"Since default_bucket param was not set, SageMaker Python SDK tried to use "
"%s bucket. "
"This bucket cannot be configured to use as it is not owned by Account %s. "
"To unblock it's recommended to use custom default_bucket "
"parameter in sagemaker.Session",
bucket_name,
expected_bucket_owner_id,
)
raise

def general_bucket_check_if_user_has_permission(
self, bucket_name, s3, bucket, region, bucket_creation_date_none
):
"""Checks if the person running has the permissions to the bucket

If there is any other error that comes up with calling head bucket, it is raised up here
If there is no bucket , it will create one

Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
region (str): The region in which to create the bucket.
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not

"""
try:
s3.meta.client.head_bucket(Bucket=bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
# bucket does not exist or forbidden to access
if bucket_creation_date_none:
if error_code == "404" and message == "Not Found":
# bucket does not exist, create one
try:
if region == "us-east-1":
# 'us-east-1' cannot be specified because it is the default region:
# https://github.com/boto/boto3/issues/125
s3.create_bucket(Bucket=bucket_name)
else:
s3.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": region},
)

logger.info("Created S3 bucket: %s", bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]

if (
error_code == "OperationAborted"
and "conflicting conditional operation" in message
):
# If this bucket is already being concurrently created,
# we don't need to create it again.
pass
else:
raise
self.create_bucket_for_not_exist_error(bucket_name, region, s3)
elif error_code == "403" and message == "Forbidden":
logger.error(
LOGGER.error(
"Bucket %s exists, but access is forbidden. Please try again after "
"adding appropriate access.",
bucket.name,
Expand All @@ -676,27 +701,37 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
else:
raise

if self._default_bucket_set_by_sdk:
# make sure the s3 bucket is configured in users account.
expected_bucket_owner_id = self.account_id()
try:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
def create_bucket_for_not_exist_error(self, bucket_name, region, s3):
"""Creates the S3 bucket in the given region

Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
region (str): The region in which to create the bucket.
"""
# bucket does not exist, create one
try:
if region == "us-east-1":
# 'us-east-1' cannot be specified because it is the default region:
# https://github.com/boto/boto3/issues/125
s3.create_bucket(Bucket=bucket_name)
else:
s3.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": region},
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "403" and message == "Forbidden":
LOGGER.error(
"Since default_bucket param was not set, SageMaker Python SDK tried to use "
"%s bucket. "
"This bucket cannot be configured to use as it is not owned by Account %s. "
"To unblock it's recommended to use custom default_bucket "
"parameter in sagemaker.Session",
bucket_name,
expected_bucket_owner_id,
)
raise

logger.info("Created S3 bucket: %s", bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]

if error_code == "OperationAborted" and "conflicting conditional operation" in message:
# If this bucket is already being concurrently created,
# we don't need to create it again.
pass
else:
raise

def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str):
"""Appends tags specified in the sagemaker_config to the given list of tags.
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/test_default_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from __future__ import absolute_import

import datetime
from unittest.mock import Mock

import pytest
from botocore.exceptions import ClientError
from mock import MagicMock
Expand Down Expand Up @@ -42,8 +44,14 @@ def test_default_bucket_s3_create_call(sagemaker_session):
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
operation_name="foo",
)
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
bucket_name = sagemaker_session.default_bucket()
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = Mock(
side_effect=error
)

try:
bucket_name = sagemaker_session.default_bucket()
except ClientError:
pass

create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
_1, _2, create_kwargs = create_calls[0]
Expand All @@ -53,7 +61,6 @@ def test_default_bucket_s3_create_call(sagemaker_session):
"CreateBucketConfiguration": {"LocationConstraint": "us-west-2"},
"Bucket": bucket_name,
}
assert sagemaker_session._default_bucket == bucket_name


def test_default_bucket_s3_needs_access(sagemaker_session, caplog):
Expand Down