13
13
from sagemaker .amazon .amazon_estimator import AmazonAlgorithmEstimatorBase , registry
14
14
from sagemaker .amazon .common import numpy_to_record_serializer , record_deserializer
15
15
from sagemaker .amazon .hyperparameter import Hyperparameter as hp # noqa
16
- from sagemaker .amazon .validation import gt , isin , ge
16
+ from sagemaker .amazon .validation import gt , isin , ge , le
17
17
from sagemaker .predictor import RealTimePredictor
18
18
from sagemaker .model import Model
19
19
from sagemaker .session import Session
@@ -27,16 +27,18 @@ class KMeans(AmazonAlgorithmEstimatorBase):
27
27
k = hp ('k' , gt (1 ), 'An integer greater-than 1' , int )
28
28
init_method = hp ('init_method' , isin ('random' , 'kmeans++' ), 'One of "random", "kmeans++"' , str )
29
29
max_iterations = hp ('local_lloyd_max_iterations' , gt (0 ), 'An integer greater-than 0' , int )
30
- tol = hp ('local_lloyd_tol' , gt ( 0 ), 'An integer greater-than 0 ' , int )
30
+ tol = hp ('local_lloyd_tol' , ( ge ( 0 ), le ( 1 )), 'An float in [0, 1] ' , float )
31
31
num_trials = hp ('local_lloyd_num_trials' , gt (0 ), 'An integer greater-than 0' , int )
32
32
local_init_method = hp ('local_lloyd_init_method' , isin ('random' , 'kmeans++' ), 'One of "random", "kmeans++"' , str )
33
33
half_life_time_size = hp ('half_life_time_size' , ge (0 ), 'An integer greater-than-or-equal-to 0' , int )
34
34
epochs = hp ('epochs' , gt (0 ), 'An integer greater-than 0' , int )
35
35
center_factor = hp ('extra_center_factor' , gt (0 ), 'An integer greater-than 0' , int )
36
+ eval_metrics = hp (name = 'eval_metrics' , validation_message = 'A comma separated list of "msd" or "ssd"' ,
37
+ data_type = list )
36
38
37
39
def __init__ (self , role , train_instance_count , train_instance_type , k , init_method = None ,
38
40
max_iterations = None , tol = None , num_trials = None , local_init_method = None ,
39
- half_life_time_size = None , epochs = None , center_factor = None , ** kwargs ):
41
+ half_life_time_size = None , epochs = None , center_factor = None , eval_metrics = None , ** kwargs ):
40
42
"""
41
43
A k-means clustering :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`. Finds k clusters of data in an
42
44
unlabeled dataset.
@@ -70,7 +72,7 @@ def __init__(self, role, train_instance_count, train_instance_type, k, init_meth
70
72
k (int): The number of clusters to produce.
71
73
init_method (str): How to initialize cluster locations. One of 'random' or 'kmeans++'.
72
74
max_iterations (int): Maximum iterations for Lloyds EM procedure in the local kmeans used in finalize stage.
73
- tol (int ): Tolerance for change in ssd for early stopping in local kmeans.
75
+ tol (float ): Tolerance for change in ssd for early stopping in local kmeans.
74
76
num_trials (int): Local version is run multiple times and the one with the best loss is chosen. This
75
77
determines how many times.
76
78
local_init_method (str): Initialization method for local version. One of 'random', 'kmeans++'
@@ -82,6 +84,9 @@ def __init__(self, role, train_instance_count, train_instance_type, k, init_meth
82
84
epochs (int): Number of passes done over the training data.
83
85
center_factor(int): The algorithm will create ``num_clusters * extra_center_factor`` as it runs and
84
86
reduce the number of centers to ``k`` when finalizing
87
+ eval_metrics(list): JSON list of metrics types to be used for reporting the score for the model.
88
+ Allowed values are "msd" Means Square Error, "ssd": Sum of square distance. If test data is provided,
89
+ the score shall be reported in terms of all requested metrics.
85
90
**kwargs: base class keyword argument values.
86
91
"""
87
92
super (KMeans , self ).__init__ (role , train_instance_count , train_instance_type , ** kwargs )
@@ -94,6 +99,7 @@ def __init__(self, role, train_instance_count, train_instance_type, k, init_meth
94
99
self .half_life_time_size = half_life_time_size
95
100
self .epochs = epochs
96
101
self .center_factor = center_factor
102
+ self .eval_metrics = eval_metrics
97
103
98
104
def create_model (self ):
99
105
"""Return a :class:`~sagemaker.amazon.kmeans.KMeansModel` referencing the latest
0 commit comments