Skip to content

Commit 98fa76d

Browse files
authored
feature: allow setting the default bucket in Session (#1168)
1 parent 04a2e75 commit 98fa76d

12 files changed

+267
-19
lines changed

src/sagemaker/local/local_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def __init__(self, boto_session=None):
379379
if platform.system() == "Windows":
380380
logger.warning("Windows Support for Local Mode is Experimental")
381381

382-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
382+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
383383
"""Initialize this Local SageMaker Session.
384384
385385
Args:
@@ -413,6 +413,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
413413

414414
self.config = yaml.load(open(sagemaker_config_file, "r"))
415415

416+
self._default_bucket = None
417+
self._desired_default_bucket_name = default_bucket
418+
416419
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
417420
"""
418421

src/sagemaker/session.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods
7676
bucket based on a naming convention which includes the current AWS account ID.
7777
"""
7878

79-
def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
79+
def __init__(
80+
self,
81+
boto_session=None,
82+
sagemaker_client=None,
83+
sagemaker_runtime_client=None,
84+
default_bucket=None,
85+
):
8086
"""Initialize a SageMaker ``Session``.
8187
8288
Args:
@@ -91,15 +97,23 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
9197
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
9298
using this ``Session`` use this client. If not provided, one will be created using
9399
this instance's ``boto_session``.
100+
default_bucket (str): The default s3 bucket to be used by this session.
101+
Ex: "sagemaker-us-west-2"
102+
94103
"""
95104
self._default_bucket = None
96105

97106
# currently is used for local_code in local mode
98107
self.config = None
99108

100-
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
109+
self._initialize(
110+
boto_session=boto_session,
111+
sagemaker_client=sagemaker_client,
112+
sagemaker_runtime_client=sagemaker_runtime_client,
113+
default_bucket=default_bucket,
114+
)
101115

102-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
116+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
103117
"""Initialize this SageMaker Session.
104118
105119
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
@@ -126,6 +140,12 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
126140

127141
prepend_user_agent(self.sagemaker_runtime_client)
128142

143+
self._default_bucket = None
144+
self._desired_default_bucket_name = default_bucket
145+
146+
# Create default bucket on session init to verify that desired name, if specified, is valid
147+
self.default_bucket()
148+
129149
self.local_mode = False
130150

131151
@property
@@ -314,11 +334,14 @@ def default_bucket(self):
314334
if self._default_bucket:
315335
return self._default_bucket
316336

337+
default_bucket = self._desired_default_bucket_name
317338
region = self.boto_session.region_name
318-
account = self.boto_session.client(
319-
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
320-
).get_caller_identity()["Account"]
321-
default_bucket = "sagemaker-{}-{}".format(region, account)
339+
340+
if not default_bucket:
341+
account = self.boto_session.client(
342+
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
343+
).get_caller_identity()["Account"]
344+
default_bucket = "sagemaker-{}-{}".format(region, account)
322345

323346
s3 = self.boto_session.resource("s3")
324347
try:

tests/integ/test_local_mode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession):
4343
def __init__(self):
4444
super(LocalSession, self).__init__()
4545

46-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
46+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
4747
self.boto_session = boto3.Session(region_name=DEFAULT_REGION)
4848
if self.config is None:
4949
self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}}
@@ -53,6 +53,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
5353
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
5454
self.local_mode = True
5555

56+
self._default_bucket = None
57+
self._desired_default_bucket_name = default_bucket
58+
5659

5760
@pytest.fixture(scope="module")
5861
def mxnet_model(sagemaker_local_session, mxnet_full_version):

tests/integ/test_processing.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
import os
1616

17+
import boto3
1718
import pytest
19+
from botocore.config import Config
20+
from sagemaker import Session
1821
from sagemaker.fw_registry import default_framework_uri
1922

2023
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor
@@ -23,6 +26,35 @@
2326
from tests.integ.kms_utils import get_or_create_kms_key
2427

2528
ROLE = "SageMakerRole"
29+
DEFAULT_REGION = "us-west-2"
30+
CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket"
31+
32+
33+
@pytest.fixture(scope="module")
34+
def sagemaker_session_with_custom_bucket(
35+
boto_config, sagemaker_client_config, sagemaker_runtime_config
36+
):
37+
boto_session = (
38+
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
39+
)
40+
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
41+
sagemaker_client = (
42+
boto_session.client("sagemaker", **sagemaker_client_config)
43+
if sagemaker_client_config
44+
else None
45+
)
46+
runtime_client = (
47+
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
48+
if sagemaker_runtime_config
49+
else None
50+
)
51+
52+
return Session(
53+
boto_session=boto_session,
54+
sagemaker_client=sagemaker_client,
55+
sagemaker_runtime_client=runtime_client,
56+
default_bucket=CUSTOM_BUCKET_PATH,
57+
)
2658

2759

2860
@pytest.fixture(scope="module")
@@ -170,6 +202,90 @@ def test_sklearn_with_customizations(
170202
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
171203

172204

205+
def test_sklearn_with_custom_default_bucket(
206+
sagemaker_session_with_custom_bucket,
207+
image_uri,
208+
sklearn_full_version,
209+
cpu_instance_type,
210+
output_kms_key,
211+
):
212+
213+
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
214+
215+
sklearn_processor = SKLearnProcessor(
216+
framework_version=sklearn_full_version,
217+
role=ROLE,
218+
command=["python3"],
219+
instance_type=cpu_instance_type,
220+
instance_count=1,
221+
volume_size_in_gb=100,
222+
volume_kms_key=None,
223+
output_kms_key=output_kms_key,
224+
max_runtime_in_seconds=3600,
225+
base_job_name="test-sklearn-with-customizations",
226+
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
227+
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
228+
sagemaker_session=sagemaker_session_with_custom_bucket,
229+
)
230+
231+
sklearn_processor.run(
232+
code=os.path.join(DATA_DIR, "dummy_script.py"),
233+
inputs=[
234+
ProcessingInput(
235+
source=input_file_path,
236+
destination="/opt/ml/processing/input/container/path/",
237+
input_name="dummy_input",
238+
s3_data_type="S3Prefix",
239+
s3_input_mode="File",
240+
s3_data_distribution_type="FullyReplicated",
241+
s3_compression_type="None",
242+
)
243+
],
244+
outputs=[
245+
ProcessingOutput(
246+
source="/opt/ml/processing/output/container/path/",
247+
output_name="dummy_output",
248+
s3_upload_mode="EndOfJob",
249+
)
250+
],
251+
arguments=["-v"],
252+
wait=True,
253+
logs=True,
254+
)
255+
256+
job_description = sklearn_processor.latest_job.describe()
257+
258+
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
259+
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
260+
261+
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
262+
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]
263+
264+
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")
265+
266+
assert job_description["ProcessingJobStatus"] == "Completed"
267+
268+
assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
269+
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"
270+
271+
assert job_description["ProcessingResources"] == {
272+
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
273+
}
274+
275+
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
276+
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
277+
"python3",
278+
"/opt/ml/processing/input/code/dummy_script.py",
279+
]
280+
assert job_description["AppSpecification"]["ImageUri"] == image_uri
281+
282+
assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}
283+
284+
assert ROLE in job_description["RoleArn"]
285+
286+
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
287+
288+
173289
def test_sklearn_with_no_inputs_or_outputs(
174290
sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type
175291
):

tests/unit/test_create_deploy_entities.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434
@pytest.fixture()
3535
def sagemaker_session():
3636
boto_mock = Mock(name="boto_session", region_name=REGION)
37+
client_mock = Mock()
38+
client_mock.get_caller_identity.return_value = {
39+
"UserId": "mock_user_id",
40+
"Account": "012345678910",
41+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
42+
}
43+
boto_mock.client.return_value = client_mock
3744
ims = sagemaker.Session(boto_session=boto_mock)
3845
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
3946
return ims

tests/unit/test_default_bucket.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
@pytest.fixture()
2626
def sagemaker_session():
2727
boto_mock = Mock(name="boto_session", region_name=REGION)
28+
client_mock = Mock()
29+
client_mock.get_caller_identity.return_value = {
30+
"UserId": "mock_user_id",
31+
"Account": "012345678910",
32+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
33+
}
34+
boto_mock.client.return_value = client_mock
2835
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
2936
ims = sagemaker.Session(boto_session=boto_mock)
3037
return ims
@@ -48,11 +55,13 @@ def test_default_already_cached(sagemaker_session):
4855
existing_default = "mydefaultbucket"
4956
sagemaker_session._default_bucket = existing_default
5057

58+
before_create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
59+
5160
bucket_name = sagemaker_session.default_bucket()
5261

53-
create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
62+
after_create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
5463
assert bucket_name == existing_default
55-
assert create_calls == []
64+
assert before_create_calls == after_create_calls
5665

5766

5867
def test_default_bucket_exists(sagemaker_session):
@@ -78,22 +87,42 @@ def test_concurrent_bucket_modification(sagemaker_session):
7887
assert bucket_name == DEFAULT_BUCKET_NAME
7988

8089

81-
def test_bucket_creation_client_error(sagemaker_session):
90+
def test_bucket_creation_client_error():
8291
with pytest.raises(ClientError):
92+
boto_mock = Mock(name="boto_session", region_name=REGION)
93+
client_mock = Mock()
94+
client_mock.get_caller_identity.return_value = {
95+
"UserId": "mock_user_id",
96+
"Account": "012345678910",
97+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
98+
}
99+
boto_mock.client.return_value = client_mock
100+
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
101+
83102
error = ClientError(
84103
error_response={"Error": {"Code": "SomethingWrong", "Message": "message"}},
85104
operation_name="foo",
86105
)
87-
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
106+
boto_mock.resource().create_bucket.side_effect = error
88107

89-
sagemaker_session.default_bucket()
90-
assert sagemaker_session._default_bucket is None
108+
session = sagemaker.Session(boto_session=boto_mock)
109+
assert session._default_bucket is None
91110

92111

93-
def test_bucket_creation_other_error(sagemaker_session):
112+
def test_bucket_creation_other_error():
94113
with pytest.raises(RuntimeError):
114+
boto_mock = Mock(name="boto_session", region_name=REGION)
115+
client_mock = Mock()
116+
client_mock.get_caller_identity.return_value = {
117+
"UserId": "mock_user_id",
118+
"Account": "012345678910",
119+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
120+
}
121+
boto_mock.client.return_value = client_mock
122+
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
123+
95124
error = RuntimeError()
96-
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
125+
boto_mock.resource().create_bucket.side_effect = error
97126

98-
sagemaker_session.default_bucket()
99-
assert sagemaker_session._default_bucket is None
127+
session = sagemaker.Session(boto_session=boto_mock)
128+
assert session._default_bucket is None

tests/unit/test_endpoint_from_job.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@
4343
@pytest.fixture()
4444
def sagemaker_session():
4545
boto_mock = Mock(name="boto_session", region_name=REGION)
46+
client_mock = Mock()
47+
client_mock.get_caller_identity.return_value = {
48+
"UserId": "mock_user_id",
49+
"Account": "012345678910",
50+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
51+
}
52+
boto_mock.client.return_value = client_mock
4653
ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock)
4754
ims.sagemaker_client.describe_training_job = Mock(
4855
name="describe_training_job", return_value=TRAINING_JOB_RESPONSE

tests/unit/test_endpoint_from_model_data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@
3636
@pytest.fixture()
3737
def sagemaker_session():
3838
boto_mock = Mock(name="boto_session", region_name=REGION)
39+
client_mock = Mock()
40+
client_mock.get_caller_identity.return_value = {
41+
"UserId": "mock_user_id",
42+
"Account": "012345678910",
43+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
44+
}
45+
boto_mock.client.return_value = client_mock
3946
ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock)
4047
ims.sagemaker_client.describe_model = Mock(
4148
name="describe_model", side_effect=_raise_does_not_exist_client_error

tests/unit/test_exception_on_bad_status.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@ def get_sagemaker_session(returns_status):
2929
client_mock.describe_model_package = MagicMock(
3030
return_value={"ModelPackageStatus": returns_status}
3131
)
32+
client_mock.get_caller_identity.return_value = {
33+
"UserId": "mock_user_id",
34+
"Account": "012345678910",
35+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
36+
}
3237
client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status})
38+
boto_mock.client.return_value = client_mock
3339
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
3440
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
3541
return ims

tests/unit/test_local_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,5 +473,12 @@ def test_file_input_content_type():
473473

474474
def test_local_session_is_set_to_local_mode():
475475
boto_session = Mock(region_name="us-west-2")
476+
client_mock = Mock()
477+
client_mock.get_caller_identity.return_value = {
478+
"UserId": "mock_user_id",
479+
"Account": "012345678910",
480+
"Arn": "arn:aws:iam::012345678910:user/mock-user",
481+
}
482+
boto_session.client.return_value = client_mock
476483
local_session = sagemaker.local.local_session.LocalSession(boto_session=boto_session)
477484
assert local_session.local_mode

0 commit comments

Comments
 (0)