Skip to content

fix: forbid extras in Configs #5042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/sagemaker/modules/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import

from typing import Optional, Union
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, model_validator, ConfigDict

import sagemaker_core.shapes as shapes

Expand Down Expand Up @@ -74,7 +74,13 @@
]


class SourceCode(BaseModel):
class BaseConfig(BaseModel):
"""BaseConfig"""

model_config = ConfigDict(validate_assignment=True, extra="forbid")


class SourceCode(BaseConfig):
"""SourceCode.

The SourceCode class allows the user to specify the source code location, dependencies,
Expand Down Expand Up @@ -194,7 +200,7 @@ def _to_vpc_config(self) -> shapes.VpcConfig:
return shapes.VpcConfig(**filtered_dict)


class InputData(BaseModel):
class InputData(BaseConfig):
"""InputData.

This config allows the user to specify an input data source for the training job.
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/modules/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from __future__ import absolute_import

from typing import Optional, Dict, Any, List
from pydantic import BaseModel, PrivateAttr
from pydantic import PrivateAttr
from sagemaker.modules.utils import safe_serialize
from sagemaker.modules.configs import BaseConfig


class SMP(BaseModel):
class SMP(BaseConfig):
"""SMP.

This class is used for configuring the SageMaker Model Parallelism v2 parameters.
Expand Down Expand Up @@ -72,7 +73,7 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
return hyperparameters


class DistributedConfig(BaseModel):
class DistributedConfig(BaseConfig):
"""Base class for distributed training configurations."""

_type: str = PrivateAttr()
Expand Down
22 changes: 16 additions & 6 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ class ModelTrainer(BaseModel):
"LOCAL_CONTAINER" mode.
"""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid"
)

training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
sagemaker_session: Optional[Session] = None
Expand Down Expand Up @@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):

def __del__(self):
"""Destructor method to clean up the temporary directory."""
# Clean up the temporary directory if it exists
if self._temp_recipe_train_dir is not None:
self._temp_recipe_train_dir.cleanup()
# Clean up the temporary directory if it exists and class was initialized
if hasattr(self, "__pydantic_fields_set__"):
if self._temp_recipe_train_dir is not None:
self._temp_recipe_train_dir.cleanup()

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
Expand Down Expand Up @@ -792,14 +795,14 @@ def _prepare_train_script(
"""Prepare the training script to be executed in the training job container.

Args:
source_code (SourceCodeConfig): The source code configuration.
source_code (SourceCode): The source code configuration.
"""

base_command = ""
if source_code.command:
if source_code.entry_script:
logger.warning(
"Both 'command' and 'entry_script' are provided in the SourceCodeConfig. "
"Both 'command' and 'entry_script' are provided in the SourceCode. "
+ "Defaulting to 'command'."
)
base_command = source_code.command.split()
Expand Down Expand Up @@ -831,6 +834,13 @@ def _prepare_train_script(
+ "Only .py and .sh scripts are supported."
)
execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER
else:
# This should never be reached, as the source_code should have been validated.
raise ValueError(
f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}."
+ "Please provide a valid configuration with atleast one of 'command'"
+ " or entry_script'."
)

train_script = TRAIN_SCRIPT_TEMPLATE.format(
working_dir=working_dir,
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/sagemaker/modules/train/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import pytest
from pydantic import ValidationError
from unittest.mock import patch, MagicMock, ANY

from sagemaker import image_uris
Expand Down Expand Up @@ -438,7 +439,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_
{
"source_code": DEFAULT_SOURCE_CODE,
"distributed": MPI(
custom_mpi_options=["-x", "VAR1", "-x", "VAR2"],
mpi_additional_options=["-x", "VAR1", "-x", "VAR2"],
),
"expected_template": EXECUTE_MPI_DRIVER,
"expected_hyperparameters": {},
Expand Down Expand Up @@ -1059,3 +1060,36 @@ def mock_upload_data(path, bucket, key_prefix):
hyper_parameters=hyperparameters,
environment=environment,
)


def test_safe_configs():
# Test extra fails
with pytest.raises(ValueError):
SourceCode(entry_point="train.py")
# Test invalid type fails
with pytest.raises(ValueError):
SourceCode(entry_script=1)


@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory")
def test_destructor_cleanup(mock_tmp_dir, modules_session):

with pytest.raises(ValidationError):
model_trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute="test",
)
mock_tmp_dir.cleanup.assert_not_called()

model_trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
)
model_trainer._temp_recipe_train_dir = mock_tmp_dir
mock_tmp_dir.assert_not_called()
del model_trainer
mock_tmp_dir.cleanup.assert_called_once()