vendor/hnswlib/space_l2.h in umappp-0.1.6 vs vendor/hnswlib/space_l2.h in umappp-0.2.0
- old
+ new
@@ -1,281 +1,324 @@
#pragma once
#include "hnswlib.h"
namespace hnswlib {
- static float
- L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
- float *pVect1 = (float *) pVect1v;
- float *pVect2 = (float *) pVect2v;
- size_t qty = *((size_t *) qty_ptr);
+static float
+L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
+ float *pVect1 = (float *) pVect1v;
+ float *pVect2 = (float *) pVect2v;
+ size_t qty = *((size_t *) qty_ptr);
- float res = 0;
- for (size_t i = 0; i < qty; i++) {
- float t = *pVect1 - *pVect2;
- pVect1++;
- pVect2++;
- res += t * t;
- }
- return (res);
+ float res = 0;
+ for (size_t i = 0; i < qty; i++) {
+ float t = *pVect1 - *pVect2;
+ pVect1++;
+ pVect2++;
+ res += t * t;
}
+ return (res);
+}
-#if defined(USE_AVX)
+#if defined(USE_AVX512)
- // Favor using AVX if available.
- static float
- L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
- float *pVect1 = (float *) pVect1v;
- float *pVect2 = (float *) pVect2v;
- size_t qty = *((size_t *) qty_ptr);
- float PORTABLE_ALIGN32 TmpRes[8];
- size_t qty16 = qty >> 4;
+// Favor using AVX512 if available.
+static float
+L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
+ float *pVect1 = (float *) pVect1v;
+ float *pVect2 = (float *) pVect2v;
+ size_t qty = *((size_t *) qty_ptr);
+ float PORTABLE_ALIGN64 TmpRes[16];
+ size_t qty16 = qty >> 4;
- const float *pEnd1 = pVect1 + (qty16 << 4);
+ const float *pEnd1 = pVect1 + (qty16 << 4);
- __m256 diff, v1, v2;
- __m256 sum = _mm256_set1_ps(0);
+ __m512 diff, v1, v2;
+ __m512 sum = _mm512_set1_ps(0);
- while (pVect1 < pEnd1) {
- v1 = _mm256_loadu_ps(pVect1);
- pVect1 += 8;
- v2 = _mm256_loadu_ps(pVect2);
- pVect2 += 8;
- diff = _mm256_sub_ps(v1, v2);
- sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
+ while (pVect1 < pEnd1) {
+ v1 = _mm512_loadu_ps(pVect1);
+ pVect1 += 16;
+ v2 = _mm512_loadu_ps(pVect2);
+ pVect2 += 16;
+ diff = _mm512_sub_ps(v1, v2);
+ // sum = _mm512_fmadd_ps(diff, diff, sum);
+ sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
+ }
- v1 = _mm256_loadu_ps(pVect1);
- pVect1 += 8;
- v2 = _mm256_loadu_ps(pVect2);
- pVect2 += 8;
- diff = _mm256_sub_ps(v1, v2);
- sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
- }
+ _mm512_store_ps(TmpRes, sum);
+ float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] +
+ TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] +
+ TmpRes[13] + TmpRes[14] + TmpRes[15];
- _mm256_store_ps(TmpRes, sum);
- return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
+ return (res);
+}
+#endif
+
+#if defined(USE_AVX)
+
+// Favor using AVX if available.
+static float
+L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
+ float *pVect1 = (float *) pVect1v;
+ float *pVect2 = (float *) pVect2v;
+ size_t qty = *((size_t *) qty_ptr);
+ float PORTABLE_ALIGN32 TmpRes[8];
+ size_t qty16 = qty >> 4;
+
+ const float *pEnd1 = pVect1 + (qty16 << 4);
+
+ __m256 diff, v1, v2;
+ __m256 sum = _mm256_set1_ps(0);
+
+ while (pVect1 < pEnd1) {
+ v1 = _mm256_loadu_ps(pVect1);
+ pVect1 += 8;
+ v2 = _mm256_loadu_ps(pVect2);
+ pVect2 += 8;
+ diff = _mm256_sub_ps(v1, v2);
+ sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
+
+ v1 = _mm256_loadu_ps(pVect1);
+ pVect1 += 8;
+ v2 = _mm256_loadu_ps(pVect2);
+ pVect2 += 8;
+ diff = _mm256_sub_ps(v1, v2);
+ sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
}
-#elif defined(USE_SSE)
+ _mm256_store_ps(TmpRes, sum);
+ return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
+}
- static float
- L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
- float *pVect1 = (float *) pVect1v;
- float *pVect2 = (float *) pVect2v;
- size_t qty = *((size_t *) qty_ptr);
- float PORTABLE_ALIGN32 TmpRes[8];
- size_t qty16 = qty >> 4;
+#endif
- const float *pEnd1 = pVect1 + (qty16 << 4);
+#if defined(USE_SSE)
- __m128 diff, v1, v2;
- __m128 sum = _mm_set1_ps(0);
+static float
+L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
+ float *pVect1 = (float *) pVect1v;
+ float *pVect2 = (float *) pVect2v;
+ size_t qty = *((size_t *) qty_ptr);
+ float PORTABLE_ALIGN32 TmpRes[8];
+ size_t qty16 = qty >> 4;
- while (pVect1 < pEnd1) {
- //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
- v1 = _mm_loadu_ps(pVect1);
- pVect1 += 4;
- v2 = _mm_loadu_ps(pVect2);
- pVect2 += 4;
- diff = _mm_sub_ps(v1, v2);
- sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
+ const float *pEnd1 = pVect1 + (qty16 << 4);
- v1 = _mm_loadu_ps(pVect1);
- pVect1 += 4;
- v2 = _mm_loadu_ps(pVect2);
- pVect2 += 4;
- diff = _mm_sub_ps(v1, v2);
- sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
+ __m128 diff, v1, v2;
+ __m128 sum = _mm_set1_ps(0);
- v1 = _mm_loadu_ps(pVect1);
- pVect1 += 4;
- v2 = _mm_loadu_ps(pVect2);
- pVect2 += 4;
- diff = _mm_sub_ps(v1, v2);
- sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
+ while (pVect1 < pEnd1) {
+ //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
+ v1 = _mm_loadu_ps(pVect1);
+ pVect1 += 4;
+ v2 = _mm_loadu_ps(pVect2);
+ pVect2 += 4;
+ diff = _mm_sub_ps(v1, v2);
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
- v1 = _mm_loadu_ps(pVect1);
- pVect1 += 4;
- v2 = _mm_loadu_ps(pVect2);
- pVect2 += 4;
- diff = _mm_sub_ps(v1, v2);
- sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
- }
+ v1 = _mm_loadu_ps(pVect1);
+ pVect1 += 4;
+ v2 = _mm_loadu_ps(pVect2);
+ pVect2 += 4;
+ diff = _mm_sub_ps(v1, v2);
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
- _mm_store_ps(TmpRes, sum);
- return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
+ v1 = _mm_loadu_ps(pVect1);
+ pVect1 += 4;
+ v2 = _mm_loadu_ps(pVect2);
+ pVect2 += 4;
+ diff = _mm_sub_ps(v1, v2);
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
+
+ v1 = _mm_loadu_ps(pVect1);
+ pVect1 += 4;
+ v2 = _mm_loadu_ps(pVect2);
+ pVect2 += 4;
+ diff = _mm_sub_ps(v1, v2);
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
+
+ _mm_store_ps(TmpRes, sum);
+ return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
+}
#endif
-#if defined(USE_SSE) || defined(USE_AVX)
- static float
- L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
- size_t qty = *((size_t *) qty_ptr);
- size_t qty16 = qty >> 4 << 4;
- float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
- float *pVect1 = (float *) pVect1v + qty16;
- float *pVect2 = (float *) pVect2v + qty16;
+#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
+static DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;
- size_t qty_left = qty - qty16;
- float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
- return (res + res_tail);
- }
+static float
+L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
+ size_t qty = *((size_t *) qty_ptr);
+ size_t qty16 = qty >> 4 << 4;
+ float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
+ float *pVect1 = (float *) pVect1v + qty16;
+ float *pVect2 = (float *) pVect2v + qty16;
+
+ size_t qty_left = qty - qty16;
+ float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
+ return (res + res_tail);
+}
#endif
-#ifdef USE_SSE
- static float
- L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
- float PORTABLE_ALIGN32 TmpRes[8];
- float *pVect1 = (float *) pVect1v;
- float *pVect2 = (float *) pVect2v;
- size_t qty = *((size_t *) qty_ptr);
+#if defined(USE_SSE)
+static float
+L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
+ float PORTABLE_ALIGN32 TmpRes[8];
+ float *pVect1 = (float *) pVect1v;
+ float *pVect2 = (float *) pVect2v;
+ size_t qty = *((size_t *) qty_ptr);
- size_t qty4 = qty >> 2;
+ size_t qty4 = qty >> 2;
- const float *pEnd1 = pVect1 + (qty4 << 2);
+ const float *pEnd1 = pVect1 + (qty4 << 2);
- __m128 diff, v1, v2;
- __m128 sum = _mm_set1_ps(0);
+ __m128 diff, v1, v2;
+ __m128 sum = _mm_set1_ps(0);
- while (pVect1 < pEnd1) {
- v1 = _mm_loadu_ps(pVect1);
- pVect1 += 4;
- v2 = _mm_loadu_ps(pVect2);
- pVect2 += 4;
- diff = _mm_sub_ps(v1, v2);
- sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
- }
- _mm_store_ps(TmpRes, sum);
- return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
+ while (pVect1 < pEnd1) {
+ v1 = _mm_loadu_ps(pVect1);
+ pVect1 += 4;
+ v2 = _mm_loadu_ps(pVect2);
+ pVect2 += 4;
+ diff = _mm_sub_ps(v1, v2);
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
+ _mm_store_ps(TmpRes, sum);
+ return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
+}
- static float
- L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
- size_t qty = *((size_t *) qty_ptr);
- size_t qty4 = qty >> 2 << 2;
+static float
+L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
+ size_t qty = *((size_t *) qty_ptr);
+ size_t qty4 = qty >> 2 << 2;
- float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
- size_t qty_left = qty - qty4;
+ float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
+ size_t qty_left = qty - qty4;
- float *pVect1 = (float *) pVect1v + qty4;
- float *pVect2 = (float *) pVect2v + qty4;
- float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
+ float *pVect1 = (float *) pVect1v + qty4;
+ float *pVect2 = (float *) pVect2v + qty4;
+ float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
- return (res + res_tail);
- }
+ return (res + res_tail);
+}
#endif
- class L2Space : public SpaceInterface<float> {
+class L2Space : public SpaceInterface<float> {
+ DISTFUNC<float> fstdistfunc_;
+ size_t data_size_;
+ size_t dim_;
- DISTFUNC<float> fstdistfunc_;
- size_t data_size_;
- size_t dim_;
- public:
- L2Space(size_t dim) {
- fstdistfunc_ = L2Sqr;
- #if defined(USE_SSE) || defined(USE_AVX)
- if (dim % 16 == 0)
- fstdistfunc_ = L2SqrSIMD16Ext;
- else if (dim % 4 == 0)
- fstdistfunc_ = L2SqrSIMD4Ext;
- else if (dim > 16)
- fstdistfunc_ = L2SqrSIMD16ExtResiduals;
- else if (dim > 4)
- fstdistfunc_ = L2SqrSIMD4ExtResiduals;
- #endif
- dim_ = dim;
- data_size_ = dim * sizeof(float);
- }
+ public:
+ L2Space(size_t dim) {
+ fstdistfunc_ = L2Sqr;
+#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
+ #if defined(USE_AVX512)
+ if (AVX512Capable())
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
+ else if (AVXCapable())
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
+ #elif defined(USE_AVX)
+ if (AVXCapable())
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
+ #endif
- size_t get_data_size() {
- return data_size_;
- }
+ if (dim % 16 == 0)
+ fstdistfunc_ = L2SqrSIMD16Ext;
+ else if (dim % 4 == 0)
+ fstdistfunc_ = L2SqrSIMD4Ext;
+ else if (dim > 16)
+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
+ else if (dim > 4)
+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
+#endif
+ dim_ = dim;
+ data_size_ = dim * sizeof(float);
+ }
- DISTFUNC<float> get_dist_func() {
- return fstdistfunc_;
- }
+ size_t get_data_size() {
+ return data_size_;
+ }
- void *get_dist_func_param() {
- return &dim_;
- }
+ DISTFUNC<float> get_dist_func() {
+ return fstdistfunc_;
+ }
- ~L2Space() {}
- };
+ void *get_dist_func_param() {
+ return &dim_;
+ }
- static int
- L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
+ ~L2Space() {}
+};
- size_t qty = *((size_t *) qty_ptr);
- int res = 0;
- unsigned char *a = (unsigned char *) pVect1;
- unsigned char *b = (unsigned char *) pVect2;
+static int
+L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
+ size_t qty = *((size_t *) qty_ptr);
+ int res = 0;
+ unsigned char *a = (unsigned char *) pVect1;
+ unsigned char *b = (unsigned char *) pVect2;
- qty = qty >> 2;
- for (size_t i = 0; i < qty; i++) {
-
- res += ((*a) - (*b)) * ((*a) - (*b));
- a++;
- b++;
- res += ((*a) - (*b)) * ((*a) - (*b));
- a++;
- b++;
- res += ((*a) - (*b)) * ((*a) - (*b));
- a++;
- b++;
- res += ((*a) - (*b)) * ((*a) - (*b));
- a++;
- b++;
- }
- return (res);
+ qty = qty >> 2;
+ for (size_t i = 0; i < qty; i++) {
+ res += ((*a) - (*b)) * ((*a) - (*b));
+ a++;
+ b++;
+ res += ((*a) - (*b)) * ((*a) - (*b));
+ a++;
+ b++;
+ res += ((*a) - (*b)) * ((*a) - (*b));
+ a++;
+ b++;
+ res += ((*a) - (*b)) * ((*a) - (*b));
+ a++;
+ b++;
}
+ return (res);
+}
- static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
- size_t qty = *((size_t*)qty_ptr);
- int res = 0;
- unsigned char* a = (unsigned char*)pVect1;
- unsigned char* b = (unsigned char*)pVect2;
+static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
+ size_t qty = *((size_t*)qty_ptr);
+ int res = 0;
+ unsigned char* a = (unsigned char*)pVect1;
+ unsigned char* b = (unsigned char*)pVect2;
- for(size_t i = 0; i < qty; i++)
- {
- res += ((*a) - (*b)) * ((*a) - (*b));
- a++;
- b++;
- }
- return (res);
+ for (size_t i = 0; i < qty; i++) {
+ res += ((*a) - (*b)) * ((*a) - (*b));
+ a++;
+ b++;
}
+ return (res);
+}
- class L2SpaceI : public SpaceInterface<int> {
+class L2SpaceI : public SpaceInterface<int> {
+ DISTFUNC<int> fstdistfunc_;
+ size_t data_size_;
+ size_t dim_;
- DISTFUNC<int> fstdistfunc_;
- size_t data_size_;
- size_t dim_;
- public:
- L2SpaceI(size_t dim) {
- if(dim % 4 == 0) {
- fstdistfunc_ = L2SqrI4x;
- }
- else {
- fstdistfunc_ = L2SqrI;
- }
- dim_ = dim;
- data_size_ = dim * sizeof(unsigned char);
+ public:
+ L2SpaceI(size_t dim) {
+ if (dim % 4 == 0) {
+ fstdistfunc_ = L2SqrI4x;
+ } else {
+ fstdistfunc_ = L2SqrI;
}
+ dim_ = dim;
+ data_size_ = dim * sizeof(unsigned char);
+ }
- size_t get_data_size() {
- return data_size_;
- }
+ size_t get_data_size() {
+ return data_size_;
+ }
- DISTFUNC<int> get_dist_func() {
- return fstdistfunc_;
- }
+ DISTFUNC<int> get_dist_func() {
+ return fstdistfunc_;
+ }
- void *get_dist_func_param() {
- return &dim_;
- }
+ void *get_dist_func_param() {
+ return &dim_;
+ }
- ~L2SpaceI() {}
- };
-
-
-}
\ No newline at end of file
+ ~L2SpaceI() {}
+};
+} // namespace hnswlib