|
21 | 21 | from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
|
22 | 22 |
|
23 | 23 |
|
| 24 | +def _list_check_subset(valid_super_list): |
| 25 | + valid_superset = set(valid_super_list) |
| 26 | + |
| 27 | + def validate(value): |
| 28 | + if not isinstance(value, str): |
| 29 | + return False |
| 30 | + |
| 31 | + val_list = [s.strip() for s in value.split(',')] |
| 32 | + return set(val_list).issubset(valid_superset) |
| 33 | + |
| 34 | + return validate |
| 35 | + |
| 36 | + |
24 | 37 | class Object2Vec(AmazonAlgorithmEstimatorBase):
|
25 | 38 |
|
26 | 39 | repo_name = 'object2vec'
|
@@ -57,6 +70,14 @@ class Object2Vec(AmazonAlgorithmEstimatorBase):
|
57 | 70 | 'One of "adagrad", "adam", "rmsprop", "sgd", "adadelta"', str)
|
58 | 71 | learning_rate = hp('learning_rate', (ge(1e-06), le(1.0)),
|
59 | 72 | 'A float in [1e-06, 1.0]', float)
|
| 73 | + |
| 74 | + negative_sampling_rate = hp('negative_sampling_rate', (ge(0), le(100)), 'An integer in [0, 100]', int) |
| 75 | + comparator_list = hp('comparator_list', _list_check_subset(["hadamard", "concat", "abs_diff"]), |
| 76 | + 'Comma-separated of hadamard, concat, abs_diff. E.g. "hadamard,abs_diff"', str) |
| 77 | + tied_token_embedding_weight = hp('tied_token_embedding_weight', (), 'Either True or False', bool) |
| 78 | + token_embedding_storage_type = hp('token_embedding_storage_type', isin("dense", "row_sparse"), |
| 79 | + 'One of "dense", "row_sparse"', str) |
| 80 | + |
60 | 81 | enc0_network = hp('enc0_network', isin("hcnn", "bilstm", "pooled_embedding"),
|
61 | 82 | 'One of "hcnn", "bilstm", "pooled_embedding"', str)
|
62 | 83 | enc1_network = hp('enc1_network', isin("hcnn", "bilstm", "pooled_embedding", "enc0"),
|
@@ -104,6 +125,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
|
104 | 125 | output_layer=None,
|
105 | 126 | optimizer=None,
|
106 | 127 | learning_rate=None,
|
| 128 | + negative_sampling_rate=None, |
| 129 | + comparator_list=None, |
| 130 | + tied_token_embedding_weight=None, |
| 131 | + token_embedding_storage_type=None, |
107 | 132 | enc0_network=None,
|
108 | 133 | enc1_network=None,
|
109 | 134 | enc0_cnn_filter_width=None,
|
@@ -164,6 +189,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
|
164 | 189 | output_layer(str): Optional. Type of output layer
|
165 | 190 | optimizer(str): Optional. Type of optimizer for training
|
166 | 191 | learning_rate(float): Optional. Learning rate for SGD training
|
| 192 | + negative_sampling_rate(int): Optional. Negative sampling rate |
| 193 | + comparator_list(str): Optional. Customization of comparator operator |
| 194 | + tied_token_embedding_weight(bool): Optional. Tying of token embedding layer weight |
| 195 | + token_embedding_storage_type(str): Optional. Type of token embedding storage |
167 | 196 | enc0_network(str): Optional. Network model of encoder "enc0"
|
168 | 197 | enc1_network(str): Optional. Network model of encoder "enc1"
|
169 | 198 | enc0_cnn_filter_width(int): Optional. CNN filter width
|
@@ -197,6 +226,12 @@ def __init__(self, role, train_instance_count, train_instance_type,
|
197 | 226 | self.output_layer = output_layer
|
198 | 227 | self.optimizer = optimizer
|
199 | 228 | self.learning_rate = learning_rate
|
| 229 | + |
| 230 | + self.negative_sampling_rate = negative_sampling_rate |
| 231 | + self.comparator_list = comparator_list |
| 232 | + self.tied_token_embedding_weight = tied_token_embedding_weight |
| 233 | + self.token_embedding_storage_type = token_embedding_storage_type |
| 234 | + |
200 | 235 | self.enc0_network = enc0_network
|
201 | 236 | self.enc1_network = enc1_network
|
202 | 237 | self.enc0_cnn_filter_width = enc0_cnn_filter_width
|
|
0 commit comments