Skip to content

AVX runtime check #357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,25 @@
#ifdef _MSC_VER
#include <intrin.h>
#include <stdexcept>
#include "cpu_x86.h"
void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
__int64 xgetbv(unsigned int x) {
return _xgetbv(x);
}
#else
#include <x86intrin.h>
#include <cpuid.h>
#include <stdint.h>
void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
}
uint64_t xgetbv(unsigned int index) {
uint32_t eax, edx;
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
return ((uint64_t)edx << 32) | eax;
}
#endif

#if defined(USE_AVX512)
Expand All @@ -30,6 +47,65 @@
#define PORTABLE_ALIGN32 __declspec(align(32))
#define PORTABLE_ALIGN64 __declspec(align(64))
#endif

// Adapted from https://github.com/Mysticial/FeatureDetector
#define _XCR_XFEATURE_ENABLED_MASK 0

bool AVXCapable() {
int cpuInfo[4];

// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];

bool HW_AVX = false;
if (nIds >= 0x00000001) {
cpuid(cpuInfo, 0x00000001, 0);
HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
}

// OS support
cpuid(cpuInfo, 1, 0);

bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;

bool avxSupported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avxSupported = (xcrFeatureMask & 0x6) == 0x6;
}
return HW_AVX && avxSupported;
}

bool AVX512Capable() {
if (!AVXCapable()) return false;

int cpuInfo[4];

// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];

bool HW_AVX512F = false;
if (nIds >= 0x00000007) { // AVX512 Foundation
cpuid(cpuInfo, 0x00000007, 0);
HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
}

// OS support
cpuid(cpuInfo, 1, 0);

bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;

bool avx512Supported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
}
return HW_AVX512F && avx512Supported;
}
#endif

#include <queue>
Expand Down Expand Up @@ -108,7 +184,6 @@ namespace hnswlib {

return result;
}

}

#include "space_l2.h"
Expand Down
39 changes: 31 additions & 8 deletions hnswlib/space_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace hnswlib {

// Favor using AVX if available.
static float
InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
Expand Down Expand Up @@ -64,10 +64,12 @@ namespace hnswlib {
return 1.0f - sum;
}

#elif defined(USE_SSE)
#endif

#if defined(USE_SSE)

static float
InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
Expand Down Expand Up @@ -128,7 +130,7 @@ namespace hnswlib {
#if defined(USE_AVX512)

static float
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN64 TmpRes[16];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
Expand Down Expand Up @@ -157,10 +159,12 @@ namespace hnswlib {
return 1.0f - sum;
}

#elif defined(USE_AVX)
#endif

#if defined(USE_AVX)

static float
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
Expand Down Expand Up @@ -195,10 +199,12 @@ namespace hnswlib {
return 1.0f - sum;
}

#elif defined(USE_SSE)
#endif

#if defined(USE_SSE)

static float
InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
Expand Down Expand Up @@ -245,6 +251,9 @@ namespace hnswlib {
#endif

#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
DISTFUNC<float> InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
DISTFUNC<float> InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;

static float
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
Expand Down Expand Up @@ -283,6 +292,20 @@ namespace hnswlib {
InnerProductSpace(size_t dim) {
fstdistfunc_ = InnerProduct;
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable())
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
else if (AVXCapable())
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
#elif defined(USE_AVX)
if (AVXCapable())
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
#endif
#if defined(USE_AVX)
if (AVXCapable())
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
#endif

if (dim % 16 == 0)
fstdistfunc_ = InnerProductSIMD16Ext;
else if (dim % 4 == 0)
Expand Down
31 changes: 23 additions & 8 deletions hnswlib/space_l2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace hnswlib {

// Favor using AVX512 if available.
static float
L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
Expand Down Expand Up @@ -52,12 +52,13 @@ namespace hnswlib {

return (res);
}
#endif

#elif defined(USE_AVX)
#if defined(USE_AVX)

// Favor using AVX if available.
static float
L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
Expand Down Expand Up @@ -89,10 +90,12 @@ namespace hnswlib {
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
}

#elif defined(USE_SSE)
#endif

#if defined(USE_SSE)

static float
L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
Expand Down Expand Up @@ -141,6 +144,8 @@ namespace hnswlib {
#endif

#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;

static float
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
Expand All @@ -156,7 +161,7 @@ namespace hnswlib {
#endif


#ifdef USE_SSE
#if defined(USE_SSE)
static float
L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
Expand Down Expand Up @@ -208,7 +213,17 @@ namespace hnswlib {
public:
L2Space(size_t dim) {
fstdistfunc_ = L2Sqr;
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
else if (AVXCapable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#elif defined(USE_AVX)
if (AVXCapable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#endif

if (dim % 16 == 0)
fstdistfunc_ = L2SqrSIMD16Ext;
else if (dim % 4 == 0)
Expand All @@ -217,7 +232,7 @@ namespace hnswlib {
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
#endif
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
}
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ class BuildExt(build_ext):
"""A custom build extension for adding compiler-specific options."""
c_opts = {
'msvc': ['/EHsc', '/openmp', '/O2'],
'unix': ['-O3', '-march=native'], # , '-w'
#'unix': ['-O3', '-march=native'], # , '-w'
'unix': ['-O3'], # , '-w'
}
if not os.environ.get("HNSWLIB_NO_NATIVE"):
c_opts['unix'].append('-march=native')

link_opts = {
'unix': [],
'msvc': [],
Expand Down