Skip to content

Commit a0d3d7b

Browse files
committed
modify unit test to enable cache to True
1 parent 1c2875b commit a0d3d7b

File tree

2 files changed

+14
-216
lines changed

2 files changed

+14
-216
lines changed

tests/integ/test_workflow.py

Lines changed: 8 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ def test_one_step_sklearn_processing_pipeline(
289289
ProcessingInput(dataset_definition=athena_dataset_definition),
290290
]
291291

292+
cache_config = CacheConfig(enable_caching=True, expire_after="T30m")
293+
292294
sklearn_processor = SKLearnProcessor(
293295
framework_version=sklearn_latest_version,
294296
role=role,
@@ -304,6 +306,7 @@ def test_one_step_sklearn_processing_pipeline(
304306
processor=sklearn_processor,
305307
inputs=inputs,
306308
code=script_path,
309+
cache_config=cache_config,
307310
)
308311
pipeline = Pipeline(
309312
name=pipeline_name,
@@ -343,6 +346,11 @@ def test_one_step_sklearn_processing_pipeline(
343346
response = execution.describe()
344347
assert response["PipelineArn"] == create_arn
345348

349+
# Check CacheConfig
350+
response = json.loads(pipeline.describe()["PipelineDefinition"])["Steps"][0]["CacheConfig"]
351+
assert response["Enabled"] == cache_config.enable_caching
352+
assert response["ExpireAfter"] == cache_config.expire_after
353+
346354
try:
347355
execution.wait(delay=30, max_attempts=3)
348356
except WaiterError:
@@ -547,213 +555,3 @@ def test_training_job_with_debugger(
547555
pipeline.delete()
548556
except Exception:
549557
pass
550-
551-
552-
def test_cache_hit(
553-
sagemaker_session,
554-
workflow_session,
555-
region_name,
556-
role,
557-
script_dir,
558-
pipeline_name,
559-
athena_dataset_definition,
560-
):
561-
562-
cache_config = CacheConfig(enable_caching=True, expire_after="T30m")
563-
564-
framework_version = "0.20.0"
565-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
566-
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
567-
568-
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
569-
570-
sklearn_processor = SKLearnProcessor(
571-
framework_version=framework_version,
572-
instance_type=instance_type,
573-
instance_count=instance_count,
574-
base_job_name="test-sklearn",
575-
sagemaker_session=sagemaker_session,
576-
role=role,
577-
)
578-
579-
step_process = ProcessingStep(
580-
name="my-cache-test",
581-
processor=sklearn_processor,
582-
inputs=[
583-
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
584-
ProcessingInput(dataset_definition=athena_dataset_definition),
585-
],
586-
outputs=[
587-
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
588-
ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"),
589-
],
590-
code=os.path.join(script_dir, "preprocessing.py"),
591-
cache_config=cache_config,
592-
)
593-
594-
pipeline = Pipeline(
595-
name=pipeline_name,
596-
parameters=[instance_count, instance_type],
597-
steps=[step_process],
598-
sagemaker_session=workflow_session,
599-
)
600-
601-
try:
602-
response = pipeline.create(role)
603-
create_arn = response["PipelineArn"]
604-
pytest.set_trace()
605-
606-
assert re.match(
607-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
608-
create_arn,
609-
)
610-
611-
# Run pipeline for the first time to get an entry in the cache
612-
execution1 = pipeline.start(parameters={})
613-
assert re.match(
614-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
615-
execution1.arn,
616-
)
617-
618-
response = execution1.describe()
619-
assert response["PipelineArn"] == create_arn
620-
621-
try:
622-
execution1.wait(delay=30, max_attempts=10)
623-
except WaiterError:
624-
pass
625-
execution1_steps = execution1.list_steps()
626-
assert len(execution1_steps) == 1
627-
assert execution1_steps[0]["StepName"] == "my-cache-test"
628-
629-
# Run pipeline for the second time and expect cache hit
630-
execution2 = pipeline.start(parameters={})
631-
assert re.match(
632-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
633-
execution2.arn,
634-
)
635-
636-
response = execution2.describe()
637-
assert response["PipelineArn"] == create_arn
638-
639-
try:
640-
execution2.wait(delay=30, max_attempts=10)
641-
except WaiterError:
642-
pass
643-
execution2_steps = execution2.list_steps()
644-
assert len(execution2_steps) == 1
645-
assert execution2_steps[0]["StepName"] == "my-cache-test"
646-
647-
assert execution1_steps[0] == execution2_steps[0]
648-
649-
finally:
650-
try:
651-
pipeline.delete()
652-
except Exception:
653-
pass
654-
655-
656-
def test_cache_expiry(
657-
sagemaker_session,
658-
workflow_session,
659-
region_name,
660-
role,
661-
script_dir,
662-
pipeline_name,
663-
athena_dataset_definition,
664-
):
665-
666-
cache_config = CacheConfig(enable_caching=True, expire_after="T1m")
667-
668-
framework_version = "0.20.0"
669-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
670-
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
671-
672-
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
673-
674-
sklearn_processor = SKLearnProcessor(
675-
framework_version=framework_version,
676-
instance_type=instance_type,
677-
instance_count=instance_count,
678-
base_job_name="test-sklearn",
679-
sagemaker_session=sagemaker_session,
680-
role=role,
681-
)
682-
683-
step_process = ProcessingStep(
684-
name="my-cache-test-expiry",
685-
processor=sklearn_processor,
686-
inputs=[
687-
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
688-
ProcessingInput(dataset_definition=athena_dataset_definition),
689-
],
690-
outputs=[
691-
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
692-
ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"),
693-
],
694-
code=os.path.join(script_dir, "preprocessing.py"),
695-
cache_config=cache_config,
696-
)
697-
698-
pipeline = Pipeline(
699-
name=pipeline_name,
700-
parameters=[instance_count, instance_type],
701-
steps=[step_process],
702-
sagemaker_session=workflow_session,
703-
)
704-
705-
try:
706-
response = pipeline.create(role)
707-
create_arn = response["PipelineArn"]
708-
709-
assert re.match(
710-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
711-
create_arn,
712-
)
713-
714-
# Run pipeline for the first time to get an entry in the cache
715-
execution1 = pipeline.start(parameters={})
716-
assert re.match(
717-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
718-
execution1.arn,
719-
)
720-
721-
response = execution1.describe()
722-
assert response["PipelineArn"] == create_arn
723-
724-
try:
725-
execution1.wait(delay=30, max_attempts=3)
726-
except WaiterError:
727-
pass
728-
execution1_steps = execution1.list_steps()
729-
assert len(execution1_steps) == 1
730-
assert execution1_steps[0]["StepName"] == "my-cache-test-expiry"
731-
732-
# wait 1 minute for cache to expire
733-
time.sleep(60)
734-
735-
# Run pipeline for the second time and expect cache miss
736-
execution2 = pipeline.start(parameters={})
737-
assert re.match(
738-
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
739-
execution2.arn,
740-
)
741-
742-
response = execution2.describe()
743-
assert response["PipelineArn"] == create_arn
744-
745-
try:
746-
execution2.wait(delay=30, max_attempts=3)
747-
except WaiterError:
748-
pass
749-
execution2_steps = execution2.list_steps()
750-
assert len(execution2_steps) == 1
751-
assert execution2_steps[0]["StepName"] == "my-cache-test-expiry"
752-
753-
assert execution1_steps[0] != execution2_steps[0]
754-
755-
finally:
756-
try:
757-
pipeline.delete()
758-
except Exception:
759-
pass

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_training_step(sagemaker_session):
115115
sagemaker_session=sagemaker_session,
116116
)
117117
inputs = TrainingInput(f"s3://{BUCKET}/train_manifest")
118-
cache_config = CacheConfig(enable_caching=False, expire_after="PT1H")
118+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
119119
step = TrainingStep(
120120
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
121121
)
@@ -145,7 +145,7 @@ def test_training_step(sagemaker_session):
145145
"RoleArn": ROLE,
146146
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
147147
},
148-
"CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"},
148+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
149149
}
150150
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
151151

@@ -164,7 +164,7 @@ def test_processing_step(sagemaker_session):
164164
destination="processing_manifest",
165165
)
166166
]
167-
cache_config = CacheConfig(enable_caching=False, expire_after="PT1H")
167+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
168168
step = ProcessingStep(
169169
name="MyProcessingStep",
170170
processor=processor,
@@ -200,7 +200,7 @@ def test_processing_step(sagemaker_session):
200200
},
201201
"RoleArn": "DummyRole",
202202
},
203-
"CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"},
203+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
204204
}
205205
assert step.properties.ProcessingJobName.expr == {
206206
"Get": "Steps.MyProcessingStep.ProcessingJobName"
@@ -242,7 +242,7 @@ def test_transform_step(sagemaker_session):
242242
sagemaker_session=sagemaker_session,
243243
)
244244
inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
245-
cache_config = CacheConfig(enable_caching=False, expire_after="PT1H")
245+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
246246
step = TransformStep(
247247
name="MyTransformStep", transformer=transformer, inputs=inputs, cache_config=cache_config
248248
)
@@ -265,7 +265,7 @@ def test_transform_step(sagemaker_session):
265265
"InstanceType": "c4.4xlarge",
266266
},
267267
},
268-
"CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"},
268+
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
269269
}
270270
assert step.properties.TransformJobName.expr == {
271271
"Get": "Steps.MyTransformStep.TransformJobName"

0 commit comments

Comments
 (0)