|
44 | 44 | ParameterInteger,
|
45 | 45 | ParameterString,
|
46 | 46 | )
|
47 |
| -from sagemaker.workflow.steps import ( |
48 |
| - CreateModelStep, |
49 |
| - ProcessingStep, |
50 |
| - TrainingStep, |
51 |
| - CacheConfig |
52 |
| -) |
| 47 | +from sagemaker.workflow.steps import CreateModelStep, ProcessingStep, TrainingStep, CacheConfig |
53 | 48 | from sagemaker.workflow.step_collections import RegisterModel
|
54 | 49 | from sagemaker.workflow.pipeline import Pipeline
|
55 | 50 | from tests.integ import DATA_DIR
|
@@ -554,76 +549,209 @@ def test_training_job_with_debugger(
|
554 | 549 | pass
|
555 | 550 |
|
556 | 551 |
|
557 |
| -def test_cache_hit_expired_entry( |
558 |
| - sagemaker_session, |
559 |
| - workflow_session, |
560 |
| - region_name, |
561 |
| - role, |
562 |
| - script_dir, |
563 |
| - pipeline_name, |
| 552 | +def test_cache_hit( |
| 553 | + sagemaker_session, |
| 554 | + workflow_session, |
| 555 | + region_name, |
| 556 | + role, |
| 557 | + script_dir, |
| 558 | + pipeline_name, |
| 559 | + athena_dataset_definition, |
564 | 560 | ):
|
565 | 561 |
|
| 562 | + cache_config = CacheConfig(enable_caching=True, expire_after="T30m") |
| 563 | + |
| 564 | + framework_version = "0.20.0" |
566 | 565 | instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
|
567 | 566 | instance_count = ParameterInteger(name="InstanceCount", default_value=1)
|
568 | 567 |
|
| 568 | + input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv" |
569 | 569 |
|
570 |
| - estimator = |
| 570 | + sklearn_processor = SKLearnProcessor( |
| 571 | + framework_version=framework_version, |
| 572 | + instance_type=instance_type, |
| 573 | + instance_count=instance_count, |
| 574 | + base_job_name="test-sklearn", |
| 575 | + sagemaker_session=sagemaker_session, |
| 576 | + role=role, |
| 577 | + ) |
571 | 578 |
|
572 |
| - step_train = TrainingStep( |
573 |
| - name="my-train", |
574 |
| - estimator=sklearn_train, |
575 |
| - inputs=TrainingInput( |
576 |
| - s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ |
577 |
| - "train_data" |
578 |
| - ].S3Output.S3Uri |
579 |
| - ), |
580 |
| - cache_config= |
| 579 | + step_process = ProcessingStep( |
| 580 | + name="my-cache-test", |
| 581 | + processor=sklearn_processor, |
| 582 | + inputs=[ |
| 583 | + ProcessingInput(source=input_data, destination="/opt/ml/processing/input"), |
| 584 | + ProcessingInput(dataset_definition=athena_dataset_definition), |
| 585 | + ], |
| 586 | + outputs=[ |
| 587 | + ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"), |
| 588 | + ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"), |
| 589 | + ], |
| 590 | + code=os.path.join(script_dir, "preprocessing.py"), |
| 591 | + cache_config=cache_config, |
581 | 592 | )
|
| 593 | + |
582 | 594 | pipeline = Pipeline(
|
583 | 595 | name=pipeline_name,
|
584 |
| - parameters=[instance_type, instance_count], |
585 |
| - steps=[step_train], |
| 596 | + parameters=[instance_count, instance_type], |
| 597 | + steps=[step_process], |
586 | 598 | sagemaker_session=workflow_session,
|
587 | 599 | )
|
588 | 600 |
|
589 | 601 | try:
|
590 |
| - # NOTE: We should exercise the case when role used in the pipeline execution is |
591 |
| - # different than that required of the steps in the pipeline itself. The role in |
592 |
| - # the pipeline definition needs to create training and processing jobs and other |
593 |
| - # sagemaker entities. However, the jobs created in the steps themselves execute |
594 |
| - # under a potentially different role, often requiring access to S3 and other |
595 |
| - # artifacts not required to during creation of the jobs in the pipeline steps. |
596 | 602 | response = pipeline.create(role)
|
597 | 603 | create_arn = response["PipelineArn"]
|
| 604 | + pytest.set_trace() |
| 605 | + |
598 | 606 | assert re.match(
|
599 |
| - fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}", |
| 607 | + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", |
600 | 608 | create_arn,
|
601 | 609 | )
|
602 | 610 |
|
603 |
| - pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] |
604 |
| - response = pipeline.update(role) |
605 |
| - update_arn = response["PipelineArn"] |
| 611 | + # Run pipeline for the first time to get an entry in the cache |
| 612 | + execution1 = pipeline.start(parameters={}) |
606 | 613 | assert re.match(
|
607 |
| - fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}", |
608 |
| - update_arn, |
| 614 | + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", |
| 615 | + execution1.arn, |
609 | 616 | )
|
610 | 617 |
|
611 |
| - execution = pipeline.start(parameters={}) |
| 618 | + response = execution1.describe() |
| 619 | + assert response["PipelineArn"] == create_arn |
| 620 | + |
| 621 | + try: |
| 622 | + execution1.wait(delay=30, max_attempts=10) |
| 623 | + except WaiterError: |
| 624 | + pass |
| 625 | + execution1_steps = execution1.list_steps() |
| 626 | + assert len(execution1_steps) == 1 |
| 627 | + assert execution1_steps[0]["StepName"] == "my-cache-test" |
| 628 | + |
| 629 | + # Run pipeline for the second time and expect cache hit |
| 630 | + execution2 = pipeline.start(parameters={}) |
612 | 631 | assert re.match(
|
613 |
| - fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}/execution/", |
614 |
| - execution.arn, |
| 632 | + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", |
| 633 | + execution2.arn, |
615 | 634 | )
|
616 | 635 |
|
617 |
| - response = execution.describe() |
| 636 | + response = execution2.describe() |
618 | 637 | assert response["PipelineArn"] == create_arn
|
619 | 638 |
|
620 | 639 | try:
|
621 |
| - execution.wait(delay=30, max_attempts=3) |
| 640 | + execution2.wait(delay=30, max_attempts=10) |
622 | 641 | except WaiterError:
|
623 | 642 | pass
|
624 |
| - execution_steps = execution.list_steps() |
625 |
| - assert len(execution_steps) == 1 |
626 |
| - assert execution_steps[0]["StepName"] == "sklearn-process" |
| 643 | + execution2_steps = execution2.list_steps() |
| 644 | + assert len(execution2_steps) == 1 |
| 645 | + assert execution2_steps[0]["StepName"] == "my-cache-test" |
| 646 | + |
| 647 | + assert execution1_steps[0] == execution2_steps[0] |
| 648 | + |
| 649 | + finally: |
| 650 | + try: |
| 651 | + pipeline.delete() |
| 652 | + except Exception: |
| 653 | + pass |
| 654 | + |
| 655 | + |
| 656 | +def test_cache_expiry( |
| 657 | + sagemaker_session, |
| 658 | + workflow_session, |
| 659 | + region_name, |
| 660 | + role, |
| 661 | + script_dir, |
| 662 | + pipeline_name, |
| 663 | + athena_dataset_definition, |
| 664 | +): |
| 665 | + |
| 666 | + cache_config = CacheConfig(enable_caching=True, expire_after="T1m") |
| 667 | + |
| 668 | + framework_version = "0.20.0" |
| 669 | + instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") |
| 670 | + instance_count = ParameterInteger(name="InstanceCount", default_value=1) |
| 671 | + |
| 672 | + input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv" |
| 673 | + |
| 674 | + sklearn_processor = SKLearnProcessor( |
| 675 | + framework_version=framework_version, |
| 676 | + instance_type=instance_type, |
| 677 | + instance_count=instance_count, |
| 678 | + base_job_name="test-sklearn", |
| 679 | + sagemaker_session=sagemaker_session, |
| 680 | + role=role, |
| 681 | + ) |
| 682 | + |
| 683 | + step_process = ProcessingStep( |
| 684 | + name="my-cache-test-expiry", |
| 685 | + processor=sklearn_processor, |
| 686 | + inputs=[ |
| 687 | + ProcessingInput(source=input_data, destination="/opt/ml/processing/input"), |
| 688 | + ProcessingInput(dataset_definition=athena_dataset_definition), |
| 689 | + ], |
| 690 | + outputs=[ |
| 691 | + ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"), |
| 692 | + ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"), |
| 693 | + ], |
| 694 | + code=os.path.join(script_dir, "preprocessing.py"), |
| 695 | + cache_config=cache_config, |
| 696 | + ) |
| 697 | + |
| 698 | + pipeline = Pipeline( |
| 699 | + name=pipeline_name, |
| 700 | + parameters=[instance_count, instance_type], |
| 701 | + steps=[step_process], |
| 702 | + sagemaker_session=workflow_session, |
| 703 | + ) |
| 704 | + |
| 705 | + try: |
| 706 | + response = pipeline.create(role) |
| 707 | + create_arn = response["PipelineArn"] |
| 708 | + |
| 709 | + assert re.match( |
| 710 | + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", |
| 711 | + create_arn, |
| 712 | + ) |
| 713 | + |
| 714 | + # Run pipeline for the first time to get an entry in the cache |
| 715 | + execution1 = pipeline.start(parameters={}) |
| 716 | + assert re.match( |
| 717 | + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", |
| 718 | + execution1.arn, |
| 719 | + ) |
| 720 | + |
| 721 | + response = execution1.describe() |
| 722 | + assert response["PipelineArn"] == create_arn |
| 723 | + |
| 724 | + try: |
| 725 | + execution1.wait(delay=30, max_attempts=3) |
| 726 | + except WaiterError: |
| 727 | + pass |
| 728 | + execution1_steps = execution1.list_steps() |
| 729 | + assert len(execution1_steps) == 1 |
| 730 | + assert execution1_steps[0]["StepName"] == "my-cache-test-expiry" |
| 731 | + |
| 732 | + # wait 1 minute for cache to expire |
| 733 | + time.sleep(60) |
| 734 | + |
| 735 | + # Run pipeline for the second time and expect cache miss |
| 736 | + execution2 = pipeline.start(parameters={}) |
| 737 | + assert re.match( |
| 738 | + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", |
| 739 | + execution2.arn, |
| 740 | + ) |
| 741 | + |
| 742 | + response = execution2.describe() |
| 743 | + assert response["PipelineArn"] == create_arn |
| 744 | + |
| 745 | + try: |
| 746 | + execution2.wait(delay=30, max_attempts=3) |
| 747 | + except WaiterError: |
| 748 | + pass |
| 749 | + execution2_steps = execution2.list_steps() |
| 750 | + assert len(execution2_steps) == 1 |
| 751 | + assert execution2_steps[0]["StepName"] == "my-cache-test-expiry" |
| 752 | + |
| 753 | + assert execution1_steps[0] != execution2_steps[0] |
| 754 | + |
627 | 755 | finally:
|
628 | 756 | try:
|
629 | 757 | pipeline.delete()
|
|
0 commit comments