Skip to content

Commit 8e62ab2

Browse files
committed
change: merge method inputs with class inputs
1 parent e697619 commit 8e62ab2

File tree

3 files changed

+63
-10
lines changed

3 files changed

+63
-10
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -596,11 +596,23 @@ def train(
596596
current_training_job_name = _get_unique_name(self.base_job_name)
597597
input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input"
598598

599-
self.input_data_config = input_data_config or self.input_data_config or []
599+
final_input_data_config = self.input_data_config.copy() if self.input_data_config else []
600+
601+
if input_data_config:
602+
# merge the inputs with method parameter taking precedence
603+
existing_channels = {input.channel_name: input for input in final_input_data_config}
604+
new_channels = []
605+
for new_input in input_data_config:
606+
if new_input.channel_name in existing_channels:
607+
existing_channels[new_input.channel_name] = new_input
608+
else:
609+
new_channels.append(new_input)
610+
611+
final_input_data_config = list(existing_channels.values()) + new_channels
600612

601-
if self.input_data_config:
602-
self.input_data_config = self._get_input_data_config(
603-
self.input_data_config, input_data_key_prefix
613+
if final_input_data_config:
614+
final_input_data_config = self._get_input_data_config(
615+
final_input_data_config, input_data_key_prefix
604616
)
605617

606618
if self.checkpoint_config and not self.checkpoint_config.s3_uri:
@@ -643,7 +655,7 @@ def train(
643655
data_source=self.source_code.source_dir,
644656
key_prefix=input_data_key_prefix,
645657
)
646-
self.input_data_config.append(source_code_channel)
658+
final_input_data_config.append(source_code_channel)
647659

648660
self._prepare_train_script(
649661
tmp_dir=tmp_dir,
@@ -664,7 +676,7 @@ def train(
664676
data_source=tmp_dir.name,
665677
key_prefix=input_data_key_prefix,
666678
)
667-
self.input_data_config.append(sm_drivers_channel)
679+
final_input_data_config.append(sm_drivers_channel)
668680

669681
# If source_code is provided, we will always use
670682
# the default container entrypoint and arguments
@@ -691,7 +703,7 @@ def train(
691703
training_job_name=current_training_job_name,
692704
algorithm_specification=algorithm_specification,
693705
hyper_parameters=string_hyper_parameters,
694-
input_data_config=self.input_data_config,
706+
input_data_config=final_input_data_config,
695707
resource_config=resource_config,
696708
vpc_config=vpc_config,
697709
# Public Instance Attributes
@@ -736,7 +748,7 @@ def train(
736748
sagemaker_session=self.sagemaker_session,
737749
container_entrypoint=algorithm_specification.container_entrypoint,
738750
container_arguments=algorithm_specification.container_arguments,
739-
input_data_config=self.input_data_config,
751+
input_data_config=final_input_data_config,
740752
hyper_parameters=string_hyper_parameters,
741753
environment=self.environment,
742754
)

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,3 +1258,44 @@ def mock_upload_data(path, bucket, key_prefix):
12581258

12591259
assert kwargs["tensor_board_output_config"].s3_output_path == default_base_path
12601260
assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard"
1261+
1262+
1263+
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
1264+
def test_input_merge(mock_training_job, modules_session):
1265+
model_input = InputData(channel_name="model", data_source="s3://bucket/model/model.tar.gz")
1266+
model_trainer = ModelTrainer(
1267+
training_image=DEFAULT_IMAGE,
1268+
role=DEFAULT_ROLE,
1269+
sagemaker_session=modules_session,
1270+
compute=DEFAULT_COMPUTE_CONFIG,
1271+
input_data_config=[model_input],
1272+
)
1273+
1274+
train_input = InputData(channel_name="train", data_source="s3://bucket/data/train")
1275+
model_trainer.train(input_data_config=[train_input])
1276+
1277+
mock_training_job.create.assert_called_once()
1278+
assert mock_training_job.create.call_args.kwargs["input_data_config"] == [
1279+
Channel(
1280+
channel_name="model",
1281+
data_source=DataSource(
1282+
s3_data_source=S3DataSource(
1283+
s3_data_type="S3Prefix",
1284+
s3_uri="s3://bucket/model/model.tar.gz",
1285+
s3_data_distribution_type="FullyReplicated",
1286+
)
1287+
),
1288+
input_mode="File",
1289+
),
1290+
Channel(
1291+
channel_name="train",
1292+
data_source=DataSource(
1293+
s3_data_source=S3DataSource(
1294+
s3_data_type="S3Prefix",
1295+
s3_uri="s3://bucket/data/train",
1296+
s3_data_distribution_type="FullyReplicated",
1297+
)
1298+
),
1299+
input_mode="File",
1300+
),
1301+
]

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ passenv =
8383
commands =
8484
python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')"
8585
pip install 'apache-airflow==2.10.4' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.4/constraints-3.9.txt"
86-
pip install 'torch==2.3.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html'
87-
pip install 'torchvision==0.18.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html'
86+
pip install 'torch==2.3.1' -f 'https://download.pytorch.org/whl/torch_stable.html'
87+
pip install 'torchvision==0.18.1' -f 'https://download.pytorch.org/whl/torch_stable.html'
8888
pip install 'dill>=0.3.9'
8989

9090
pytest {posargs}

0 commit comments

Comments
 (0)