Skip to content

Commit aea4eb2

Browse files
authored
Fix tests (#9)
1 parent 8044572 commit aea4eb2

File tree

14 files changed

+1164
-1268
lines changed

14 files changed

+1164
-1268
lines changed

.github/workflows/pytest.yml

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,30 @@ jobs:
1111
- name: Checkout
1212
uses: actions/checkout@v3
1313

14+
- name: Install Poetry
15+
run: |
16+
pipx install poetry
17+
1418
- name: Setup Python
1519
id: setup_py
1620
uses: actions/setup-python@v4
1721
with:
1822
python-version: '3.10'
19-
20-
- name: Install Poetry
21-
uses: snok/install-poetry@v1
22-
with:
23-
virtualenvs-create: true
24-
virtualenvs-in-project: true
25-
26-
- name: Load cached venv
27-
id: cached-poetry-dependencies
28-
uses: actions/cache@v3
29-
with:
30-
path: .venv
31-
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('poetry.lock') }}
23+
cache: 'poetry'
3224

3325
- name: Poetry install
3426
run: |
27+
poetry lock --check
3528
poetry install --extras tests --extras docs --extras annlibs
36-
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
3729
3830
- name: Run test suite
3931
run: |
40-
.venv/bin/pytest
32+
poetry run pytest -v --color=yes
4133
4234
- name: Run black
4335
run: |
44-
.venv/bin/black --check sklearn_ann
36+
poetry run black --check sklearn_ann
4537
4638
- name: Run flake8
4739
run: |
48-
.venv/bin/flake8 sklearn_ann
40+
poetry run flake8 sklearn_ann

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1+
__pycache__/
12
docs/_build/
23
activate.sh
4+
5+
# IDEs
6+
.vscode/

docs/conf.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
# -- Project information -----------------------------------------------------
2121

22-
project = 'sklearn-ann'
23-
copyright = '2021, Frankie Robertson'
24-
author = 'Frankie Robertson'
22+
project = "sklearn-ann"
23+
copyright = "2021, Frankie Robertson"
24+
author = "Frankie Robertson"
2525

2626

2727
# -- General configuration ---------------------------------------------------
@@ -30,33 +30,33 @@
3030
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
3131
# ones.
3232
extensions = [
33-
'sphinx.ext.autodoc',
34-
'sphinx.ext.autosummary',
35-
'numpydoc',
36-
'sphinx_issues',
37-
'sphinx.ext.viewcode',
33+
"sphinx.ext.autodoc",
34+
"sphinx.ext.autosummary",
35+
"numpydoc",
36+
"sphinx_issues",
37+
"sphinx.ext.viewcode",
3838
]
3939

4040
# Add any paths that contain templates here, relative to this directory.
41-
templates_path = ['_templates']
41+
templates_path = ["_templates"]
4242

4343
# List of patterns, relative to source directory, that match files and
4444
# directories to ignore when looking for source files.
4545
# This pattern also affects html_static_path and html_extra_path.
46-
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
46+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
4747

4848

4949
# -- Options for HTML output -------------------------------------------------
5050

5151
# The theme to use for HTML and HTML Help pages. See the documentation for
5252
# a list of builtin themes.
5353
#
54-
html_theme = 'sphinx_rtd_theme'
54+
html_theme = "sphinx_rtd_theme"
5555

5656
# Add any paths that contain custom static files (such as style sheets) here,
5757
# relative to this directory. They are copied after the builtin static files,
5858
# so a file named "default.css" will overwrite the builtin "default.css".
59-
html_static_path = ['_static']
59+
html_static_path = ["_static"]
6060
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
6161

6262
autodoc_mock_imports = ["annoy", "faiss", "pynndescent", "nmslib"]

examples/rnn_dbscan_big.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def fetch_mnist():
2525
return mnist.data / 255, mnist.target
2626

2727

28-
memory = Memory('./mnist')
28+
memory = Memory("./mnist")
2929

3030
X, y = memory.cache(fetch_mnist)()
3131

@@ -44,21 +44,22 @@ def run_rnn_dbscan(neighbor_transformer, n_neighbors, **kwargs):
4444
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
4545
n_noise_ = list(labels).count(-1)
4646

47-
print('Estimated number of clusters: %d' % n_clusters_)
48-
print('Estimated number of noise points: %d' % n_noise_)
47+
print("Estimated number of clusters: %d" % n_clusters_)
48+
print("Estimated number of noise points: %d" % n_noise_)
4949
print("Homogeneity: %0.3f" % metrics.homogeneity_score(y, labels))
5050
print("Completeness: %0.3f" % metrics.completeness_score(y, labels))
5151
print("V-measure: %0.3f" % metrics.v_measure_score(y, labels))
52-
print("Adjusted Rand Index: %0.3f"
53-
% metrics.adjusted_rand_score(y, labels))
54-
print("Adjusted Mutual Information: %0.3f"
55-
% metrics.adjusted_mutual_info_score(y, labels))
56-
print("Silhouette Coefficient: %0.3f"
57-
% metrics.silhouette_score(X, labels))
52+
print("Adjusted Rand Index: %0.3f" % metrics.adjusted_rand_score(y, labels))
53+
print(
54+
"Adjusted Mutual Information: %0.3f"
55+
% metrics.adjusted_mutual_info_score(y, labels)
56+
)
57+
print("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(X, labels))
5858

5959

6060
if __name__ == "__main__":
6161
import code
62+
6263
print("Now you can import your chosen transformer_cls and run:")
6364
print("run_rnn_dbscan(transformer_cls, n_neighbors, **params)")
6465
print("e.g.")

examples/rnn_dbscan_simple.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
# #############################################################################
2323
# Generate sample data
2424
centers = [[1, 1], [-1, -1], [1, -1]]
25-
X, labels_true = make_blobs(n_samples=750, centers=centers, cluster_std=0.4,
26-
random_state=0)
25+
X, labels_true = make_blobs(
26+
n_samples=750, centers=centers, cluster_std=0.4, random_state=0
27+
)
2728

2829
X = StandardScaler().fit_transform(X)
2930

@@ -38,40 +39,51 @@
3839
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
3940
n_noise_ = list(labels).count(-1)
4041

41-
print('Estimated number of clusters: %d' % n_clusters_)
42-
print('Estimated number of noise points: %d' % n_noise_)
42+
print("Estimated number of clusters: %d" % n_clusters_)
43+
print("Estimated number of noise points: %d" % n_noise_)
4344
print("Homogeneity: %0.3f" % metrics.homogeneity_score(labels_true, labels))
4445
print("Completeness: %0.3f" % metrics.completeness_score(labels_true, labels))
4546
print("V-measure: %0.3f" % metrics.v_measure_score(labels_true, labels))
46-
print("Adjusted Rand Index: %0.3f"
47-
% metrics.adjusted_rand_score(labels_true, labels))
48-
print("Adjusted Mutual Information: %0.3f"
49-
% metrics.adjusted_mutual_info_score(labels_true, labels))
50-
print("Silhouette Coefficient: %0.3f"
51-
% metrics.silhouette_score(X, labels))
47+
print("Adjusted Rand Index: %0.3f" % metrics.adjusted_rand_score(labels_true, labels))
48+
print(
49+
"Adjusted Mutual Information: %0.3f"
50+
% metrics.adjusted_mutual_info_score(labels_true, labels)
51+
)
52+
print("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(X, labels))
5253

5354
# #############################################################################
5455
# Plot result
5556
import matplotlib.pyplot as plt
5657

5758
# Black removed and is used for noise instead.
5859
unique_labels = set(labels)
59-
colors = [plt.cm.Spectral(each)
60-
for each in np.linspace(0, 1, len(unique_labels))]
60+
colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
6161
for k, col in zip(unique_labels, colors):
6262
if k == -1:
6363
# Black used for noise.
6464
col = [0, 0, 0, 1]
6565

66-
class_member_mask = (labels == k)
66+
class_member_mask = labels == k
6767

6868
xy = X[class_member_mask & core_samples_mask]
69-
plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col),
70-
markeredgecolor='k', markersize=14)
69+
plt.plot(
70+
xy[:, 0],
71+
xy[:, 1],
72+
"o",
73+
markerfacecolor=tuple(col),
74+
markeredgecolor="k",
75+
markersize=14,
76+
)
7177

7278
xy = X[class_member_mask & ~core_samples_mask]
73-
plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col),
74-
markeredgecolor='k', markersize=6)
75-
76-
plt.title('Estimated number of clusters: %d' % n_clusters_)
79+
plt.plot(
80+
xy[:, 0],
81+
xy[:, 1],
82+
"o",
83+
markerfacecolor=tuple(col),
84+
markeredgecolor="k",
85+
markersize=6,
86+
)
87+
88+
plt.title("Estimated number of clusters: %d" % n_clusters_)
7789
plt.show()

0 commit comments

Comments
 (0)