@@ -289,6 +289,8 @@ def test_one_step_sklearn_processing_pipeline(
289
289
ProcessingInput (dataset_definition = athena_dataset_definition ),
290
290
]
291
291
292
+ cache_config = CacheConfig (enable_caching = True , expire_after = "T30m" )
293
+
292
294
sklearn_processor = SKLearnProcessor (
293
295
framework_version = sklearn_latest_version ,
294
296
role = role ,
@@ -304,6 +306,7 @@ def test_one_step_sklearn_processing_pipeline(
304
306
processor = sklearn_processor ,
305
307
inputs = inputs ,
306
308
code = script_path ,
309
+ cache_config = cache_config ,
307
310
)
308
311
pipeline = Pipeline (
309
312
name = pipeline_name ,
@@ -343,6 +346,11 @@ def test_one_step_sklearn_processing_pipeline(
343
346
response = execution .describe ()
344
347
assert response ["PipelineArn" ] == create_arn
345
348
349
+ # Check CacheConfig
350
+ response = json .loads (pipeline .describe ()["PipelineDefinition" ])["Steps" ][0 ]["CacheConfig" ]
351
+ assert response ["Enabled" ] == cache_config .enable_caching
352
+ assert response ["ExpireAfter" ] == cache_config .expire_after
353
+
346
354
try :
347
355
execution .wait (delay = 30 , max_attempts = 3 )
348
356
except WaiterError :
@@ -547,213 +555,3 @@ def test_training_job_with_debugger(
547
555
pipeline .delete ()
548
556
except Exception :
549
557
pass
550
-
551
-
552
- def test_cache_hit (
553
- sagemaker_session ,
554
- workflow_session ,
555
- region_name ,
556
- role ,
557
- script_dir ,
558
- pipeline_name ,
559
- athena_dataset_definition ,
560
- ):
561
-
562
- cache_config = CacheConfig (enable_caching = True , expire_after = "T30m" )
563
-
564
- framework_version = "0.20.0"
565
- instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
566
- instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
567
-
568
- input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
569
-
570
- sklearn_processor = SKLearnProcessor (
571
- framework_version = framework_version ,
572
- instance_type = instance_type ,
573
- instance_count = instance_count ,
574
- base_job_name = "test-sklearn" ,
575
- sagemaker_session = sagemaker_session ,
576
- role = role ,
577
- )
578
-
579
- step_process = ProcessingStep (
580
- name = "my-cache-test" ,
581
- processor = sklearn_processor ,
582
- inputs = [
583
- ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
584
- ProcessingInput (dataset_definition = athena_dataset_definition ),
585
- ],
586
- outputs = [
587
- ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
588
- ProcessingOutput (output_name = "test_data" , source = "/opt/ml/processing/test" ),
589
- ],
590
- code = os .path .join (script_dir , "preprocessing.py" ),
591
- cache_config = cache_config ,
592
- )
593
-
594
- pipeline = Pipeline (
595
- name = pipeline_name ,
596
- parameters = [instance_count , instance_type ],
597
- steps = [step_process ],
598
- sagemaker_session = workflow_session ,
599
- )
600
-
601
- try :
602
- response = pipeline .create (role )
603
- create_arn = response ["PipelineArn" ]
604
- pytest .set_trace ()
605
-
606
- assert re .match (
607
- fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
608
- create_arn ,
609
- )
610
-
611
- # Run pipeline for the first time to get an entry in the cache
612
- execution1 = pipeline .start (parameters = {})
613
- assert re .match (
614
- fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
615
- execution1 .arn ,
616
- )
617
-
618
- response = execution1 .describe ()
619
- assert response ["PipelineArn" ] == create_arn
620
-
621
- try :
622
- execution1 .wait (delay = 30 , max_attempts = 10 )
623
- except WaiterError :
624
- pass
625
- execution1_steps = execution1 .list_steps ()
626
- assert len (execution1_steps ) == 1
627
- assert execution1_steps [0 ]["StepName" ] == "my-cache-test"
628
-
629
- # Run pipeline for the second time and expect cache hit
630
- execution2 = pipeline .start (parameters = {})
631
- assert re .match (
632
- fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
633
- execution2 .arn ,
634
- )
635
-
636
- response = execution2 .describe ()
637
- assert response ["PipelineArn" ] == create_arn
638
-
639
- try :
640
- execution2 .wait (delay = 30 , max_attempts = 10 )
641
- except WaiterError :
642
- pass
643
- execution2_steps = execution2 .list_steps ()
644
- assert len (execution2_steps ) == 1
645
- assert execution2_steps [0 ]["StepName" ] == "my-cache-test"
646
-
647
- assert execution1_steps [0 ] == execution2_steps [0 ]
648
-
649
- finally :
650
- try :
651
- pipeline .delete ()
652
- except Exception :
653
- pass
654
-
655
-
656
- def test_cache_expiry (
657
- sagemaker_session ,
658
- workflow_session ,
659
- region_name ,
660
- role ,
661
- script_dir ,
662
- pipeline_name ,
663
- athena_dataset_definition ,
664
- ):
665
-
666
- cache_config = CacheConfig (enable_caching = True , expire_after = "T1m" )
667
-
668
- framework_version = "0.20.0"
669
- instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
670
- instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
671
-
672
- input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
673
-
674
- sklearn_processor = SKLearnProcessor (
675
- framework_version = framework_version ,
676
- instance_type = instance_type ,
677
- instance_count = instance_count ,
678
- base_job_name = "test-sklearn" ,
679
- sagemaker_session = sagemaker_session ,
680
- role = role ,
681
- )
682
-
683
- step_process = ProcessingStep (
684
- name = "my-cache-test-expiry" ,
685
- processor = sklearn_processor ,
686
- inputs = [
687
- ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
688
- ProcessingInput (dataset_definition = athena_dataset_definition ),
689
- ],
690
- outputs = [
691
- ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
692
- ProcessingOutput (output_name = "test_data" , source = "/opt/ml/processing/test" ),
693
- ],
694
- code = os .path .join (script_dir , "preprocessing.py" ),
695
- cache_config = cache_config ,
696
- )
697
-
698
- pipeline = Pipeline (
699
- name = pipeline_name ,
700
- parameters = [instance_count , instance_type ],
701
- steps = [step_process ],
702
- sagemaker_session = workflow_session ,
703
- )
704
-
705
- try :
706
- response = pipeline .create (role )
707
- create_arn = response ["PipelineArn" ]
708
-
709
- assert re .match (
710
- fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
711
- create_arn ,
712
- )
713
-
714
- # Run pipeline for the first time to get an entry in the cache
715
- execution1 = pipeline .start (parameters = {})
716
- assert re .match (
717
- fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
718
- execution1 .arn ,
719
- )
720
-
721
- response = execution1 .describe ()
722
- assert response ["PipelineArn" ] == create_arn
723
-
724
- try :
725
- execution1 .wait (delay = 30 , max_attempts = 3 )
726
- except WaiterError :
727
- pass
728
- execution1_steps = execution1 .list_steps ()
729
- assert len (execution1_steps ) == 1
730
- assert execution1_steps [0 ]["StepName" ] == "my-cache-test-expiry"
731
-
732
- # wait 1 minute for cache to expire
733
- time .sleep (60 )
734
-
735
- # Run pipeline for the second time and expect cache miss
736
- execution2 = pipeline .start (parameters = {})
737
- assert re .match (
738
- fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
739
- execution2 .arn ,
740
- )
741
-
742
- response = execution2 .describe ()
743
- assert response ["PipelineArn" ] == create_arn
744
-
745
- try :
746
- execution2 .wait (delay = 30 , max_attempts = 3 )
747
- except WaiterError :
748
- pass
749
- execution2_steps = execution2 .list_steps ()
750
- assert len (execution2_steps ) == 1
751
- assert execution2_steps [0 ]["StepName" ] == "my-cache-test-expiry"
752
-
753
- assert execution1_steps [0 ] != execution2_steps [0 ]
754
-
755
- finally :
756
- try :
757
- pipeline .delete ()
758
- except Exception :
759
- pass
0 commit comments