Skip to content

Commit 967ec94

Browse files
knakadlaurenyu
authored andcommitted
feature: allow setting the default bucket in Session (#1176)
Default bucket not created on init.
1 parent d51eef4 commit 967ec94

File tree

3 files changed

+261
-6
lines changed

3 files changed

+261
-6
lines changed

src/sagemaker/session.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods
7676
bucket based on a naming convention which includes the current AWS account ID.
7777
"""
7878

79-
def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
79+
def __init__(
80+
self,
81+
boto_session=None,
82+
sagemaker_client=None,
83+
sagemaker_runtime_client=None,
84+
default_bucket=None,
85+
):
8086
"""Initialize a SageMaker ``Session``.
8187
8288
Args:
@@ -91,13 +97,25 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
9197
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
9298
using this ``Session`` use this client. If not provided, one will be created using
9399
this instance's ``boto_session``.
100+
default_bucket (str): The default Amazon S3 bucket to be used by this session.
101+
This will be created the next time an Amazon S3 bucket is needed (by calling
102+
:func:`default_bucket`).
103+
If not provided, a default bucket will be created based on the following format:
104+
"sagemaker-{region}-{aws-account-id}".
105+
Example: "sagemaker-my-custom-bucket".
106+
94107
"""
95108
self._default_bucket = None
109+
self._default_bucket_name_override = default_bucket
96110

97111
# currently is used for local_code in local mode
98112
self.config = None
99113

100-
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
114+
self._initialize(
115+
boto_session=boto_session,
116+
sagemaker_client=sagemaker_client,
117+
sagemaker_runtime_client=sagemaker_runtime_client,
118+
)
101119

102120
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
103121
"""Initialize this SageMaker Session.
@@ -315,10 +333,13 @@ def default_bucket(self):
315333
return self._default_bucket
316334

317335
region = self.boto_session.region_name
318-
account = self.boto_session.client(
319-
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
320-
).get_caller_identity()["Account"]
321-
default_bucket = "sagemaker-{}-{}".format(region, account)
336+
337+
default_bucket = self._default_bucket_name_override
338+
if not default_bucket:
339+
account = self.boto_session.client(
340+
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
341+
).get_caller_identity()["Account"]
342+
default_bucket = "sagemaker-{}-{}".format(region, account)
322343

323344
s3 = self.boto_session.resource("s3")
324345
try:

tests/integ/test_processing.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
import os
1616

17+
import boto3
1718
import pytest
19+
from botocore.config import Config
20+
from sagemaker import Session
1821
from sagemaker.fw_registry import default_framework_uri
1922

2023
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor
@@ -23,6 +26,35 @@
2326
from tests.integ.kms_utils import get_or_create_kms_key
2427

2528
ROLE = "SageMakerRole"
29+
DEFAULT_REGION = "us-west-2"
30+
CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket"
31+
32+
33+
@pytest.fixture(scope="module")
34+
def sagemaker_session_with_custom_bucket(
35+
boto_config, sagemaker_client_config, sagemaker_runtime_config
36+
):
37+
boto_session = (
38+
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
39+
)
40+
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
41+
sagemaker_client = (
42+
boto_session.client("sagemaker", **sagemaker_client_config)
43+
if sagemaker_client_config
44+
else None
45+
)
46+
runtime_client = (
47+
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
48+
if sagemaker_runtime_config
49+
else None
50+
)
51+
52+
return Session(
53+
boto_session=boto_session,
54+
sagemaker_client=sagemaker_client,
55+
sagemaker_runtime_client=runtime_client,
56+
default_bucket=CUSTOM_BUCKET_PATH,
57+
)
2658

2759

2860
@pytest.fixture(scope="module")
@@ -170,6 +202,89 @@ def test_sklearn_with_customizations(
170202
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
171203

172204

205+
def test_sklearn_with_custom_default_bucket(
206+
sagemaker_session_with_custom_bucket,
207+
image_uri,
208+
sklearn_full_version,
209+
cpu_instance_type,
210+
output_kms_key,
211+
):
212+
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
213+
214+
sklearn_processor = SKLearnProcessor(
215+
framework_version=sklearn_full_version,
216+
role=ROLE,
217+
command=["python3"],
218+
instance_type=cpu_instance_type,
219+
instance_count=1,
220+
volume_size_in_gb=100,
221+
volume_kms_key=None,
222+
output_kms_key=output_kms_key,
223+
max_runtime_in_seconds=3600,
224+
base_job_name="test-sklearn-with-customizations",
225+
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
226+
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
227+
sagemaker_session=sagemaker_session_with_custom_bucket,
228+
)
229+
230+
sklearn_processor.run(
231+
code=os.path.join(DATA_DIR, "dummy_script.py"),
232+
inputs=[
233+
ProcessingInput(
234+
source=input_file_path,
235+
destination="/opt/ml/processing/input/container/path/",
236+
input_name="dummy_input",
237+
s3_data_type="S3Prefix",
238+
s3_input_mode="File",
239+
s3_data_distribution_type="FullyReplicated",
240+
s3_compression_type="None",
241+
)
242+
],
243+
outputs=[
244+
ProcessingOutput(
245+
source="/opt/ml/processing/output/container/path/",
246+
output_name="dummy_output",
247+
s3_upload_mode="EndOfJob",
248+
)
249+
],
250+
arguments=["-v"],
251+
wait=True,
252+
logs=True,
253+
)
254+
255+
job_description = sklearn_processor.latest_job.describe()
256+
257+
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
258+
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
259+
260+
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
261+
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]
262+
263+
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")
264+
265+
assert job_description["ProcessingJobStatus"] == "Completed"
266+
267+
assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
268+
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"
269+
270+
assert job_description["ProcessingResources"] == {
271+
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
272+
}
273+
274+
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
275+
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
276+
"python3",
277+
"/opt/ml/processing/input/code/dummy_script.py",
278+
]
279+
assert job_description["AppSpecification"]["ImageUri"] == image_uri
280+
281+
assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}
282+
283+
assert ROLE in job_description["RoleArn"]
284+
285+
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
286+
287+
173288
def test_sklearn_with_no_inputs_or_outputs(
174289
sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type
175290
):
@@ -405,3 +520,72 @@ def test_processor(sagemaker_session, image_uri, cpu_instance_type, output_kms_k
405520
assert ROLE in job_description["RoleArn"]
406521

407522
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
523+
524+
525+
def test_processor_with_custom_bucket(
526+
sagemaker_session_with_custom_bucket, image_uri, cpu_instance_type, output_kms_key
527+
):
528+
script_path = os.path.join(DATA_DIR, "dummy_script.py")
529+
530+
processor = Processor(
531+
role=ROLE,
532+
image_uri=image_uri,
533+
instance_count=1,
534+
instance_type=cpu_instance_type,
535+
entrypoint=["python3", "/opt/ml/processing/input/code/dummy_script.py"],
536+
volume_size_in_gb=100,
537+
volume_kms_key=None,
538+
output_kms_key=output_kms_key,
539+
max_runtime_in_seconds=3600,
540+
base_job_name="test-processor",
541+
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
542+
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
543+
sagemaker_session=sagemaker_session_with_custom_bucket,
544+
)
545+
546+
processor.run(
547+
inputs=[
548+
ProcessingInput(
549+
source=script_path, destination="/opt/ml/processing/input/code/", input_name="code"
550+
)
551+
],
552+
outputs=[
553+
ProcessingOutput(
554+
source="/opt/ml/processing/output/container/path/",
555+
output_name="dummy_output",
556+
s3_upload_mode="EndOfJob",
557+
)
558+
],
559+
arguments=["-v"],
560+
wait=True,
561+
logs=True,
562+
)
563+
564+
job_description = processor.latest_job.describe()
565+
566+
assert job_description["ProcessingInputs"][0]["InputName"] == "code"
567+
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
568+
569+
assert job_description["ProcessingJobName"].startswith("test-processor")
570+
571+
assert job_description["ProcessingJobStatus"] == "Completed"
572+
573+
assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
574+
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"
575+
576+
assert job_description["ProcessingResources"] == {
577+
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
578+
}
579+
580+
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
581+
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
582+
"python3",
583+
"/opt/ml/processing/input/code/dummy_script.py",
584+
]
585+
assert job_description["AppSpecification"]["ImageUri"] == image_uri
586+
587+
assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}
588+
589+
assert ROLE in job_description["RoleArn"]
590+
591+
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}

tests/integ/test_session.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import boto3
16+
from botocore.config import Config
17+
18+
from sagemaker import Session
19+
20+
DEFAULT_REGION = "us-west-2"
21+
CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist"
22+
23+
24+
def test_sagemaker_session_does_not_create_bucket_on_init(
25+
sagemaker_client_config, sagemaker_runtime_config, boto_config
26+
):
27+
boto_session = (
28+
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
29+
)
30+
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
31+
sagemaker_client = (
32+
boto_session.client("sagemaker", **sagemaker_client_config)
33+
if sagemaker_client_config
34+
else None
35+
)
36+
runtime_client = (
37+
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
38+
if sagemaker_runtime_config
39+
else None
40+
)
41+
42+
Session(
43+
boto_session=boto_session,
44+
sagemaker_client=sagemaker_client,
45+
sagemaker_runtime_client=runtime_client,
46+
default_bucket=CUSTOM_BUCKET_NAME,
47+
)
48+
49+
s3 = boto3.resource("s3")
50+
assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None

0 commit comments

Comments
 (0)