@@ -627,8 +627,12 @@ class Framework(EstimatorBase):
627
627
such as training/deployment images and predictor instances.
628
628
"""
629
629
630
+ _DISTRIBUTION_SUPPORTED_FRAMEWORKS = ('mxnet' ,)
631
+ LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
632
+
630
633
def __init__ (self , entry_point , source_dir = None , hyperparameters = None , enable_cloudwatch_metrics = False ,
631
- container_log_level = logging .INFO , code_location = None , image_name = None , ** kwargs ):
634
+ container_log_level = logging .INFO , code_location = None , image_name = None ,
635
+ distributions = None , ** kwargs ):
632
636
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
633
637
634
638
Args:
@@ -650,6 +654,8 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
650
654
image_name (str): An alternate image name to use instead of the official Sagemaker image
651
655
for the framework. This is useful to run one of the Sagemaker supported frameworks
652
656
with an image containing custom dependencies.
657
+ distributions (dict): A dictionary with information on how to run distributed training
658
+ (default: None).
653
659
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
654
660
"""
655
661
super (Framework , self ).__init__ (** kwargs )
@@ -660,10 +666,27 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
660
666
DeprecationWarning )
661
667
self .enable_cloudwatch_metrics = False
662
668
self .container_log_level = container_log_level
663
- self ._hyperparameters = hyperparameters or {}
664
669
self .code_location = code_location
665
670
self .image_name = image_name
666
671
672
+ self ._hyperparameters = hyperparameters or {}
673
+ self ._configure_distributions (distributions )
674
+
675
+ def _configure_distributions (self , distributions ):
676
+ if distributions is None :
677
+ return
678
+
679
+ if self .__framework_name__ not in self ._DISTRIBUTION_SUPPORTED_FRAMEWORKS :
680
+ raise ValueError ('This framework does not support the distributions option.' )
681
+
682
+ if self .framework_version .split ('.' ) < self ._LOWEST_SCRIPT_MODE_VERSION :
683
+ raise ValueError ('The distributions option is valid for only versions {} and higher'
684
+ .format ('.' .join (self ._LOWEST_SCRIPT_MODE_VERSION )))
685
+
686
+ if 'parameter_server' in distributions :
687
+ enabled = distributions ['parameter_server' ].get ('enabled' , False )
688
+ self ._hyperparameters [self .LAUNCH_PS_ENV_NAME ] = enabled
689
+
667
690
def _prepare_for_training (self , job_name = None ):
668
691
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
669
692
0 commit comments