Skip to content

Commit 40b559c

Browse files
committed
Fix tests
1 parent 2099ab9 commit 40b559c

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/sagemaker/estimator.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,16 +862,23 @@ def _stage_user_code_in_s3(self):
862862
Returns: s3 uri
863863
864864
"""
865-
if self.code_location is None:
865+
local_mode = self.output_path.startswith('file://')
866+
867+
if self.code_location is None and local_mode:
868+
code_bucket = self.sagemaker_session.default_bucket()
869+
code_s3_prefix = '{}/source'.format(self._current_job_name)
870+
kms_key = None
871+
872+
elif self.code_location is None:
866873
code_bucket, _ = parse_s3_url(self.output_path)
867874
code_s3_prefix = '{}/source'.format(self._current_job_name)
875+
kms_key = self.output_kms_key
868876
else:
869877
code_bucket, key_prefix = parse_s3_url(self.code_location)
870878
code_s3_prefix = '/'.join(filter(None, [key_prefix, self._current_job_name, 'source']))
871879

872-
output_bucket, _ = parse_s3_url(self.output_path)
873-
874-
kms_key = self.output_kms_key if code_bucket == output_bucket else None
880+
output_bucket, _ = parse_s3_url(self.output_path)
881+
kms_key = self.output_kms_key if code_bucket == output_bucket else None
875882

876883
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
877884
bucket=code_bucket,

0 commit comments

Comments
 (0)