@@ -425,7 +425,7 @@ def compile_model(
425
425
LOGGER .info ("Creating compilation-job with name: %s" , job_name )
426
426
self .sagemaker_client .create_compilation_job (** compilation_job_request )
427
427
428
- def tune (
428
+ def tune ( # noqa: C901
429
429
self ,
430
430
job_name ,
431
431
strategy ,
@@ -450,6 +450,9 @@ def tune(
450
450
early_stopping_type = "Off" ,
451
451
encrypt_inter_container_traffic = False ,
452
452
vpc_config = None ,
453
+ train_use_spot_instances = False ,
454
+ checkpoint_s3_uri = None ,
455
+ checkpoint_local_path = None ,
453
456
):
454
457
"""Create an Amazon SageMaker hyperparameter tuning job
455
458
@@ -512,6 +515,18 @@ def tune(
512
515
The key in vpc_config is 'Subnets'.
513
516
* security_group_ids (list[str]): List of security group ids.
514
517
The key in vpc_config is 'SecurityGroupIds'.
518
+ train_use_spot_instances (bool): whether to use spot instances for training.
519
+ checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
520
+ that the algorithm persists (if any) during training. (default:
521
+ ``None``).
522
+ checkpoint_local_path (str): The local path that the algorithm
523
+ writes its checkpoints to. SageMaker will persist all files
524
+ under this path to `checkpoint_s3_uri` continually during
525
+ training. On job startup the reverse happens - data from the
526
+ s3 location is downloaded to this path before the algorithm is
527
+ started. If the path is unset then SageMaker assumes the
528
+ checkpoints will be provided under `/opt/ml/checkpoints/`.
529
+ (default: ``None``).
515
530
516
531
"""
517
532
tune_request = {
@@ -569,6 +584,15 @@ def tune(
569
584
if encrypt_inter_container_traffic :
570
585
tune_request ["TrainingJobDefinition" ]["EnableInterContainerTrafficEncryption" ] = True
571
586
587
+ if train_use_spot_instances :
588
+ tune_request ["TrainingJobDefinition" ]["EnableManagedSpotTraining" ] = True
589
+
590
+ if checkpoint_s3_uri :
591
+ checkpoint_config = {"S3Uri" : checkpoint_s3_uri }
592
+ if checkpoint_local_path :
593
+ checkpoint_config ["LocalPath" ] = checkpoint_local_path
594
+ tune_request ["TrainingJobDefinition" ]["CheckpointConfig" ] = checkpoint_config
595
+
572
596
LOGGER .info ("Creating hyperparameter tuning job with name: %s" , job_name )
573
597
LOGGER .debug ("tune request: %s" , json .dumps (tune_request , indent = 4 ))
574
598
self .sagemaker_client .create_hyper_parameter_tuning_job (** tune_request )
0 commit comments