Skip to content

Commit 11994bd

Browse files
pnpnpnjesterhazy
authored andcommitted
feature: add document embedding support to Object2Vec algorithm (#772)
1 parent c76d6a9 commit 11994bd

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

src/sagemaker/amazon/object2vec.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@
2121
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2222

2323

24+
def _list_check_subset(valid_super_list):
25+
valid_superset = set(valid_super_list)
26+
27+
def validate(value):
28+
if not isinstance(value, str):
29+
return False
30+
31+
val_list = [s.strip() for s in value.split(',')]
32+
return set(val_list).issubset(valid_superset)
33+
34+
return validate
35+
36+
2437
class Object2Vec(AmazonAlgorithmEstimatorBase):
2538

2639
repo_name = 'object2vec'
@@ -57,6 +70,14 @@ class Object2Vec(AmazonAlgorithmEstimatorBase):
5770
'One of "adagrad", "adam", "rmsprop", "sgd", "adadelta"', str)
5871
learning_rate = hp('learning_rate', (ge(1e-06), le(1.0)),
5972
'A float in [1e-06, 1.0]', float)
73+
74+
negative_sampling_rate = hp('negative_sampling_rate', (ge(0), le(100)), 'An integer in [0, 100]', int)
75+
comparator_list = hp('comparator_list', _list_check_subset(["hadamard", "concat", "abs_diff"]),
76+
'Comma-separated of hadamard, concat, abs_diff. E.g. "hadamard,abs_diff"', str)
77+
tied_token_embedding_weight = hp('tied_token_embedding_weight', (), 'Either True or False', bool)
78+
token_embedding_storage_type = hp('token_embedding_storage_type', isin("dense", "row_sparse"),
79+
'One of "dense", "row_sparse"', str)
80+
6081
enc0_network = hp('enc0_network', isin("hcnn", "bilstm", "pooled_embedding"),
6182
'One of "hcnn", "bilstm", "pooled_embedding"', str)
6283
enc1_network = hp('enc1_network', isin("hcnn", "bilstm", "pooled_embedding", "enc0"),
@@ -104,6 +125,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
104125
output_layer=None,
105126
optimizer=None,
106127
learning_rate=None,
128+
negative_sampling_rate=None,
129+
comparator_list=None,
130+
tied_token_embedding_weight=None,
131+
token_embedding_storage_type=None,
107132
enc0_network=None,
108133
enc1_network=None,
109134
enc0_cnn_filter_width=None,
@@ -164,6 +189,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
164189
output_layer(str): Optional. Type of output layer
165190
optimizer(str): Optional. Type of optimizer for training
166191
learning_rate(float): Optional. Learning rate for SGD training
192+
negative_sampling_rate(int): Optional. Negative sampling rate
193+
comparator_list(str): Optional. Customization of comparator operator
194+
tied_token_embedding_weight(bool): Optional. Tying of token embedding layer weight
195+
token_embedding_storage_type(str): Optional. Type of token embedding storage
167196
enc0_network(str): Optional. Network model of encoder "enc0"
168197
enc1_network(str): Optional. Network model of encoder "enc1"
169198
enc0_cnn_filter_width(int): Optional. CNN filter width
@@ -197,6 +226,12 @@ def __init__(self, role, train_instance_count, train_instance_type,
197226
self.output_layer = output_layer
198227
self.optimizer = optimizer
199228
self.learning_rate = learning_rate
229+
230+
self.negative_sampling_rate = negative_sampling_rate
231+
self.comparator_list = comparator_list
232+
self.tied_token_embedding_weight = tied_token_embedding_weight
233+
self.token_embedding_storage_type = token_embedding_storage_type
234+
200235
self.enc0_network = enc0_network
201236
self.enc1_network = enc1_network
202237
self.enc0_cnn_filter_width = enc0_cnn_filter_width

tests/integ/test_object2vec.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def test_object2vec(sagemaker_session):
4343
enc0_vocab_size=45000,
4444
enc_dim=16,
4545
num_classes=3,
46+
negative_sampling_rate=0,
47+
comparator_list='hadamard,concat,abs_diff',
48+
tied_token_embedding_weight=False,
49+
token_embedding_storage_type='dense',
4650
sagemaker_session=sagemaker_session)
4751

4852
record_set = prepare_record_set_from_local_files(data_path, object2vec.data_location,

tests/unit/test_object2vec.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def test_all_hyperparameters(sagemaker_session):
111111
output_layer='softmax',
112112
optimizer='adam',
113113
learning_rate=0.0001,
114+
negative_sampling_rate=1,
115+
comparator_list='hadamard, abs_diff',
116+
tied_token_embedding_weight=True,
117+
token_embedding_storage_type='row_sparse',
114118
enc0_network='bilstm',
115119
enc1_network='hcnn',
116120
enc0_cnn_filter_width=3,
@@ -161,7 +165,11 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param
161165
('optimizer', 0),
162166
('enc0_cnn_filter_width', 'string'),
163167
('weight_decay', 'string'),
164-
('learning_rate', 'string')
168+
('learning_rate', 'string'),
169+
('negative_sampling_rate', 'some_string'),
170+
('comparator_list', 0),
171+
('comparator_list', ['foobar']),
172+
('token_embedding_storage_type', 123),
165173
])
166174
def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value):
167175
with pytest.raises(ValueError):
@@ -182,7 +190,10 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame
182190
('weight_decay', 200000),
183191
('enc0_cnn_filter_width', 2000),
184192
('learning_rate', 0),
185-
('learning_rate', 2)
193+
('learning_rate', 2),
194+
('negative_sampling_rate', -1),
195+
('comparator_list', 'hadamard,foobar'),
196+
('token_embedding_storage_type', 'foobar'),
186197
])
187198
def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value):
188199
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)