Skip to content

Commit 1c2875b

Browse files
committed
modify integ test
1 parent e1d0d45 commit 1c2875b

File tree

1 file changed

+173
-45
lines changed

1 file changed

+173
-45
lines changed

tests/integ/test_workflow.py

Lines changed: 173 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,7 @@
4444
ParameterInteger,
4545
ParameterString,
4646
)
47-
from sagemaker.workflow.steps import (
48-
CreateModelStep,
49-
ProcessingStep,
50-
TrainingStep,
51-
CacheConfig
52-
)
47+
from sagemaker.workflow.steps import CreateModelStep, ProcessingStep, TrainingStep, CacheConfig
5348
from sagemaker.workflow.step_collections import RegisterModel
5449
from sagemaker.workflow.pipeline import Pipeline
5550
from tests.integ import DATA_DIR
@@ -554,76 +549,209 @@ def test_training_job_with_debugger(
554549
pass
555550

556551

557-
def test_cache_hit_expired_entry(
558-
sagemaker_session,
559-
workflow_session,
560-
region_name,
561-
role,
562-
script_dir,
563-
pipeline_name,
552+
def test_cache_hit(
553+
sagemaker_session,
554+
workflow_session,
555+
region_name,
556+
role,
557+
script_dir,
558+
pipeline_name,
559+
athena_dataset_definition,
564560
):
565561

562+
cache_config = CacheConfig(enable_caching=True, expire_after="T30m")
563+
564+
framework_version = "0.20.0"
566565
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
567566
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
568567

568+
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
569569

570-
estimator =
570+
sklearn_processor = SKLearnProcessor(
571+
framework_version=framework_version,
572+
instance_type=instance_type,
573+
instance_count=instance_count,
574+
base_job_name="test-sklearn",
575+
sagemaker_session=sagemaker_session,
576+
role=role,
577+
)
571578

572-
step_train = TrainingStep(
573-
name="my-train",
574-
estimator=sklearn_train,
575-
inputs=TrainingInput(
576-
s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
577-
"train_data"
578-
].S3Output.S3Uri
579-
),
580-
cache_config=
579+
step_process = ProcessingStep(
580+
name="my-cache-test",
581+
processor=sklearn_processor,
582+
inputs=[
583+
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
584+
ProcessingInput(dataset_definition=athena_dataset_definition),
585+
],
586+
outputs=[
587+
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
588+
ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"),
589+
],
590+
code=os.path.join(script_dir, "preprocessing.py"),
591+
cache_config=cache_config,
581592
)
593+
582594
pipeline = Pipeline(
583595
name=pipeline_name,
584-
parameters=[instance_type, instance_count],
585-
steps=[step_train],
596+
parameters=[instance_count, instance_type],
597+
steps=[step_process],
586598
sagemaker_session=workflow_session,
587599
)
588600

589601
try:
590-
# NOTE: We should exercise the case when role used in the pipeline execution is
591-
# different than that required of the steps in the pipeline itself. The role in
592-
# the pipeline definition needs to create training and processing jobs and other
593-
# sagemaker entities. However, the jobs created in the steps themselves execute
594-
# under a potentially different role, often requiring access to S3 and other
595-
# artifacts not required to during creation of the jobs in the pipeline steps.
596602
response = pipeline.create(role)
597603
create_arn = response["PipelineArn"]
604+
pytest.set_trace()
605+
598606
assert re.match(
599-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}",
607+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
600608
create_arn,
601609
)
602610

603-
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
604-
response = pipeline.update(role)
605-
update_arn = response["PipelineArn"]
611+
# Run pipeline for the first time to get an entry in the cache
612+
execution1 = pipeline.start(parameters={})
606613
assert re.match(
607-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}",
608-
update_arn,
614+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
615+
execution1.arn,
609616
)
610617

611-
execution = pipeline.start(parameters={})
618+
response = execution1.describe()
619+
assert response["PipelineArn"] == create_arn
620+
621+
try:
622+
execution1.wait(delay=30, max_attempts=10)
623+
except WaiterError:
624+
pass
625+
execution1_steps = execution1.list_steps()
626+
assert len(execution1_steps) == 1
627+
assert execution1_steps[0]["StepName"] == "my-cache-test"
628+
629+
# Run pipeline for the second time and expect cache hit
630+
execution2 = pipeline.start(parameters={})
612631
assert re.match(
613-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}/execution/",
614-
execution.arn,
632+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
633+
execution2.arn,
615634
)
616635

617-
response = execution.describe()
636+
response = execution2.describe()
618637
assert response["PipelineArn"] == create_arn
619638

620639
try:
621-
execution.wait(delay=30, max_attempts=3)
640+
execution2.wait(delay=30, max_attempts=10)
622641
except WaiterError:
623642
pass
624-
execution_steps = execution.list_steps()
625-
assert len(execution_steps) == 1
626-
assert execution_steps[0]["StepName"] == "sklearn-process"
643+
execution2_steps = execution2.list_steps()
644+
assert len(execution2_steps) == 1
645+
assert execution2_steps[0]["StepName"] == "my-cache-test"
646+
647+
assert execution1_steps[0] == execution2_steps[0]
648+
649+
finally:
650+
try:
651+
pipeline.delete()
652+
except Exception:
653+
pass
654+
655+
656+
def test_cache_expiry(
657+
sagemaker_session,
658+
workflow_session,
659+
region_name,
660+
role,
661+
script_dir,
662+
pipeline_name,
663+
athena_dataset_definition,
664+
):
665+
666+
cache_config = CacheConfig(enable_caching=True, expire_after="T1m")
667+
668+
framework_version = "0.20.0"
669+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
670+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
671+
672+
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
673+
674+
sklearn_processor = SKLearnProcessor(
675+
framework_version=framework_version,
676+
instance_type=instance_type,
677+
instance_count=instance_count,
678+
base_job_name="test-sklearn",
679+
sagemaker_session=sagemaker_session,
680+
role=role,
681+
)
682+
683+
step_process = ProcessingStep(
684+
name="my-cache-test-expiry",
685+
processor=sklearn_processor,
686+
inputs=[
687+
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
688+
ProcessingInput(dataset_definition=athena_dataset_definition),
689+
],
690+
outputs=[
691+
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
692+
ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"),
693+
],
694+
code=os.path.join(script_dir, "preprocessing.py"),
695+
cache_config=cache_config,
696+
)
697+
698+
pipeline = Pipeline(
699+
name=pipeline_name,
700+
parameters=[instance_count, instance_type],
701+
steps=[step_process],
702+
sagemaker_session=workflow_session,
703+
)
704+
705+
try:
706+
response = pipeline.create(role)
707+
create_arn = response["PipelineArn"]
708+
709+
assert re.match(
710+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
711+
create_arn,
712+
)
713+
714+
# Run pipeline for the first time to get an entry in the cache
715+
execution1 = pipeline.start(parameters={})
716+
assert re.match(
717+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
718+
execution1.arn,
719+
)
720+
721+
response = execution1.describe()
722+
assert response["PipelineArn"] == create_arn
723+
724+
try:
725+
execution1.wait(delay=30, max_attempts=3)
726+
except WaiterError:
727+
pass
728+
execution1_steps = execution1.list_steps()
729+
assert len(execution1_steps) == 1
730+
assert execution1_steps[0]["StepName"] == "my-cache-test-expiry"
731+
732+
# wait 1 minute for cache to expire
733+
time.sleep(60)
734+
735+
# Run pipeline for the second time and expect cache miss
736+
execution2 = pipeline.start(parameters={})
737+
assert re.match(
738+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
739+
execution2.arn,
740+
)
741+
742+
response = execution2.describe()
743+
assert response["PipelineArn"] == create_arn
744+
745+
try:
746+
execution2.wait(delay=30, max_attempts=3)
747+
except WaiterError:
748+
pass
749+
execution2_steps = execution2.list_steps()
750+
assert len(execution2_steps) == 1
751+
assert execution2_steps[0]["StepName"] == "my-cache-test-expiry"
752+
753+
assert execution1_steps[0] != execution2_steps[0]
754+
627755
finally:
628756
try:
629757
pipeline.delete()

0 commit comments

Comments
 (0)