Skip to content

change: Add bucket owner check #4150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
self._default_bucket_name_override = default_bucket
# this may also be set again inside :func:`_initialize` if it is None
self.default_bucket_prefix = default_bucket_prefix
self._default_bucket_set_by_sdk = False

self.s3_resource = None
self.s3_client = None
Expand Down Expand Up @@ -545,8 +546,12 @@ def default_bucket(self):
default_bucket = self._default_bucket_name_override
if not default_bucket:
default_bucket = generate_default_sagemaker_bucket_name(self.boto_session)
self._default_bucket_set_by_sdk = True

self._create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=region)
self._create_s3_bucket_if_it_does_not_exist(
bucket_name=default_bucket,
region=region,
)

self._default_bucket = default_bucket

Expand Down Expand Up @@ -620,6 +625,28 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
else:
raise

if self._default_bucket_set_by_sdk:
# make sure the s3 bucket is configured in users account.
expected_bucket_owner_id = self.account_id()
try:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "403" and message == "Forbidden":
LOGGER.error(
"Since default_bucket param was not set, SageMaker Python SDK tried to use "
"%s bucket. "
"This bucket cannot be configured to use as it is not owned by Account %s. "
"To unblock it's recommended to use custom default_bucket "
"parameter in sagemaker.Session",
bucket_name,
expected_bucket_owner_id,
)
raise

def _append_sagemaker_config_tags(self, tags: list, config_path_to_tags: str):
"""Appends tags specified in the sagemaker_config to the given list of tags.

Expand Down
72 changes: 54 additions & 18 deletions tests/unit/test_default_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import datetime
import pytest
from botocore.exceptions import ClientError
from mock import MagicMock, patch
from mock import MagicMock
import sagemaker

ACCOUNT_ID = "123"
REGION = "us-west-2"
DEFAULT_BUCKET_NAME = "sagemaker-{}-{}".format(REGION, ACCOUNT_ID)


@pytest.fixture
def datetime_obj():
return datetime.datetime(2017, 6, 16, 15, 55, 0)


@pytest.fixture()
def sagemaker_session():
boto_mock = MagicMock(name="boto_session", region_name=REGION)
Expand Down Expand Up @@ -50,23 +56,53 @@ def test_default_bucket_s3_create_call(sagemaker_session):
assert sagemaker_session._default_bucket == bucket_name


def test_default_bucket_s3_needs_access(sagemaker_session):
with patch("logging.Logger.error") as mocked_error_log:
with pytest.raises(ClientError):
error = ClientError(
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
operation_name="foo",
)
sagemaker_session.boto_session.resource(
"s3"
).meta.client.head_bucket.side_effect = error
sagemaker_session.default_bucket()
mocked_error_log.assert_called_once_with(
"Bucket %s exists, but access is forbidden. Please try again after "
"adding appropriate access.",
DEFAULT_BUCKET_NAME,
)
assert sagemaker_session._default_bucket is None
def test_default_bucket_s3_needs_access(sagemaker_session, caplog):
with pytest.raises(ClientError):
error = ClientError(
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
operation_name="foo",
)
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
sagemaker_session.default_bucket()
error_message = (
" exists, but access is forbidden. Please try again after adding appropriate access."
)
assert error_message in caplog.text
assert sagemaker_session._default_bucket is None


def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime_obj, caplog):
with pytest.raises(ClientError):
error = ClientError(
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
operation_name="foo",
)
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
# bucket exists
sagemaker_session.boto_session.resource("s3").Bucket(
name=DEFAULT_BUCKET_NAME
).creation_date = datetime_obj
sagemaker_session.default_bucket()

error_message = "This bucket cannot be configured to use as it is not owned by Account"
assert error_message in caplog.text
assert sagemaker_session._default_bucket is None


def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog):
sagemaker_session._default_bucket_name_override = "custom-bucket-override"
error = ClientError(
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
operation_name="foo",
)
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
# bucket exists
sagemaker_session.boto_session.resource("s3").Bucket(
name=DEFAULT_BUCKET_NAME
).creation_date = datetime_obj
# This should not raise ClientError as no head_bucket call is expected for custom bucket
sagemaker_session.default_bucket()
assert sagemaker_session._default_bucket == "custom-bucket-override"


def test_default_already_cached(sagemaker_session):
Expand Down