@@ -181,7 +181,7 @@ class ModelTrainer(BaseModel):
181
181
The output data configuration. This is used to specify the output data location
182
182
for the training job.
183
183
If not specified in the session, will default to
184
- `` s3://<default_bucket>/<default_prefix>/<base_job_name>/`` .
184
+ s3://<default_bucket>/<default_prefix>/<base_job_name>/.
185
185
input_data_config (Optional[List[Union[Channel, InputData]]]):
186
186
The input data config for the training job.
187
187
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -477,6 +477,20 @@ def model_post_init(self, __context: Any):
477
477
)
478
478
logger .warning (f"Compute not provided. Using default:\n { self .compute } " )
479
479
480
+ if self .compute .instance_type is None :
481
+ self .compute .instance_type = DEFAULT_INSTANCE_TYPE
482
+ logger .warning (f"Instance type not provided. Using default:\n { DEFAULT_INSTANCE_TYPE } " )
483
+ if self .compute .instance_count is None :
484
+ self .compute .instance_count = 1
485
+ logger .warning (
486
+ f"Instance count not provided. Using default:\n { self .compute .instance_count } "
487
+ )
488
+ if self .compute .volume_size_in_gb is None :
489
+ self .compute .volume_size_in_gb = 30
490
+ logger .warning (
491
+ f"Volume size not provided. Using default:\n { self .compute .volume_size_in_gb } "
492
+ )
493
+
480
494
if self .stopping_condition is None :
481
495
self .stopping_condition = StoppingCondition (
482
496
max_runtime_in_seconds = 3600 ,
@@ -486,6 +500,12 @@ def model_post_init(self, __context: Any):
486
500
logger .warning (
487
501
f"StoppingCondition not provided. Using default:\n { self .stopping_condition } "
488
502
)
503
+ if self .stopping_condition .max_runtime_in_seconds is None :
504
+ self .stopping_condition .max_runtime_in_seconds = 3600
505
+ logger .info (
506
+ "Max runtime not provided. Using default:\n "
507
+ f"{ self .stopping_condition .max_runtime_in_seconds } "
508
+ )
489
509
490
510
if self .hyperparameters and isinstance (self .hyperparameters , str ):
491
511
if not os .path .exists (self .hyperparameters ):
@@ -511,23 +531,40 @@ def model_post_init(self, __context: Any):
511
531
)
512
532
513
533
if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB and self .output_data_config is None :
514
- session = self .sagemaker_session
515
- base_job_name = self .base_job_name
516
- self .output_data_config = OutputDataConfig (
517
- s3_output_path = f"s3://{ self ._fetch_bucket_name_and_prefix (session )} "
518
- f"/{ base_job_name } " ,
519
- compression_type = "GZIP" ,
520
- kms_key_id = None ,
521
- )
522
- logger .warning (
523
- f"OutputDataConfig not provided. Using default:\n { self .output_data_config } "
524
- )
534
+ if self .output_data_config is None :
535
+ session = self .sagemaker_session
536
+ base_job_name = self .base_job_name
537
+ self .output_data_config = OutputDataConfig (
538
+ s3_output_path = f"s3://{ self ._fetch_bucket_name_and_prefix (session )} "
539
+ f"/{ base_job_name } " ,
540
+ compression_type = "GZIP" ,
541
+ kms_key_id = None ,
542
+ )
543
+ logger .warning (
544
+ f"OutputDataConfig not provided. Using default:\n { self .output_data_config } "
545
+ )
546
+ if self .output_data_config .s3_output_path is None :
547
+ session = self .sagemaker_session
548
+ base_job_name = self .base_job_name
549
+ self .output_data_config .s3_output_path = (
550
+ f"s3://{ self ._fetch_bucket_name_and_prefix (session )} /{ base_job_name } "
551
+ )
552
+ logger .warning (
553
+ f"OutputDataConfig s3_output_path not provided. Using default:\n "
554
+ f"{ self .output_data_config .s3_output_path } "
555
+ )
556
+ if self .output_data_config .compression_type is None :
557
+ self .output_data_config .compression_type = "GZIP"
558
+ logger .warning (
559
+ f"OutputDataConfig compression type not provided. Using default:\n "
560
+ f"{ self .output_data_config .compression_type } "
561
+ )
525
562
526
- # TODO: Autodetect which image to use if source_code is provided
527
563
if self .training_image :
528
564
logger .info (f"Training image URI: { self .training_image } " )
529
565
530
- def _fetch_bucket_name_and_prefix (self , session : Session ) -> str :
566
+ @staticmethod
567
+ def _fetch_bucket_name_and_prefix (session : Session ) -> str :
531
568
"""Helper function to get the bucket name with the corresponding prefix if applicable"""
532
569
if session .default_bucket_prefix is not None :
533
570
return f"{ session .default_bucket ()} /{ session .default_bucket_prefix } "
@@ -558,16 +595,28 @@ def train(
558
595
"""
559
596
self ._populate_intelligent_defaults ()
560
597
current_training_job_name = _get_unique_name (self .base_job_name )
561
- input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
562
- if input_data_config :
598
+ default_artifact_path = f"{ self .base_job_name } /{ current_training_job_name } "
599
+ input_data_key_prefix = f"{ default_artifact_path } /input"
600
+ if input_data_config and self .input_data_config :
563
601
self .input_data_config = input_data_config
602
+ # Add missing input data channels to the existing input_data_config
603
+ final_input_channel_names = {i .channel_name for i in input_data_config }
604
+ for input_data in self .input_data_config :
605
+ if input_data .channel_name not in final_input_channel_names :
606
+ input_data_config .append (input_data )
607
+
608
+ self .input_data_config = input_data_config or self .input_data_config or []
564
609
565
- input_data_config = []
566
610
if self .input_data_config :
567
611
input_data_config = self ._get_input_data_config (
568
612
self .input_data_config , input_data_key_prefix
569
613
)
570
614
615
+ if self .checkpoint_config and not self .checkpoint_config .s3_uri :
616
+ self .checkpoint_config .s3_uri = f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /{ default_artifact_path } "
617
+ if self ._tensorboard_output_config and not self ._tensorboard_output_config .s3_uri :
618
+ self ._tensorboard_output_config .s3_uri = f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /{ default_artifact_path } "
619
+
571
620
string_hyper_parameters = {}
572
621
if self .hyperparameters :
573
622
for hyper_parameter , value in self .hyperparameters .items ():
0 commit comments