4
4
namespace hnswlib {
5
5
6
6
static float
7
- InnerProduct (const void *pVect1, const void *pVect2, const void *qty_ptr) {
7
+ InnerProduct_impl (const void *pVect1, const void *pVect2, const void *qty_ptr) {
8
8
size_t qty = *((size_t *) qty_ptr);
9
9
float res = 0 ;
10
10
for (unsigned i = 0 ; i < qty; i++) {
11
11
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
12
12
}
13
- return ( 1 . 0f - res) ;
13
+ return res;
14
14
15
15
}
16
16
17
+ static float
18
+ InnerProduct (const void *pVect1, const void *pVect2, const void *qty_ptr) {
19
+ return 1 .0f - InnerProduct_impl (pVect1, pVect2, qty_ptr);
20
+ }
21
+
22
+ #if defined(USE_AVX) || defined(USE_SSE)
17
23
#if defined(USE_AVX)
18
24
19
25
// Favor using AVX if available.
20
26
static float
21
- InnerProductSIMD4Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
27
+ InnerProductSIMD4Ext_impl (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
22
28
float PORTABLE_ALIGN32 TmpRes[8 ];
23
29
float *pVect1 = (float *) pVect1v;
24
30
float *pVect2 = (float *) pVect2v;
@@ -61,13 +67,13 @@ namespace hnswlib {
61
67
62
68
_mm_store_ps (TmpRes, sum_prod);
63
69
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];;
64
- return 1 . 0f - sum;
65
- }
70
+ return sum;
71
+ }
66
72
67
73
#elif defined(USE_SSE)
68
74
69
75
static float
70
- InnerProductSIMD4Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
76
+ InnerProductSIMD4Ext_impl (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
71
77
float PORTABLE_ALIGN32 TmpRes[8 ];
72
78
float *pVect1 = (float *) pVect1v;
73
79
float *pVect2 = (float *) pVect2v;
@@ -119,16 +125,24 @@ namespace hnswlib {
119
125
_mm_store_ps (TmpRes, sum_prod);
120
126
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
121
127
122
- return 1 . 0f - sum;
128
+ return sum;
123
129
}
124
130
125
131
#endif
132
+
133
+ static float
134
+ InnerProductSIMD4Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
135
+ return 1 .0f - InnerProductSIMD4Ext_impl (pVect1v, pVect2v, qty_ptr);
136
+ }
126
137
138
+ #endif
127
139
140
+
141
+ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
128
142
#if defined(USE_AVX512)
129
143
130
144
static float
131
- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
145
+ InnerProductSIMD16Ext_impl (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
132
146
float PORTABLE_ALIGN64 TmpRes[16 ];
133
147
float *pVect1 = (float *) pVect1v;
134
148
float *pVect2 = (float *) pVect2v;
@@ -154,13 +168,13 @@ namespace hnswlib {
154
168
_mm512_store_ps (TmpRes, sum512);
155
169
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ] + TmpRes[8 ] + TmpRes[9 ] + TmpRes[10 ] + TmpRes[11 ] + TmpRes[12 ] + TmpRes[13 ] + TmpRes[14 ] + TmpRes[15 ];
156
170
157
- return 1 . 0f - sum;
171
+ return sum;
158
172
}
159
173
160
174
#elif defined(USE_AVX)
161
175
162
176
static float
163
- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
177
+ InnerProductSIMD16Ext_impl (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
164
178
float PORTABLE_ALIGN32 TmpRes[8 ];
165
179
float *pVect1 = (float *) pVect1v;
166
180
float *pVect2 = (float *) pVect2v;
@@ -192,13 +206,13 @@ namespace hnswlib {
192
206
_mm256_store_ps (TmpRes, sum256);
193
207
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ];
194
208
195
- return 1 . 0f - sum;
209
+ return sum;
196
210
}
197
211
198
212
#elif defined(USE_SSE)
199
213
200
- static float
201
- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
214
+ static float
215
+ InnerProductSIMD16Ext_impl (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
202
216
float PORTABLE_ALIGN32 TmpRes[8 ];
203
217
float *pVect1 = (float *) pVect1v;
204
218
float *pVect2 = (float *) pVect2v;
@@ -239,7 +253,14 @@ namespace hnswlib {
239
253
_mm_store_ps (TmpRes, sum_prod);
240
254
float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
241
255
242
- return 1 .0f - sum;
256
+ return sum;
257
+ }
258
+
259
+ #endif
260
+
261
+ static float
262
+ InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
263
+ return 1 .0f - InnerProductSIMD16Ext_impl (pVect1v, pVect2v, qty_ptr);
243
264
}
244
265
245
266
#endif
@@ -249,28 +270,28 @@ namespace hnswlib {
249
270
InnerProductSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
250
271
size_t qty = *((size_t *) qty_ptr);
251
272
size_t qty16 = qty >> 4 << 4 ;
252
- float res = InnerProductSIMD16Ext (pVect1v, pVect2v, &qty16);
273
+ float res = InnerProductSIMD16Ext_impl (pVect1v, pVect2v, &qty16);
253
274
float *pVect1 = (float *) pVect1v + qty16;
254
275
float *pVect2 = (float *) pVect2v + qty16;
255
276
256
277
size_t qty_left = qty - qty16;
257
- float res_tail = InnerProduct (pVect1, pVect2, &qty_left);
258
- return res + res_tail - 1 . 0f ;
278
+ float res_tail = InnerProduct_impl (pVect1, pVect2, &qty_left);
279
+ return 1 . 0f - ( res + res_tail) ;
259
280
}
260
281
261
282
static float
262
283
InnerProductSIMD4ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
263
284
size_t qty = *((size_t *) qty_ptr);
264
285
size_t qty4 = qty >> 2 << 2 ;
265
286
266
- float res = InnerProductSIMD4Ext (pVect1v, pVect2v, &qty4);
287
+ float res = InnerProductSIMD4Ext_impl (pVect1v, pVect2v, &qty4);
267
288
size_t qty_left = qty - qty4;
268
289
269
290
float *pVect1 = (float *) pVect1v + qty4;
270
291
float *pVect2 = (float *) pVect2v + qty4;
271
- float res_tail = InnerProduct (pVect1, pVect2, &qty_left);
292
+ float res_tail = InnerProduct_impl (pVect1, pVect2, &qty_left);
272
293
273
- return res + res_tail - 1 . 0f ;
294
+ return 1 . 0f - ( res + res_tail) ;
274
295
}
275
296
#endif
276
297
@@ -311,5 +332,4 @@ namespace hnswlib {
311
332
~InnerProductSpace () {}
312
333
};
313
334
314
-
315
335
}
0 commit comments