Skip to content

Commit 0b9be17

Browse files
committed
Improve default handling
1 parent 40432b3 commit 0b9be17

File tree

2 files changed

+130
-21
lines changed

2 files changed

+130
-21
lines changed

src/sagemaker/modules/configs.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from sagemaker_core.shapes import (
3131
StoppingCondition,
3232
RetryStrategy,
33-
OutputDataConfig,
3433
Channel,
3534
ShuffleConfig,
3635
DataSource,
@@ -42,9 +41,7 @@
4241
InfraCheckConfig,
4342
RemoteDebugConfig,
4443
SessionChainingConfig,
45-
InstanceGroup,
46-
TensorBoardOutputConfig,
47-
CheckpointConfig,
44+
InstanceGroup
4845
)
4946

5047
from sagemaker.modules.utils import convert_unassigned_to_none
@@ -131,6 +128,8 @@ class Compute(shapes.ResourceConfig):
131128
subsequent training jobs.
132129
instance_groups (Optional[List[InstanceGroup]]):
133130
A list of instance groups for heterogeneous clusters to be used in the training job.
131+
training_plan_arn (Optional[str]):
132+
The Amazon Resource Name (ARN) of the training plan to use for this resource configuration.
134133
enable_managed_spot_training (Optional[bool]):
135134
To train models using managed spot training, choose True. Managed spot training
136135
provides a fully managed and scalable infrastructure for training machine learning
@@ -224,3 +223,64 @@ class InputData(BaseConfig):
224223

225224
channel_name: str = None
226225
data_source: Union[str, FileSystemDataSource, S3DataSource] = None
226+
227+
228+
class OutputDataConfig(shapes.OutputDataConfig):
229+
"""OutputDataConfig.
230+
231+
The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig``
232+
and allows the user to specify the output data configuration for the training job.
233+
234+
Parameters:
235+
s3_output_path (Optional[str]):
236+
The S3 URI where the output data will be stored. This is the location where the
237+
training job will save its output data, such as model artifacts and logs.
238+
kms_key_id (Optional[str]):
239+
The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that
240+
SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side
241+
encryption.
242+
compression_type (Optional[str]):
243+
The model output compression type. Select None to output an uncompressed model,
244+
recommended for large model outputs. Defaults to gzip.
245+
"""
246+
247+
s3_output_path: Optional[str] = None
248+
kms_key_id: Optional[str] = None
249+
compression_type: Optional[str] = None
250+
251+
252+
class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig):
253+
"""TensorBoardOutputConfig.
254+
255+
The TensorBoardOutputConfig class is a subclass of ``sagemaker_core.shapes.TensorBoardOutputConfig``
256+
and allows the user to specify the storage locations for the Amazon SageMaker
257+
Debugger TensorBoard.
258+
259+
Parameters:
260+
s3_output_path (Optional[str]):
261+
Path to Amazon S3 storage location for TensorBoard output. If not specified, will default to the default artifact location for the training job.
262+
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/``
263+
local_path (Optional[str]):
264+
Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard.
265+
"""
266+
267+
s3_output_path: Optional[str] = None
268+
local_path: Optional[str] = "/opt/ml/output/tensorboard"
269+
270+
271+
class CheckpointConfig(shapes.CheckpointConfig):
272+
"""CheckpointConfig.
273+
274+
The CheckpointConfig class is a subclass of ``sagemaker_core.shapes.CheckpointConfig``
275+
and allows the user to specify the checkpoint configuration for the training job.
276+
277+
Parameters:
278+
s3_uri (Optional[str]):
279+
Path to Amazon S3 storage location for the Checkpoint data. If not specified, will default to the default artifact location for the training job.
280+
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/``
281+
local_path (Optional[str]):
282+
The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints.
283+
"""
284+
285+
s3_uri: Optional[str] = None
286+
local_path: Optional[str] = "/opt/ml/checkpoints"

src/sagemaker/modules/train/model_trainer.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class ModelTrainer(BaseModel):
181181
The output data configuration. This is used to specify the output data location
182182
for the training job.
183183
If not specified in the session, will default to
184-
``s3://<default_bucket>/<default_prefix>/<base_job_name>/``.
184+
s3://<default_bucket>/<default_prefix>/<base_job_name>/.
185185
input_data_config (Optional[List[Union[Channel, InputData]]]):
186186
The input data config for the training job.
187187
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -477,6 +477,20 @@ def model_post_init(self, __context: Any):
477477
)
478478
logger.warning(f"Compute not provided. Using default:\n{self.compute}")
479479

480+
if self.compute.instance_type is None:
481+
self.compute.instance_type = DEFAULT_INSTANCE_TYPE
482+
logger.warning(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}")
483+
if self.compute.instance_count is None:
484+
self.compute.instance_count = 1
485+
logger.warning(
486+
f"Instance count not provided. Using default:\n{self.compute.instance_count}"
487+
)
488+
if self.compute.volume_size_in_gb is None:
489+
self.compute.volume_size_in_gb = 30
490+
logger.warning(
491+
f"Volume size not provided. Using default:\n{self.compute.volume_size_in_gb}"
492+
)
493+
480494
if self.stopping_condition is None:
481495
self.stopping_condition = StoppingCondition(
482496
max_runtime_in_seconds=3600,
@@ -486,6 +500,12 @@ def model_post_init(self, __context: Any):
486500
logger.warning(
487501
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
488502
)
503+
if self.stopping_condition.max_runtime_in_seconds is None:
504+
self.stopping_condition.max_runtime_in_seconds = 3600
505+
logger.info(
506+
"Max runtime not provided. Using default:\n"
507+
f"{self.stopping_condition.max_runtime_in_seconds}"
508+
)
489509

490510
if self.hyperparameters and isinstance(self.hyperparameters, str):
491511
if not os.path.exists(self.hyperparameters):
@@ -511,23 +531,40 @@ def model_post_init(self, __context: Any):
511531
)
512532

513533
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
514-
session = self.sagemaker_session
515-
base_job_name = self.base_job_name
516-
self.output_data_config = OutputDataConfig(
517-
s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}"
518-
f"/{base_job_name}",
519-
compression_type="GZIP",
520-
kms_key_id=None,
521-
)
522-
logger.warning(
523-
f"OutputDataConfig not provided. Using default:\n{self.output_data_config}"
524-
)
534+
if self.output_data_config is None:
535+
session = self.sagemaker_session
536+
base_job_name = self.base_job_name
537+
self.output_data_config = OutputDataConfig(
538+
s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}"
539+
f"/{base_job_name}",
540+
compression_type="GZIP",
541+
kms_key_id=None,
542+
)
543+
logger.warning(
544+
f"OutputDataConfig not provided. Using default:\n{self.output_data_config}"
545+
)
546+
if self.output_data_config.s3_output_path is None:
547+
session = self.sagemaker_session
548+
base_job_name = self.base_job_name
549+
self.output_data_config.s3_output_path = (
550+
f"s3://{self._fetch_bucket_name_and_prefix(session)}/{base_job_name}"
551+
)
552+
logger.warning(
553+
f"OutputDataConfig s3_output_path not provided. Using default:\n"
554+
f"{self.output_data_config.s3_output_path}"
555+
)
556+
if self.output_data_config.compression_type is None:
557+
self.output_data_config.compression_type = "GZIP"
558+
logger.warning(
559+
f"OutputDataConfig compression type not provided. Using default:\n"
560+
f"{self.output_data_config.compression_type}"
561+
)
525562

526-
# TODO: Autodetect which image to use if source_code is provided
527563
if self.training_image:
528564
logger.info(f"Training image URI: {self.training_image}")
529565

530-
def _fetch_bucket_name_and_prefix(self, session: Session) -> str:
566+
@staticmethod
567+
def _fetch_bucket_name_and_prefix(session: Session) -> str:
531568
"""Helper function to get the bucket name with the corresponding prefix if applicable"""
532569
if session.default_bucket_prefix is not None:
533570
return f"{session.default_bucket()}/{session.default_bucket_prefix}"
@@ -558,16 +595,28 @@ def train(
558595
"""
559596
self._populate_intelligent_defaults()
560597
current_training_job_name = _get_unique_name(self.base_job_name)
561-
input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input"
562-
if input_data_config:
598+
default_artifact_path = f"{self.base_job_name}/{current_training_job_name}"
599+
input_data_key_prefix = f"{default_artifact_path}/input"
600+
if input_data_config and self.input_data_config:
563601
self.input_data_config = input_data_config
602+
# Add missing input data channels to the existing input_data_config
603+
final_input_channel_names = {i.channel_name for i in input_data_config}
604+
for input_data in self.input_data_config:
605+
if input_data.channel_name not in final_input_channel_names:
606+
input_data_config.append(input_data)
607+
608+
self.input_data_config = input_data_config or self.input_data_config or []
564609

565-
input_data_config = []
566610
if self.input_data_config:
567611
input_data_config = self._get_input_data_config(
568612
self.input_data_config, input_data_key_prefix
569613
)
570614

615+
if self.checkpoint_config and not self.checkpoint_config.s3_uri:
616+
self.checkpoint_config.s3_uri = f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/{default_artifact_path}"
617+
if self._tensorboard_output_config and not self._tensorboard_output_config.s3_uri:
618+
self._tensorboard_output_config.s3_uri = f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/{default_artifact_path}"
619+
571620
string_hyper_parameters = {}
572621
if self.hyperparameters:
573622
for hyper_parameter, value in self.hyperparameters.items():

0 commit comments

Comments
 (0)