-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Update unit tests of kmeans, pca, factorization machines, lda and ntm #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #103 +/- ##
==========================================
+ Coverage 89.74% 91.32% +1.57%
==========================================
Files 34 34
Lines 2039 2040 +1
==========================================
+ Hits 1830 1863 +33
+ Misses 209 177 -32
Continue to review full report at Codecov.
|
num_trials = hp('local_lloyd_num_trials', gt(0), 'An integer greater-than 0', int) | ||
local_init_method = hp('local_lloyd_init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str) | ||
half_life_time_size = hp('half_life_time_size', ge(0), 'An integer greater-than-or-equal-to 0', int) | ||
epochs = hp('epochs', gt(0), 'An integer greater-than 0', int) | ||
center_factor = hp('extra_center_factor', gt(0), 'An integer greater-than 0', int) | ||
eval_metrics = hp(name='eval_metrics', validation_message='A comma separated list of "msd" or "ssd"', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did you get comma separated list from? The API docs seem to imply just one value: https://docs.aws.amazon.com/sagemaker/latest/dg/k-means-api-config.html
If the API docs are wrong, can you ask the algorithms team to fix the docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked with alg owner. This 'eval_metrics' should be a list. I have asked them to update the api doc.
src/sagemaker/amazon/pca.py
Outdated
algorithm_mode = hp(name='algorithm_mode', validate=lambda x: x in ['regular', 'stable', 'randomized'], | ||
validation_message='Value must be one of "regular", "stable", "randomized"', data_type=str) | ||
num_components = hp('num_components', gt(0), 'Value must be an integer greater than zero', int) | ||
algorithm_mode = hp('algorithm_mode', isin('regular', 'stable', 'randomized'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked with alg owner. The 'stable' value is not supported for now. I have removed it.
src/sagemaker/amazon/pca.py
Outdated
subtract_mean = hp(name='subtract_mean', validation_message='Value must be a boolean', data_type=bool) | ||
extra_components = hp(name='extra_components', validate=lambda x: x >= 0, | ||
validation_message="Value must be an integer greater than or equal to 0", data_type=int) | ||
extra_components = hp('extra_components', ge(0), "Value must be an integer greater than or equal to 0", int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe edit the description to state that you should leave this unset if you want the behavior for -1: https://docs.aws.amazon.com/sagemaker/latest/dg/PCA-reference.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I removed the validator for behavior of -1.
tests/unit/test_fm.py
Outdated
@@ -94,3 +108,282 @@ def test_all_hyperparameters(sagemaker_session): | |||
def test_image(sagemaker_session): | |||
fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) | |||
assert fm.train_image() == registry(REGION) + '/factorization-machines:1' | |||
|
|||
|
|||
def test_num_factors_validation_fail_type(sagemaker_session): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend using parameterized test functions to make these tests significantly more concise for each: https://docs.pytest.org/en/latest/parametrize.html
We can chat about the best way to do this if you like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, all hyper-parameter related unit tests are now in parametrized style.
assert base_fit.call_args[0][1] == MINI_BATCH_SIZE | ||
|
||
|
||
def test_call_fit_none_mini_batch_size(sagemaker_session): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this test asserting on? And does the base fit need to be patched? (Same question about patch applies to the tests below as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test doesn't assert. It tests whether fit runs successfully(no exception) when no mini_batch_size is given(for this alg, mini_batch_size should have default value).
Other tests below this are similar, just test different case for mini_batch_size passed to fit(). Usually, for one algorithm, there are several cases for mini_batch_size. With default value? With valid range? Required?
Description:
1, Remove unused codes in unit tests of ntm and lda
2, Add more tests in unit test of factorization machines
3, Add pca and kmeans unit tests
4, Fix type of hyper-parameter tol in kmeans
5, Add missing hyper-parameter eval_metrics in kmeans
6, Remove some tests in test_amazon_estimator since they are in unit test of pca
7, Use validator function in pca instead of lambda
Test:
tox tests/unit passed