Skip to content

Commit 8e17478

Browse files
committed
simplify the estimator _stage_user_code_in_s3 condition
1 parent de46f1d commit 8e17478

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

src/sagemaker/estimator.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -695,34 +695,38 @@ def _stage_user_code_in_s3(self) -> str:
695695
696696
Returns: S3 URI
697697
"""
698-
local_mode = not is_pipeline_variable(self.output_path) and self.output_path.startswith(
699-
"file://"
700-
)
701-
702-
if self.code_location is None and local_mode:
698+
if is_pipeline_variable(self.output_path):
703699
code_bucket = self.sagemaker_session.default_bucket()
704-
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
705-
kms_key = None
706-
elif self.code_location is None:
707-
if is_pipeline_variable(self.output_path):
708-
code_bucket = self.sagemaker_session.default_bucket()
700+
if self.code_location is None:
701+
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
702+
kms_key = self.output_kms_key
709703
else:
710-
code_bucket, _ = parse_s3_url(self.output_path)
711-
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
712-
kms_key = self.output_kms_key
713-
elif local_mode:
714-
code_bucket, key_prefix = parse_s3_url(self.code_location)
715-
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"]))
716-
kms_key = None
704+
code_bucket, key_prefix = parse_s3_url(self.code_location)
705+
code_s3_prefix = "/".join(
706+
filter(None, [key_prefix, self._current_job_name, "source"])
707+
)
708+
kms_key = (
709+
self.output_kms_key
710+
if code_bucket == self.sagemaker_session.default_bucket()
711+
else None
712+
)
717713
else:
718-
code_bucket, key_prefix = parse_s3_url(self.code_location)
719-
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"]))
714+
if self.code_location is None:
715+
code_bucket, _ = parse_s3_url(self.output_path)
716+
code_s3_prefix = "{}/{}".format(self._current_job_name, "source")
717+
else:
718+
code_bucket, key_prefix = parse_s3_url(self.code_location)
719+
code_s3_prefix = "/".join(
720+
filter(None, [key_prefix, self._current_job_name, "source"])
721+
)
720722

721-
if is_pipeline_variable(self.output_path):
722-
output_bucket = self.sagemaker_session.default_bucket()
723+
local_mode = self.output_path.startswith("file://")
724+
725+
if local_mode:
726+
kms_key = None
723727
else:
724728
output_bucket, _ = parse_s3_url(self.output_path)
725-
kms_key = self.output_kms_key if code_bucket == output_bucket else None
729+
kms_key = self.output_kms_key if code_bucket == output_bucket else None
726730

727731
return tar_and_upload_dir(
728732
session=self.sagemaker_session.boto_session,

0 commit comments

Comments
 (0)