File tree Expand file tree Collapse file tree 2 files changed +36
-2
lines changed Expand file tree Collapse file tree 2 files changed +36
-2
lines changed Original file line number Diff line number Diff line change @@ -530,7 +530,9 @@ def _create_or_update_code_dir(
530
530
"""
531
531
code_dir = os .path .join (model_dir , "code" )
532
532
if os .path .exists (code_dir ):
533
- shutil .rmtree (code_dir , ignore_errors = True )
533
+ for filename in os .listdir (code_dir ):
534
+ if filename .endswith (".py" ):
535
+ os .remove (os .path .join (code_dir , filename ))
534
536
if source_directory and source_directory .lower ().startswith ("s3://" ):
535
537
local_code_path = os .path .join (tmp , "local_code.tar.gz" )
536
538
download_file_from_url (source_directory , local_code_path , sagemaker_session )
@@ -539,9 +541,12 @@ def _create_or_update_code_dir(
539
541
t .extractall (path = code_dir )
540
542
541
543
elif source_directory :
544
+ if os .path .exists (code_dir ):
545
+ shutil .rmtree (code_dir )
542
546
shutil .copytree (source_directory , code_dir )
543
547
else :
544
- os .mkdir (code_dir )
548
+ if not os .path .exists (code_dir ):
549
+ os .mkdir (code_dir )
545
550
shutil .copy2 (inference_script , code_dir )
546
551
547
552
for dependency in dependencies :
Original file line number Diff line number Diff line change @@ -614,6 +614,35 @@ def test_repack_model_from_file_to_folder(tmp):
614
614
}
615
615
616
616
617
+ def test_repack_model_with_inference_code_and_requirements (tmp , fake_s3 ):
618
+ create_file_tree (
619
+ tmp ,
620
+ [
621
+ "new-inference.py" ,
622
+ "model-dir/model" ,
623
+ "model-dir/code/old-inference.py" ,
624
+ "model-dir/code/requirements.txt" ,
625
+ ],
626
+ )
627
+
628
+ fake_s3 .tar_and_upload ("model-dir" , "s3://fake/location" )
629
+
630
+ sagemaker .utils .repack_model (
631
+ os .path .join (tmp , "new-inference.py" ),
632
+ None ,
633
+ None ,
634
+ "s3://fake/location" ,
635
+ "s3://destination-bucket/repacked-model" ,
636
+ fake_s3 .sagemaker_session ,
637
+ )
638
+
639
+ assert list_tar_files (fake_s3 .fake_upload_path , tmp ) == {
640
+ "/code/requirements.txt" ,
641
+ "/code/new-inference.py" ,
642
+ "/model" ,
643
+ }
644
+
645
+
617
646
class FakeS3 (object ):
618
647
def __init__ (self , tmp ):
619
648
self .tmp = tmp
You can’t perform that action at this time.
0 commit comments