Skip to content

IP space roundup bug fix #361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions hnswlib/space_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@ namespace hnswlib {
for (unsigned i = 0; i < qty; i++) {
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
}
return (1.0f - res);
return res;

}

static float
InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
}

#if defined(USE_AVX)

// Favor using AVX if available.
Expand Down Expand Up @@ -61,8 +66,13 @@ namespace hnswlib {

_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];;
return 1.0f - sum;
}
return sum;
}

static float
InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr);
}

#endif

Expand Down Expand Up @@ -121,7 +131,12 @@ namespace hnswlib {
_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];

return 1.0f - sum;
return sum;
}

static float
InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr);
}

#endif
Expand Down Expand Up @@ -156,7 +171,12 @@ namespace hnswlib {
_mm512_store_ps(TmpRes, sum512);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];

return 1.0f - sum;
return sum;
}

static float
InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr);
}

#endif
Expand Down Expand Up @@ -196,15 +216,20 @@ namespace hnswlib {
_mm256_store_ps(TmpRes, sum256);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];

return 1.0f - sum;
return sum;
}

static float
InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr);
}

#endif

#if defined(USE_SSE)

static float
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
static float
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
Expand Down Expand Up @@ -245,17 +270,24 @@ namespace hnswlib {
_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];

return 1.0f - sum;
return sum;
}

static float
InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr);
}

#endif

#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
DISTFUNC<float> InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
DISTFUNC<float> InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
DISTFUNC<float> InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
DISTFUNC<float> InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;

static float
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
Expand All @@ -264,11 +296,11 @@ namespace hnswlib {

size_t qty_left = qty - qty16;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
return res + res_tail - 1.0f;
return 1.0f - (res + res_tail);
}

static float
InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;

Expand All @@ -279,7 +311,7 @@ namespace hnswlib {
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);

return res + res_tail - 1.0f;
return 1.0f - (res + res_tail);
}
#endif

Expand All @@ -290,30 +322,37 @@ namespace hnswlib {
size_t dim_;
public:
InnerProductSpace(size_t dim) {
fstdistfunc_ = InnerProduct;
fstdistfunc_ = InnerProductDistance;
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable())
if (AVX512Capable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
else if (AVXCapable())
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
} else if (AVXCapable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
}
#elif defined(USE_AVX)
if (AVXCapable())
if (AVXCapable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
}
#endif
#if defined(USE_AVX)
if (AVXCapable())
if (AVXCapable()) {
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
}
#endif

if (dim % 16 == 0)
fstdistfunc_ = InnerProductSIMD16Ext;
fstdistfunc_ = InnerProductDistanceSIMD16Ext;
else if (dim % 4 == 0)
fstdistfunc_ = InnerProductSIMD4Ext;
fstdistfunc_ = InnerProductDistanceSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = InnerProductSIMD16ExtResiduals;
fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = InnerProductSIMD4ExtResiduals;
fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
Expand All @@ -334,5 +373,4 @@ namespace hnswlib {
~InnerProductSpace() {}
};


}