vendor/hnswlib/space_l2.h in umappp-0.1.6 vs vendor/hnswlib/space_l2.h in umappp-0.2.0

- old
+ new

@@ -1,281 +1,324 @@ #pragma once #include "hnswlib.h" namespace hnswlib { - static float - L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); +static float +L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); - float res = 0; - for (size_t i = 0; i < qty; i++) { - float t = *pVect1 - *pVect2; - pVect1++; - pVect2++; - res += t * t; - } - return (res); + float res = 0; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; + res += t * t; } + return (res); +} -#if defined(USE_AVX) +#if defined(USE_AVX512) - // Favor using AVX if available. - static float - L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN32 TmpRes[8]; - size_t qty16 = qty >> 4; +// Favor using AVX512 if available. +static float +L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN64 TmpRes[16]; + size_t qty16 = qty >> 4; - const float *pEnd1 = pVect1 + (qty16 << 4); + const float *pEnd1 = pVect1 + (qty16 << 4); - __m256 diff, v1, v2; - __m256 sum = _mm256_set1_ps(0); + __m512 diff, v1, v2; + __m512 sum = _mm512_set1_ps(0); - while (pVect1 < pEnd1) { - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - diff = _mm256_sub_ps(v1, v2); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + while (pVect1 < pEnd1) { + v1 = _mm512_loadu_ps(pVect1); + pVect1 += 16; + v2 = _mm512_loadu_ps(pVect2); + pVect2 += 16; + diff = _mm512_sub_ps(v1, v2); + // sum = _mm512_fmadd_ps(diff, diff, sum); + sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); + } - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - diff = _mm256_sub_ps(v1, v2); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); - } + _mm512_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + + TmpRes[13] + TmpRes[14] + TmpRes[15]; - _mm256_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + return (res); +} +#endif + +#if defined(USE_AVX) + +// Favor using AVX if available. +static float +L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); } -#elif defined(USE_SSE) + _mm256_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; +} - static float - L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN32 TmpRes[8]; - size_t qty16 = qty >> 4; +#endif - const float *pEnd1 = pVect1 + (qty16 << 4); +#if defined(USE_SSE) - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); +static float +L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + const float *pEnd1 = pVect1 + (qty16 << 4); - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - } + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } + + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} #endif -#if defined(USE_SSE) || defined(USE_AVX) - static float - L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty >> 4 << 4; - float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); - float *pVect1 = (float *) pVect1v + qty16; - float *pVect2 = (float *) pVect2v + qty16; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +static DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; - size_t qty_left = qty - qty16; - float res_tail = L2Sqr(pVect1, pVect2, &qty_left); - return (res + res_tail); - } +static float +L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + return (res + res_tail); +} #endif -#ifdef USE_SSE - static float - L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); +#if defined(USE_SSE) +static float +L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2; + size_t qty4 = qty >> 2; - const float *pEnd1 = pVect1 + (qty4 << 2); + const float *pEnd1 = pVect1 + (qty4 << 2); - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - } - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} - static float - L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2 << 2; +static float +L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; - float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); - size_t qty_left = qty - qty4; + float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; - float *pVect1 = (float *) pVect1v + qty4; - float *pVect2 = (float *) pVect2v + qty4; - float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); - return (res + res_tail); - } + return (res + res_tail); +} #endif - class L2Space : public SpaceInterface<float> { +class L2Space : public SpaceInterface<float> { + DISTFUNC<float> fstdistfunc_; + size_t data_size_; + size_t dim_; - DISTFUNC<float> fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - L2Space(size_t dim) { - fstdistfunc_ = L2Sqr; - #if defined(USE_SSE) || defined(USE_AVX) - if (dim % 16 == 0) - fstdistfunc_ = L2SqrSIMD16Ext; - else if (dim % 4 == 0) - fstdistfunc_ = L2SqrSIMD4Ext; - else if (dim > 16) - fstdistfunc_ = L2SqrSIMD16ExtResiduals; - else if (dim > 4) - fstdistfunc_ = L2SqrSIMD4ExtResiduals; - #endif - dim_ = dim; - data_size_ = dim * sizeof(float); - } + public: + L2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #endif - size_t get_data_size() { - return data_size_; - } + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } - DISTFUNC<float> get_dist_func() { - return fstdistfunc_; - } + size_t get_data_size() { + return data_size_; + } - void *get_dist_func_param() { - return &dim_; - } + DISTFUNC<float> get_dist_func() { + return fstdistfunc_; + } - ~L2Space() {} - }; + void *get_dist_func_param() { + return &dim_; + } - static int - L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + ~L2Space() {} +}; - size_t qty = *((size_t *) qty_ptr); - int res = 0; - unsigned char *a = (unsigned char *) pVect1; - unsigned char *b = (unsigned char *) pVect2; +static int +L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + int res = 0; + unsigned char *a = (unsigned char *) pVect1; + unsigned char *b = (unsigned char *) pVect2; - qty = qty >> 2; - for (size_t i = 0; i < qty; i++) { - - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - } - return (res); + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; } + return (res); +} - static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { - size_t qty = *((size_t*)qty_ptr); - int res = 0; - unsigned char* a = (unsigned char*)pVect1; - unsigned char* b = (unsigned char*)pVect2; +static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + int res = 0; + unsigned char* a = (unsigned char*)pVect1; + unsigned char* b = (unsigned char*)pVect2; - for(size_t i = 0; i < qty; i++) - { - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - } - return (res); + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; } + return (res); +} - class L2SpaceI : public SpaceInterface<int> { +class L2SpaceI : public SpaceInterface<int> { + DISTFUNC<int> fstdistfunc_; + size_t data_size_; + size_t dim_; - DISTFUNC<int> fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - L2SpaceI(size_t dim) { - if(dim % 4 == 0) { - fstdistfunc_ = L2SqrI4x; - } - else { - fstdistfunc_ = L2SqrI; - } - dim_ = dim; - data_size_ = dim * sizeof(unsigned char); + public: + L2SpaceI(size_t dim) { + if (dim % 4 == 0) { + fstdistfunc_ = L2SqrI4x; + } else { + fstdistfunc_ = L2SqrI; } + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); + } - size_t get_data_size() { - return data_size_; - } + size_t get_data_size() { + return data_size_; + } - DISTFUNC<int> get_dist_func() { - return fstdistfunc_; - } + DISTFUNC<int> get_dist_func() { + return fstdistfunc_; + } - void *get_dist_func_param() { - return &dim_; - } + void *get_dist_func_param() { + return &dim_; + } - ~L2SpaceI() {} - }; - - -} \ No newline at end of file + ~L2SpaceI() {} +}; +} // namespace hnswlib