Skip to content

Commit 7b41eca

Browse files
committed
added wraper functions and moved the 1.0f - there
1 parent 14cabd0 commit 7b41eca

File tree

1 file changed

+41
-21
lines changed

1 file changed

+41
-21
lines changed

hnswlib/space_ip.h

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,27 @@
44
namespace hnswlib {
55

66
static float
7-
InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
7+
InnerProduct_impl(const void *pVect1, const void *pVect2, const void *qty_ptr) {
88
size_t qty = *((size_t *) qty_ptr);
99
float res = 0;
1010
for (unsigned i = 0; i < qty; i++) {
1111
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
1212
}
13-
return (1.0f - res);
13+
return res;
1414

1515
}
1616

17+
static float
18+
InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
19+
return 1.0f - InnerProduct_impl(pVect1, pVect2, qty_ptr);
20+
}
21+
22+
#if defined(USE_AVX) || defined(USE_SSE)
1723
#if defined(USE_AVX)
1824

1925
// Favor using AVX if available.
2026
static float
21-
InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
27+
InnerProductSIMD4Ext_impl(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
2228
float PORTABLE_ALIGN32 TmpRes[8];
2329
float *pVect1 = (float *) pVect1v;
2430
float *pVect2 = (float *) pVect2v;
@@ -61,13 +67,13 @@ namespace hnswlib {
6167

6268
_mm_store_ps(TmpRes, sum_prod);
6369
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];;
64-
return 1.0f - sum;
65-
}
70+
return sum;
71+
}
6672

6773
#elif defined(USE_SSE)
6874

6975
static float
70-
InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
76+
InnerProductSIMD4Ext_impl(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
7177
float PORTABLE_ALIGN32 TmpRes[8];
7278
float *pVect1 = (float *) pVect1v;
7379
float *pVect2 = (float *) pVect2v;
@@ -119,16 +125,24 @@ namespace hnswlib {
119125
_mm_store_ps(TmpRes, sum_prod);
120126
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
121127

122-
return 1.0f - sum;
128+
return sum;
123129
}
124130

125131
#endif
132+
133+
static float
134+
InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
135+
return 1.0f - InnerProductSIMD4Ext_impl(pVect1v, pVect2v, qty_ptr);
136+
}
126137

138+
#endif
127139

140+
141+
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
128142
#if defined(USE_AVX512)
129143

130144
static float
131-
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
145+
InnerProductSIMD16Ext_impl(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
132146
float PORTABLE_ALIGN64 TmpRes[16];
133147
float *pVect1 = (float *) pVect1v;
134148
float *pVect2 = (float *) pVect2v;
@@ -154,13 +168,13 @@ namespace hnswlib {
154168
_mm512_store_ps(TmpRes, sum512);
155169
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];
156170

157-
return 1.0f - sum;
171+
return sum;
158172
}
159173

160174
#elif defined(USE_AVX)
161175

162176
static float
163-
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
177+
InnerProductSIMD16Ext_impl(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
164178
float PORTABLE_ALIGN32 TmpRes[8];
165179
float *pVect1 = (float *) pVect1v;
166180
float *pVect2 = (float *) pVect2v;
@@ -192,13 +206,13 @@ namespace hnswlib {
192206
_mm256_store_ps(TmpRes, sum256);
193207
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
194208

195-
return 1.0f - sum;
209+
return sum;
196210
}
197211

198212
#elif defined(USE_SSE)
199213

200-
static float
201-
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
214+
static float
215+
InnerProductSIMD16Ext_impl(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
202216
float PORTABLE_ALIGN32 TmpRes[8];
203217
float *pVect1 = (float *) pVect1v;
204218
float *pVect2 = (float *) pVect2v;
@@ -239,7 +253,14 @@ namespace hnswlib {
239253
_mm_store_ps(TmpRes, sum_prod);
240254
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
241255

242-
return 1.0f - sum;
256+
return sum;
257+
}
258+
259+
#endif
260+
261+
static float
262+
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
263+
return 1.0f - InnerProductSIMD16Ext_impl(pVect1v, pVect2v, qty_ptr);
243264
}
244265

245266
#endif
@@ -249,28 +270,28 @@ namespace hnswlib {
249270
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
250271
size_t qty = *((size_t *) qty_ptr);
251272
size_t qty16 = qty >> 4 << 4;
252-
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
273+
float res = InnerProductSIMD16Ext_impl(pVect1v, pVect2v, &qty16);
253274
float *pVect1 = (float *) pVect1v + qty16;
254275
float *pVect2 = (float *) pVect2v + qty16;
255276

256277
size_t qty_left = qty - qty16;
257-
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
258-
return res + res_tail - 1.0f;
278+
float res_tail = InnerProduct_impl(pVect1, pVect2, &qty_left);
279+
return 1.0f - (res + res_tail);
259280
}
260281

261282
static float
262283
InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
263284
size_t qty = *((size_t *) qty_ptr);
264285
size_t qty4 = qty >> 2 << 2;
265286

266-
float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
287+
float res = InnerProductSIMD4Ext_impl(pVect1v, pVect2v, &qty4);
267288
size_t qty_left = qty - qty4;
268289

269290
float *pVect1 = (float *) pVect1v + qty4;
270291
float *pVect2 = (float *) pVect2v + qty4;
271-
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
292+
float res_tail = InnerProduct_impl(pVect1, pVect2, &qty_left);
272293

273-
return res + res_tail - 1.0f;
294+
return 1.0f - (res + res_tail);
274295
}
275296
#endif
276297

@@ -311,5 +332,4 @@ namespace hnswlib {
311332
~InnerProductSpace() {}
312333
};
313334

314-
315335
}

0 commit comments

Comments
 (0)