Skip to content

Commit b898f76

Browse files
committed
fix: make configs safer
1 parent 13ad978 commit b898f76

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

src/sagemaker/modules/configs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from __future__ import absolute_import
2323

2424
from typing import Optional, Union
25-
from pydantic import BaseModel, model_validator
25+
from pydantic import BaseModel, model_validator, ConfigDict
2626

2727
import sagemaker_core.shapes as shapes
2828

@@ -94,6 +94,8 @@ class SourceCode(BaseModel):
9494
If not specified, entry_script must be provided.
9595
"""
9696

97+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
98+
9799
source_dir: Optional[str] = None
98100
requirements: Optional[str] = None
99101
entry_script: Optional[str] = None
@@ -215,5 +217,7 @@ class InputData(BaseModel):
215217
S3DataSource object, or FileSystemDataSource object.
216218
"""
217219

220+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
221+
218222
channel_name: str = None
219223
data_source: Union[str, FileSystemDataSource, S3DataSource] = None

src/sagemaker/modules/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from typing import Optional, Dict, Any, List
17-
from pydantic import BaseModel, PrivateAttr
17+
from pydantic import BaseModel, PrivateAttr, ConfigDict
1818
from sagemaker.modules.utils import safe_serialize
1919

2020

@@ -53,6 +53,8 @@ class SMP(BaseModel):
5353
parallelism or expert parallelism.
5454
"""
5555

56+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
57+
5658
hybrid_shard_degree: Optional[int] = None
5759
sm_activation_offloading: Optional[bool] = None
5860
activation_loading_horizon: Optional[int] = None
@@ -75,6 +77,8 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7577
class DistributedConfig(BaseModel):
7678
"""Base class for distributed training configurations."""
7779

80+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
81+
7882
_type: str = PrivateAttr()
7983

8084
def model_dump(self, *args, **kwargs):

src/sagemaker/modules/train/model_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,14 +792,14 @@ def _prepare_train_script(
792792
"""Prepare the training script to be executed in the training job container.
793793
794794
Args:
795-
source_code (SourceCodeConfig): The source code configuration.
795+
source_code (SourceCode): The source code configuration.
796796
"""
797797

798798
base_command = ""
799799
if source_code.command:
800800
if source_code.entry_script:
801801
logger.warning(
802-
"Both 'command' and 'entry_script' are provided in the SourceCodeConfig. "
802+
"Both 'command' and 'entry_script' are provided in the SourceCode. "
803803
+ "Defaulting to 'command'."
804804
)
805805
base_command = source_code.command.split()
@@ -831,6 +831,10 @@ def _prepare_train_script(
831831
+ "Only .py and .sh scripts are supported."
832832
)
833833
execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER
834+
else:
835+
raise ValueError(
836+
f"Invalid configuration, please provide a valid SourceCode: {source_code}"
837+
)
834838

835839
train_script = TRAIN_SCRIPT_TEMPLATE.format(
836840
working_dir=working_dir,

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_
438438
{
439439
"source_code": DEFAULT_SOURCE_CODE,
440440
"distributed": MPI(
441-
custom_mpi_options=["-x", "VAR1", "-x", "VAR2"],
441+
mpi_additional_options=["-x", "VAR1", "-x", "VAR2"],
442442
),
443443
"expected_template": EXECUTE_MPI_DRIVER,
444444
"expected_hyperparameters": {},
@@ -1059,3 +1059,12 @@ def mock_upload_data(path, bucket, key_prefix):
10591059
hyper_parameters=hyperparameters,
10601060
environment=environment,
10611061
)
1062+
1063+
1064+
def test_safe_configs():
1065+
# Test extra fails
1066+
with pytest.raises(ValueError):
1067+
SourceCode(entry_point="train.py")
1068+
# Test invalid type fails
1069+
with pytest.raises(ValueError):
1070+
SourceCode(entry_script=1)

0 commit comments

Comments
 (0)