Skip to content

Commit 33772c2

Browse files
authored
Merge pull request #361 from RedisAI/ip_space_roundup_bug_fix
IP space roundup bug fix
2 parents bcf0dc6 + 49ef6bc commit 33772c2

File tree

1 file changed

+61
-23
lines changed

1 file changed

+61
-23
lines changed

hnswlib/space_ip.h

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@ namespace hnswlib {
1010
for (unsigned i = 0; i < qty; i++) {
1111
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
1212
}
13-
return (1.0f - res);
13+
return res;
1414

1515
}
1616

17+
static float
18+
InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
19+
return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
20+
}
21+
1722
#if defined(USE_AVX)
1823

1924
// Favor using AVX if available.
@@ -61,8 +66,13 @@ namespace hnswlib {
6166

6267
_mm_store_ps(TmpRes, sum_prod);
6368
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];;
64-
return 1.0f - sum;
65-
}
69+
return sum;
70+
}
71+
72+
static float
73+
InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
74+
return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr);
75+
}
6676

6777
#endif
6878

@@ -121,7 +131,12 @@ namespace hnswlib {
121131
_mm_store_ps(TmpRes, sum_prod);
122132
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
123133

124-
return 1.0f - sum;
134+
return sum;
135+
}
136+
137+
static float
138+
InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
139+
return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr);
125140
}
126141

127142
#endif
@@ -156,7 +171,12 @@ namespace hnswlib {
156171
_mm512_store_ps(TmpRes, sum512);
157172
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];
158173

159-
return 1.0f - sum;
174+
return sum;
175+
}
176+
177+
static float
178+
InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
179+
return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr);
160180
}
161181

162182
#endif
@@ -196,15 +216,20 @@ namespace hnswlib {
196216
_mm256_store_ps(TmpRes, sum256);
197217
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
198218

199-
return 1.0f - sum;
219+
return sum;
220+
}
221+
222+
static float
223+
InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
224+
return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr);
200225
}
201226

202227
#endif
203228

204229
#if defined(USE_SSE)
205230

206-
static float
207-
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
231+
static float
232+
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
208233
float PORTABLE_ALIGN32 TmpRes[8];
209234
float *pVect1 = (float *) pVect1v;
210235
float *pVect2 = (float *) pVect2v;
@@ -245,17 +270,24 @@ namespace hnswlib {
245270
_mm_store_ps(TmpRes, sum_prod);
246271
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
247272

248-
return 1.0f - sum;
273+
return sum;
274+
}
275+
276+
static float
277+
InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
278+
return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr);
249279
}
250280

251281
#endif
252282

253283
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
254284
DISTFUNC<float> InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
255285
DISTFUNC<float> InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
286+
DISTFUNC<float> InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
287+
DISTFUNC<float> InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;
256288

257289
static float
258-
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
290+
InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
259291
size_t qty = *((size_t *) qty_ptr);
260292
size_t qty16 = qty >> 4 << 4;
261293
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
@@ -264,11 +296,11 @@ namespace hnswlib {
264296

265297
size_t qty_left = qty - qty16;
266298
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
267-
return res + res_tail - 1.0f;
299+
return 1.0f - (res + res_tail);
268300
}
269301

270302
static float
271-
InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
303+
InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
272304
size_t qty = *((size_t *) qty_ptr);
273305
size_t qty4 = qty >> 2 << 2;
274306

@@ -279,7 +311,7 @@ namespace hnswlib {
279311
float *pVect2 = (float *) pVect2v + qty4;
280312
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
281313

282-
return res + res_tail - 1.0f;
314+
return 1.0f - (res + res_tail);
283315
}
284316
#endif
285317

@@ -290,30 +322,37 @@ namespace hnswlib {
290322
size_t dim_;
291323
public:
292324
InnerProductSpace(size_t dim) {
293-
fstdistfunc_ = InnerProduct;
325+
fstdistfunc_ = InnerProductDistance;
294326
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
295327
#if defined(USE_AVX512)
296-
if (AVX512Capable())
328+
if (AVX512Capable()) {
297329
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
298-
else if (AVXCapable())
330+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
331+
} else if (AVXCapable()) {
299332
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
333+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
334+
}
300335
#elif defined(USE_AVX)
301-
if (AVXCapable())
336+
if (AVXCapable()) {
302337
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
338+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
339+
}
303340
#endif
304341
#if defined(USE_AVX)
305-
if (AVXCapable())
342+
if (AVXCapable()) {
306343
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
344+
InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
345+
}
307346
#endif
308347

309348
if (dim % 16 == 0)
310-
fstdistfunc_ = InnerProductSIMD16Ext;
349+
fstdistfunc_ = InnerProductDistanceSIMD16Ext;
311350
else if (dim % 4 == 0)
312-
fstdistfunc_ = InnerProductSIMD4Ext;
351+
fstdistfunc_ = InnerProductDistanceSIMD4Ext;
313352
else if (dim > 16)
314-
fstdistfunc_ = InnerProductSIMD16ExtResiduals;
353+
fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
315354
else if (dim > 4)
316-
fstdistfunc_ = InnerProductSIMD4ExtResiduals;
355+
fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
317356
#endif
318357
dim_ = dim;
319358
data_size_ = dim * sizeof(float);
@@ -334,5 +373,4 @@ namespace hnswlib {
334373
~InnerProductSpace() {}
335374
};
336375

337-
338376
}

0 commit comments

Comments
 (0)