@@ -596,11 +596,23 @@ def train(
596
596
current_training_job_name = _get_unique_name (self .base_job_name )
597
597
input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
598
598
599
- self .input_data_config = input_data_config or self .input_data_config or []
599
+ final_input_data_config = self .input_data_config .copy () if self .input_data_config else []
600
+
601
+ if input_data_config :
602
+ # merge the inputs with method parameter taking precedence
603
+ existing_channels = {input .channel_name : input for input in final_input_data_config }
604
+ new_channels = []
605
+ for new_input in input_data_config :
606
+ if new_input .channel_name in existing_channels :
607
+ existing_channels [new_input .channel_name ] = new_input
608
+ else :
609
+ new_channels .append (new_input )
610
+
611
+ final_input_data_config = list (existing_channels .values ()) + new_channels
600
612
601
- if self . input_data_config :
602
- self . input_data_config = self ._get_input_data_config (
603
- self . input_data_config , input_data_key_prefix
613
+ if final_input_data_config :
614
+ final_input_data_config = self ._get_input_data_config (
615
+ final_input_data_config , input_data_key_prefix
604
616
)
605
617
606
618
if self .checkpoint_config and not self .checkpoint_config .s3_uri :
@@ -643,7 +655,7 @@ def train(
643
655
data_source = self .source_code .source_dir ,
644
656
key_prefix = input_data_key_prefix ,
645
657
)
646
- self . input_data_config .append (source_code_channel )
658
+ final_input_data_config .append (source_code_channel )
647
659
648
660
self ._prepare_train_script (
649
661
tmp_dir = tmp_dir ,
@@ -664,7 +676,7 @@ def train(
664
676
data_source = tmp_dir .name ,
665
677
key_prefix = input_data_key_prefix ,
666
678
)
667
- self . input_data_config .append (sm_drivers_channel )
679
+ final_input_data_config .append (sm_drivers_channel )
668
680
669
681
# If source_code is provided, we will always use
670
682
# the default container entrypoint and arguments
@@ -691,7 +703,7 @@ def train(
691
703
training_job_name = current_training_job_name ,
692
704
algorithm_specification = algorithm_specification ,
693
705
hyper_parameters = string_hyper_parameters ,
694
- input_data_config = self . input_data_config ,
706
+ input_data_config = final_input_data_config ,
695
707
resource_config = resource_config ,
696
708
vpc_config = vpc_config ,
697
709
# Public Instance Attributes
@@ -736,7 +748,7 @@ def train(
736
748
sagemaker_session = self .sagemaker_session ,
737
749
container_entrypoint = algorithm_specification .container_entrypoint ,
738
750
container_arguments = algorithm_specification .container_arguments ,
739
- input_data_config = self . input_data_config ,
751
+ input_data_config = final_input_data_config ,
740
752
hyper_parameters = string_hyper_parameters ,
741
753
environment = self .environment ,
742
754
)
0 commit comments