Skip to content

Commit e9ed907

Browse files
humanzzjesterhazy
authored andcommitted
Allow code_location to have no key prefix (#227)
* Allow code_location to have no key prefix Optionally, enable matching the behaviour when using sagemaker's default bucket where a key prefix is not permitted * Add missing blank line in test_estimator.py
1 parent 3008a29 commit e9ed907

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CHANGELOG
77

88
* bug-fix: Unit Tests: Improve unit test runtime
99
* bug-fix: Estimators: Fix attach for LDA
10+
* bug-fix: Estimators: allow code_location to have no key prefix
1011

1112
1.4.1
1213
=====

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def _stage_user_code_in_s3(self):
571571
code_s3_prefix = '{}/source'.format(self._current_job_name)
572572
else:
573573
code_bucket, key_prefix = parse_s3_url(self.code_location)
574-
code_s3_prefix = '{}/{}/source'.format(key_prefix, self._current_job_name)
574+
code_s3_prefix = '/'.join(filter(None, [key_prefix, self._current_job_name, 'source']))
575575

576576
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
577577
bucket=code_bucket,

tests/unit/test_estimator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,24 @@ def test_custom_code_bucket(time, sagemaker_session):
127127
assert train_kwargs['hyperparameters']['sagemaker_submit_directory'] == json.dumps(expected_submit_dir)
128128

129129

130+
@patch('time.strftime', return_value=TIMESTAMP)
131+
def test_custom_code_bucket_without_prefix(time, sagemaker_session):
132+
code_bucket = 'codebucket'
133+
code_location = 's3://{}'.format(code_bucket)
134+
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
135+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
136+
code_location=code_location)
137+
t.fit('s3://bucket/mydata')
138+
139+
expected_key = '{}/source/sourcedir.tar.gz'.format(JOB_NAME)
140+
_, s3_args, _ = sagemaker_session.boto_session.resource('s3').Object.mock_calls[0]
141+
assert s3_args == (code_bucket, expected_key)
142+
143+
expected_submit_dir = 's3://{}/{}'.format(code_bucket, expected_key)
144+
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
145+
assert train_kwargs['hyperparameters']['sagemaker_submit_directory'] == json.dumps(expected_submit_dir)
146+
147+
130148
def test_invalid_custom_code_bucket(sagemaker_session):
131149
code_location = 'thisllworkright?'
132150
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)