@@ -244,6 +244,7 @@ def test_to_pipeline(
244
244
step = wrapped_func ,
245
245
role = EXECUTION_ROLE_ARN ,
246
246
max_retries = 1 ,
247
+ tags = [("tag_key_1" , "tag_value_1" ), ("tag_key_2" , "tag_value_2" )],
247
248
sagemaker_session = session ,
248
249
)
249
250
assert pipeline_arn == PIPELINE_ARN
@@ -346,7 +347,11 @@ def test_to_pipeline(
346
347
[
347
348
call (
348
349
role_arn = EXECUTION_ROLE_ARN ,
349
- tags = [dict (Key = FEATURE_PROCESSOR_TAG_KEY , Value = FEATURE_PROCESSOR_TAG_VALUE )],
350
+ tags = [
351
+ dict (Key = FEATURE_PROCESSOR_TAG_KEY , Value = FEATURE_PROCESSOR_TAG_VALUE ),
352
+ dict (Key = "tag_key_1" , Value = "tag_value_1" ),
353
+ dict (Key = "tag_key_2" , Value = "tag_value_2" ),
354
+ ],
350
355
),
351
356
call (
352
357
role_arn = EXECUTION_ROLE_ARN ,
@@ -527,6 +532,128 @@ def test_to_pipeline_pipeline_name_length_limit_exceeds(
527
532
)
528
533
529
534
535
+ @patch ("sagemaker.remote_function.job.Session" , return_value = mock_session ())
536
+ @patch (
537
+ "sagemaker.remote_function.job._JobSettings._get_default_spark_image" ,
538
+ return_value = "some_image_uri" ,
539
+ )
540
+ @patch ("sagemaker.remote_function.job.get_execution_role" , return_value = EXECUTION_ROLE_ARN )
541
+ def test_to_pipeline_too_many_tags (get_execution_role , mock_spark_image , session ):
542
+ session .sagemaker_config = None
543
+ session .boto_region_name = TEST_REGION
544
+ session .expand_role .return_value = EXECUTION_ROLE_ARN
545
+ spark_config = SparkConfig (submit_files = ["file_a" , "file_b" , "file_c" ])
546
+ job_settings = _JobSettings (
547
+ spark_config = spark_config ,
548
+ s3_root_uri = S3_URI ,
549
+ role = EXECUTION_ROLE_ARN ,
550
+ include_local_workdir = True ,
551
+ instance_type = "ml.m5.large" ,
552
+ encrypt_inter_container_traffic = True ,
553
+ sagemaker_session = session ,
554
+ )
555
+ jobs_container_entrypoint = [
556
+ "/bin/bash" ,
557
+ f"/opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ ENTRYPOINT_SCRIPT_NAME } " ,
558
+ ]
559
+ jobs_container_entrypoint .extend (["--jars" , "path_a" ])
560
+ jobs_container_entrypoint .extend (["--py-files" , "path_b" ])
561
+ jobs_container_entrypoint .extend (["--files" , "path_c" ])
562
+ jobs_container_entrypoint .extend ([SPARK_APP_SCRIPT_PATH ])
563
+ container_args = ["--s3_base_uri" , f"{ S3_URI } /pipeline_name" ]
564
+ container_args .extend (["--region" , session .boto_region_name ])
565
+
566
+ mock_feature_processor_config = Mock (
567
+ mode = FeatureProcessorMode .PYSPARK , inputs = [tdh .FEATURE_PROCESSOR_INPUTS ], output = "some_fg"
568
+ )
569
+ mock_feature_processor_config .mode .return_value = FeatureProcessorMode .PYSPARK
570
+
571
+ wrapped_func = Mock (
572
+ Callable ,
573
+ feature_processor_config = mock_feature_processor_config ,
574
+ job_settings = job_settings ,
575
+ wrapped_func = job_function ,
576
+ )
577
+ wrapped_func .feature_processor_config .return_value = mock_feature_processor_config
578
+ wrapped_func .job_settings .return_value = job_settings
579
+ wrapped_func .wrapped_func .return_value = job_function
580
+
581
+ tags = [("key_" + str (i ), "value_" + str (i )) for i in range (50 )]
582
+
583
+ with pytest .raises (
584
+ ValueError ,
585
+ match = "to_pipeline can only accept up to 47 tags. Please reduce the number of tags provided." ,
586
+ ):
587
+ to_pipeline (
588
+ pipeline_name = "pipeline_name" ,
589
+ step = wrapped_func ,
590
+ role = EXECUTION_ROLE_ARN ,
591
+ max_retries = 1 ,
592
+ tags = tags ,
593
+ sagemaker_session = session ,
594
+ )
595
+
596
+
597
+ @patch ("sagemaker.remote_function.job.Session" , return_value = mock_session ())
598
+ @patch (
599
+ "sagemaker.remote_function.job._JobSettings._get_default_spark_image" ,
600
+ return_value = "some_image_uri" ,
601
+ )
602
+ @patch ("sagemaker.remote_function.job.get_execution_role" , return_value = EXECUTION_ROLE_ARN )
603
+ def test_to_pipeline_used_reserved_tags (get_execution_role , mock_spark_image , session ):
604
+ session .sagemaker_config = None
605
+ session .boto_region_name = TEST_REGION
606
+ session .expand_role .return_value = EXECUTION_ROLE_ARN
607
+ spark_config = SparkConfig (submit_files = ["file_a" , "file_b" , "file_c" ])
608
+ job_settings = _JobSettings (
609
+ spark_config = spark_config ,
610
+ s3_root_uri = S3_URI ,
611
+ role = EXECUTION_ROLE_ARN ,
612
+ include_local_workdir = True ,
613
+ instance_type = "ml.m5.large" ,
614
+ encrypt_inter_container_traffic = True ,
615
+ sagemaker_session = session ,
616
+ )
617
+ jobs_container_entrypoint = [
618
+ "/bin/bash" ,
619
+ f"/opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ ENTRYPOINT_SCRIPT_NAME } " ,
620
+ ]
621
+ jobs_container_entrypoint .extend (["--jars" , "path_a" ])
622
+ jobs_container_entrypoint .extend (["--py-files" , "path_b" ])
623
+ jobs_container_entrypoint .extend (["--files" , "path_c" ])
624
+ jobs_container_entrypoint .extend ([SPARK_APP_SCRIPT_PATH ])
625
+ container_args = ["--s3_base_uri" , f"{ S3_URI } /pipeline_name" ]
626
+ container_args .extend (["--region" , session .boto_region_name ])
627
+
628
+ mock_feature_processor_config = Mock (
629
+ mode = FeatureProcessorMode .PYSPARK , inputs = [tdh .FEATURE_PROCESSOR_INPUTS ], output = "some_fg"
630
+ )
631
+ mock_feature_processor_config .mode .return_value = FeatureProcessorMode .PYSPARK
632
+
633
+ wrapped_func = Mock (
634
+ Callable ,
635
+ feature_processor_config = mock_feature_processor_config ,
636
+ job_settings = job_settings ,
637
+ wrapped_func = job_function ,
638
+ )
639
+ wrapped_func .feature_processor_config .return_value = mock_feature_processor_config
640
+ wrapped_func .job_settings .return_value = job_settings
641
+ wrapped_func .wrapped_func .return_value = job_function
642
+
643
+ with pytest .raises (
644
+ ValueError ,
645
+ match = "sm-fs-fe:created-from is a reserved tag key for to_pipeline API. Please choose another tag." ,
646
+ ):
647
+ to_pipeline (
648
+ pipeline_name = "pipeline_name" ,
649
+ step = wrapped_func ,
650
+ role = EXECUTION_ROLE_ARN ,
651
+ max_retries = 1 ,
652
+ tags = [("sm-fs-fe:created-from" , "random" )],
653
+ sagemaker_session = session ,
654
+ )
655
+
656
+
530
657
@patch (
531
658
"sagemaker.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources" ,
532
659
return_value = None ,
0 commit comments