@@ -695,34 +695,38 @@ def _stage_user_code_in_s3(self) -> str:
695
695
696
696
Returns: S3 URI
697
697
"""
698
- local_mode = not is_pipeline_variable (self .output_path ) and self .output_path .startswith (
699
- "file://"
700
- )
701
-
702
- if self .code_location is None and local_mode :
698
+ if is_pipeline_variable (self .output_path ):
703
699
code_bucket = self .sagemaker_session .default_bucket ()
704
- code_s3_prefix = "{}/{}" .format (self ._current_job_name , "source" )
705
- kms_key = None
706
- elif self .code_location is None :
707
- if is_pipeline_variable (self .output_path ):
708
- code_bucket = self .sagemaker_session .default_bucket ()
700
+ if self .code_location is None :
701
+ code_s3_prefix = "{}/{}" .format (self ._current_job_name , "source" )
702
+ kms_key = self .output_kms_key
709
703
else :
710
- code_bucket , _ = parse_s3_url (self .output_path )
711
- code_s3_prefix = "{}/{}" .format (self ._current_job_name , "source" )
712
- kms_key = self .output_kms_key
713
- elif local_mode :
714
- code_bucket , key_prefix = parse_s3_url (self .code_location )
715
- code_s3_prefix = "/" .join (filter (None , [key_prefix , self ._current_job_name , "source" ]))
716
- kms_key = None
704
+ code_bucket , key_prefix = parse_s3_url (self .code_location )
705
+ code_s3_prefix = "/" .join (
706
+ filter (None , [key_prefix , self ._current_job_name , "source" ])
707
+ )
708
+ kms_key = (
709
+ self .output_kms_key
710
+ if code_bucket == self .sagemaker_session .default_bucket ()
711
+ else None
712
+ )
717
713
else :
718
- code_bucket , key_prefix = parse_s3_url (self .code_location )
719
- code_s3_prefix = "/" .join (filter (None , [key_prefix , self ._current_job_name , "source" ]))
714
+ if self .code_location is None :
715
+ code_bucket , _ = parse_s3_url (self .output_path )
716
+ code_s3_prefix = "{}/{}" .format (self ._current_job_name , "source" )
717
+ else :
718
+ code_bucket , key_prefix = parse_s3_url (self .code_location )
719
+ code_s3_prefix = "/" .join (
720
+ filter (None , [key_prefix , self ._current_job_name , "source" ])
721
+ )
720
722
721
- if is_pipeline_variable (self .output_path ):
722
- output_bucket = self .sagemaker_session .default_bucket ()
723
+ local_mode = self .output_path .startswith ("file://" )
724
+
725
+ if local_mode :
726
+ kms_key = None
723
727
else :
724
728
output_bucket , _ = parse_s3_url (self .output_path )
725
- kms_key = self .output_kms_key if code_bucket == output_bucket else None
729
+ kms_key = self .output_kms_key if code_bucket == output_bucket else None
726
730
727
731
return tar_and_upload_dir (
728
732
session = self .sagemaker_session .boto_session ,
0 commit comments