Skip to content

Commit 3128278

Browse files
committed
Create BaseConfig
1 parent 3b05bbd commit 3128278

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

src/sagemaker/modules/configs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,13 @@
7474
]
7575

7676

77-
class SourceCode(BaseModel):
77+
class BaseConfig(BaseModel):
78+
"""BaseConfig"""
79+
80+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
81+
82+
83+
class SourceCode(BaseConfig):
7884
"""SourceCode.
7985
8086
The SourceCode class allows the user to specify the source code location, dependencies,
@@ -94,8 +100,6 @@ class SourceCode(BaseModel):
94100
If not specified, entry_script must be provided.
95101
"""
96102

97-
model_config = ConfigDict(validate_assignment=True, extra="forbid")
98-
99103
source_dir: Optional[str] = None
100104
requirements: Optional[str] = None
101105
entry_script: Optional[str] = None
@@ -196,7 +200,7 @@ def _to_vpc_config(self) -> shapes.VpcConfig:
196200
return shapes.VpcConfig(**filtered_dict)
197201

198202

199-
class InputData(BaseModel):
203+
class InputData(BaseConfig):
200204
"""InputData.
201205
202206
This config allows the user to specify an input data source for the training job.
@@ -217,7 +221,5 @@ class InputData(BaseModel):
217221
S3DataSource object, or FileSystemDataSource object.
218222
"""
219223

220-
model_config = ConfigDict(validate_assignment=True, extra="forbid")
221-
222224
channel_name: str = None
223225
data_source: Union[str, FileSystemDataSource, S3DataSource] = None

src/sagemaker/modules/distributed.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
from typing import Optional, Dict, Any, List
17-
from pydantic import BaseModel, PrivateAttr, ConfigDict
17+
from pydantic import PrivateAttr
1818
from sagemaker.modules.utils import safe_serialize
19+
from sagemaker.modules.configs import BaseConfig
1920

2021

21-
class SMP(BaseModel):
22+
class SMP(BaseConfig):
2223
"""SMP.
2324
2425
This class is used for configuring the SageMaker Model Parallelism v2 parameters.
@@ -53,8 +54,6 @@ class SMP(BaseModel):
5354
parallelism or expert parallelism.
5455
"""
5556

56-
model_config = ConfigDict(validate_assignment=True, extra="forbid")
57-
5857
hybrid_shard_degree: Optional[int] = None
5958
sm_activation_offloading: Optional[bool] = None
6059
activation_loading_horizon: Optional[int] = None
@@ -74,11 +73,9 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7473
return hyperparameters
7574

7675

77-
class DistributedConfig(BaseModel):
76+
class DistributedConfig(BaseConfig):
7877
"""Base class for distributed training configurations."""
7978

80-
model_config = ConfigDict(validate_assignment=True, extra="forbid")
81-
8279
_type: str = PrivateAttr()
8380

8481
def model_dump(self, *args, **kwargs):

0 commit comments

Comments
 (0)