-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Update unit tests of kmeans, pca, factorization machines, lda and ntm #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
09f70e2
a4dbd15
ef3e66a
ea59ae8
b495043
034583d
19f4640
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,17 +11,29 @@ | |
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import pytest | ||
from mock import Mock | ||
from mock import Mock, patch | ||
|
||
from sagemaker.amazon.factorization_machines import FactorizationMachines | ||
from sagemaker.amazon.amazon_estimator import registry | ||
from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesPredictor | ||
from sagemaker.amazon.amazon_estimator import registry, RecordSet | ||
|
||
ROLE = 'myrole' | ||
TRAIN_INSTANCE_COUNT = 1 | ||
TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' | ||
NUM_FACTORS = 3 | ||
PREDICTOR_TYPE = 'regressor' | ||
|
||
COMMON_TRAIN_ARGS = {'role': 'myrole', 'train_instance_count': 1, 'train_instance_type': 'ml.c4.xlarge'} | ||
ALL_REQ_ARGS = dict({'num_factors': 3, 'predictor_type': 'regressor'}, **COMMON_TRAIN_ARGS) | ||
COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, | ||
'train_instance_type': TRAIN_INSTANCE_TYPE} | ||
ALL_REQ_ARGS = dict({'num_factors': NUM_FACTORS, 'predictor_type': PREDICTOR_TYPE}, **COMMON_TRAIN_ARGS) | ||
|
||
REGION = "us-west-2" | ||
BUCKET_NAME = "Some-Bucket" | ||
REGION = 'us-west-2' | ||
BUCKET_NAME = 'Some-Bucket' | ||
|
||
DESCRIBE_TRAINING_JOB_RESULT = { | ||
'ModelArtifacts': { | ||
'S3ModelArtifacts': 's3://bucket/model.tar.gz' | ||
} | ||
} | ||
|
||
|
||
@pytest.fixture() | ||
|
@@ -30,6 +42,8 @@ def sagemaker_session(): | |
sms = Mock(name='sagemaker_session', boto_session=boto_mock) | ||
sms.boto_region_name = REGION | ||
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) | ||
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', | ||
return_value=DESCRIBE_TRAINING_JOB_RESULT) | ||
return sms | ||
|
||
|
||
|
@@ -94,3 +108,146 @@ def test_all_hyperparameters(sagemaker_session): | |
def test_image(sagemaker_session): | ||
fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) | ||
assert fm.train_image() == registry(REGION) + '/factorization-machines:1' | ||
|
||
|
||
@pytest.mark.parametrize('required_hyper_parameters, value', [ | ||
('num_factors', 'string'), | ||
('predictor_type', 0) | ||
]) | ||
def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): | ||
with pytest.raises(ValueError): | ||
test_params = ALL_REQ_ARGS.copy() | ||
test_params[required_hyper_parameters] = value | ||
FactorizationMachines(sagemaker_session=sagemaker_session, **test_params) | ||
|
||
|
||
@pytest.mark.parametrize('required_hyper_parameters, value', [ | ||
('num_factors', 0), | ||
('predictor_type', 'string') | ||
]) | ||
def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): | ||
with pytest.raises(ValueError): | ||
test_params = ALL_REQ_ARGS.copy() | ||
test_params[required_hyper_parameters] = value | ||
FactorizationMachines(sagemaker_session=sagemaker_session, **test_params) | ||
|
||
|
||
@pytest.mark.parametrize('optional_hyper_parameters, value', [ | ||
('epochs', 'string'), | ||
('clip_gradient', 'string'), | ||
('eps', 'string'), | ||
('rescale_grad', 'string'), | ||
('bias_lr', 'string'), | ||
('linear_lr', 'string'), | ||
('factors_lr', 'string'), | ||
('bias_wd', 'string'), | ||
('linear_wd', 'string'), | ||
('factors_wd', 'string'), | ||
('bias_init_method', 0), | ||
('bias_init_scale', 'string'), | ||
('bias_init_sigma', 'string'), | ||
('bias_init_value', 'string'), | ||
('linear_init_method', 0), | ||
('linear_init_scale', 'string'), | ||
('linear_init_sigma', 'string'), | ||
('linear_init_value', 'string'), | ||
('factors_init_method', 0), | ||
('factors_init_scale', 'string'), | ||
('factors_init_sigma', 'string'), | ||
('factors_init_value', 'string') | ||
]) | ||
def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): | ||
with pytest.raises(ValueError): | ||
test_params = ALL_REQ_ARGS.copy() | ||
test_params.update({optional_hyper_parameters: value}) | ||
FactorizationMachines(sagemaker_session=sagemaker_session, **test_params) | ||
|
||
|
||
@pytest.mark.parametrize('optional_hyper_parameters, value', [ | ||
('epochs', 0), | ||
('bias_lr', -1), | ||
('linear_lr', -1), | ||
('factors_lr', -1), | ||
('bias_wd', -1), | ||
('linear_wd', -1), | ||
('factors_wd', -1), | ||
('bias_init_method', 'string'), | ||
('bias_init_scale', -1), | ||
('bias_init_sigma', -1), | ||
('linear_init_method', 'string'), | ||
('linear_init_scale', -1), | ||
('linear_init_sigma', -1), | ||
('factors_init_method', 'string'), | ||
('factors_init_scale', -1), | ||
('factors_init_sigma', -1) | ||
]) | ||
def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): | ||
with pytest.raises(ValueError): | ||
test_params = ALL_REQ_ARGS.copy() | ||
test_params.update({optional_hyper_parameters: value}) | ||
FactorizationMachines(sagemaker_session=sagemaker_session, **test_params) | ||
|
||
|
||
PREFIX = 'prefix' | ||
FEATURE_DIM = 10 | ||
MINI_BATCH_SIZE = 200 | ||
|
||
|
||
@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit') | ||
def test_call_fit(base_fit, sagemaker_session): | ||
fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) | ||
|
||
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') | ||
|
||
fm.fit(data, MINI_BATCH_SIZE) | ||
|
||
base_fit.assert_called_once() | ||
assert len(base_fit.call_args[0]) == 2 | ||
assert base_fit.call_args[0][0] == data | ||
assert base_fit.call_args[0][1] == MINI_BATCH_SIZE | ||
|
||
|
||
def test_call_fit_none_mini_batch_size(sagemaker_session): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this test asserting on? And does the base fit need to be patched? (Same question about patch applies to the tests below as well) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test doesn't assert. It tests whether fit runs successfully(no exception) when no mini_batch_size is given(for this alg, mini_batch_size should have default value). Other tests below this are similar, just test different case for mini_batch_size passed to fit(). Usually, for one algorithm, there are several cases for mini_batch_size. With default value? With valid range? Required? |
||
fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) | ||
|
||
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, | ||
channel='train') | ||
fm.fit(data) | ||
|
||
|
||
def test_call_fit_wrong_type_mini_batch_size(sagemaker_session): | ||
fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) | ||
|
||
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, | ||
channel='train') | ||
|
||
with pytest.raises((TypeError, ValueError)): | ||
fm.fit(data, 'some') | ||
|
||
|
||
def test_call_fit_wrong_value_mini_batch_size(sagemaker_session): | ||
fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) | ||
|
||
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, | ||
channel='train') | ||
with pytest.raises(ValueError): | ||
fm.fit(data, 0) | ||
|
||
|
||
def test_model_image(sagemaker_session): | ||
fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) | ||
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') | ||
fm.fit(data, MINI_BATCH_SIZE) | ||
|
||
model = fm.create_model() | ||
assert model.image == registry(REGION, 'factorization-machines') + '/factorization-machines:1' | ||
|
||
|
||
def test_predictor_type(sagemaker_session): | ||
fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) | ||
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') | ||
fm.fit(data, MINI_BATCH_SIZE) | ||
model = fm.create_model() | ||
predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) | ||
|
||
assert isinstance(predictor, FactorizationMachinesPredictor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did you get comma separated list from? The API docs seem to imply just one value: https://docs.aws.amazon.com/sagemaker/latest/dg/k-means-api-config.html
If the API docs are wrong, can you ask the algorithms team to fix the docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked with alg owner. This 'eval_metrics' should be a list. I have asked them to update the api doc.