diff --git a/.kilo/skills/benchmarks/SKILL.md b/.kilo/skills/benchmarks/SKILL.md index 0136280..ea050a4 100644 --- a/.kilo/skills/benchmarks/SKILL.md +++ b/.kilo/skills/benchmarks/SKILL.md @@ -114,9 +114,12 @@ Examples: ### Run with hardware counters (benchmarks-diagnostic build, Linux only) +The `--benchmark_perf_counters` flag requests hardware counter collection via libpfm. Counter names are platform-specific but common ones include `CYCLES`, `INSTRUCTIONS`, `CACHE-MISSES`, `CACHE-REFERENCES`, `BRANCH-MISSES`, `BRANCH-INSTRUCTIONS`. + ```bash build/benchmarks-diagnostic_${BUILD_SUFFIX}/RelWithDebInfo/benchmarks \ --benchmark_filter="${FILTER}" \ + --benchmark_perf_counters=CYCLES,INSTRUCTIONS,CACHE-MISSES \ --benchmark_counters_tabular=true ``` @@ -174,7 +177,8 @@ perf script -F +pid > perf.data.txt | `--benchmark_min_time=` | Minimum run time per benchmark | | `--benchmark_format=json` | Machine-readable output | | `--benchmark_out=` | Save output to file | -| `--benchmark_counters_tabular=true` | Align hardware counter columns | +| `--benchmark_perf_counters=CYCLES,INSTRUCTIONS,...` | Collect hardware perf counters (requires libpfm build) | +| `--benchmark_counters_tabular=true` | Align user/perf counter columns into a table | | `--benchmark_time_unit=ms` | Change display unit (ns/us/ms/s) | ## Best Practices diff --git a/include/pixie/bits.h b/include/pixie/bits.h index aa4b371..50a8cc5 100644 --- a/include/pixie/bits.h +++ b/include/pixie/bits.h @@ -412,7 +412,7 @@ static inline uint16_t lower_bound_delta_8x64(const uint64_t* x, * @brief Compare 32 16-bit numbers of @p x with @p y and * return the count of numbers where @p x is less then @p y */ -uint16_t lower_bound_32x16(const uint16_t* x, uint16_t y) { +static inline uint16_t lower_bound_32x16(const uint16_t* x, uint16_t y) { #ifdef PIXIE_AVX512_SUPPORT auto y_32 = _mm512_set1_epi16(y); @@ -467,10 +467,10 @@ uint16_t lower_bound_32x16(const uint16_t* x, uint16_t y) { * offsets. * @param delta_scalar Shared delta offset. */ -uint16_t lower_bound_delta_32x16(const uint16_t* x, - uint16_t y, - const uint16_t* delta_array, - uint16_t delta_scalar) { +static inline uint16_t lower_bound_delta_32x16(const uint16_t* x, + uint16_t y, + const uint16_t* delta_array, + uint16_t delta_scalar) { #ifdef PIXIE_AVX512_SUPPORT const __m512i dlt_512 = _mm512_loadu_epi64(delta_array); @@ -539,7 +539,7 @@ uint16_t lower_bound_delta_32x16(const uint16_t* x, * @param result Pointer to store the 64 resulting 4-bit popcount values (packed * in 32 bytes) */ -void popcount_64x4(const uint8_t* x, uint8_t* result) { +static inline void popcount_64x4(const uint8_t* x, uint8_t* result) { #ifdef PIXIE_AVX512_SUPPORT __m256i data = _mm256_loadu_si256((__m256i const*)x); @@ -586,7 +586,7 @@ void popcount_64x4(const uint8_t* x, uint8_t* result) { * @param result Pointer to store the 64 resulting 4-bit popcount values * (packed in 32 bytes) */ -void popcount_32x8(const uint8_t* x, uint8_t* result) { +static inline void popcount_32x8(const uint8_t* x, uint8_t* result) { #ifdef PIXIE_AVX512_SUPPORT // Load 64 4-bit integers (256 bits total) __m256i data = _mm256_loadu_si256((__m256i const*)x); @@ -620,47 +620,50 @@ void popcount_32x8(const uint8_t* x, uint8_t* result) { #endif } -/** - * @brief Calculates 32 bit ranks of every 8th bit, result is stored as 32 - - * * 8-bit integers. - * @details Prefix sums are computed modulo 256 (uint8_t - * wraparound). - * - * @param x Pointer to 32 input 8-bit integers - * @param - * result Pointer to store the resulting 32 8-bit integers - */ #ifdef PIXIE_AVX2_SUPPORT -static inline __m256i excess_bit_masks_16x() noexcept { - return _mm256_setr_epi16(0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, - 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, - 0x1000, 0x2000, 0x4000, (int16_t)0x8000); -} - -static inline __m256i excess_prefix_sum_16x_i16(__m256i v) noexcept { - __m256i x = v; - __m256i t = _mm256_slli_si256(x, 2); - x = _mm256_add_epi16(x, t); - t = _mm256_slli_si256(x, 4); - x = _mm256_add_epi16(x, t); - t = _mm256_slli_si256(x, 8); - x = _mm256_add_epi16(x, t); - - __m128i lo = _mm256_extracti128_si256(x, 0); - __m128i hi = _mm256_extracti128_si256(x, 1); - const int16_t carry = (int16_t)_mm_extract_epi16(lo, 7); - hi = _mm_add_epi16(hi, _mm_set1_epi16(carry)); - - __m256i out = _mm256_castsi128_si256(lo); - out = _mm256_inserti128_si256(out, hi, 1); - return out; -} - -static inline int16_t excess_last_prefix_16x_i16(__m256i pref) noexcept { - __m128i hi = _mm256_extracti128_si256(pref, 1); - return (int16_t)_mm_extract_epi16(hi, 7); -} +// clang-format off +// LUT for total excess change across a 4-bit nibble +static inline const __m256i excess_lut_delta = _mm256_setr_epi8( + -4, -2, -2, 0, + -2, 0, 0, 2, + -2, 0, 0, 2, + 0, 2, 2, 4, + -4, -2, -2, 0, + -2, 0, 0, 2, + -2, 0, 0, 2, + 0, 2, 2, 4); + +// LUTs for target relative excess positions +static inline const __m256i excess_lut_pos0 = _mm256_setr_epi8( + -1, 1, -1, 1, + -1, 1, -1, 1, + -1, 1, -1, 1, + -1, 1, -1, 1, + -1, 1, -1, 1, + -1, 1, -1, 1, + -1, 1, -1, 1, + -1, 1, -1, 1); + +static inline const __m256i excess_lut_pos1 = _mm256_setr_epi8( + -2, 0, 0, 2, + -2, 0, 0, 2, + -2, 0, 0, 2, + -2, 0, 0, 2, + -2, 0, 0, 2, + -2, 0, 0, 2, + -2, 0, 0, 2, + -2, 0, 0, 2); + +static inline const __m256i excess_lut_pos2 = _mm256_setr_epi8( + -3, -1, -1, 1, + -1, 1, 1, 3, + -3, -1, -1, 1, + -1, 1, 1, 3, + -3, -1, -1, 1, + -1, 1, 1, 3, + -3, -1, -1, 1, + -1, 1, 1, 3); +// clang-format on #endif /** @@ -686,164 +689,77 @@ static inline void excess_positions_512(const uint64_t* s, } #ifdef PIXIE_AVX2_SUPPORT - static const __m256i masks = excess_bit_masks_16x(); - static const __m256i vzero = _mm256_setzero_si256(); - static const __m256i vallones = _mm256_cmpeq_epi16(vzero, vzero); - static const __m256i vminus1 = _mm256_set1_epi16(-1); - static const __m256i vtwo = _mm256_set1_epi16(2); - const __m256i vtarget = _mm256_set1_epi16((int16_t)target_x); - int cur = 0; - for (int k = 0; k < 32; ++k) { - const size_t bit_off = size_t(k) * 16; - const size_t word_idx = bit_off >> 6; - const size_t shift = bit_off & 63; - - uint16_t bits16 = (uint16_t)((s[word_idx] >> shift) & 0xFFFFull); - if (shift > 48 && word_idx + 1 < 8) { - bits16 |= (uint16_t)(s[word_idx + 1] << (64 - shift)); - } + const __m256i vdelta = excess_lut_delta; + const __m256i vpos0 = excess_lut_pos0; + const __m256i vpos1 = excess_lut_pos1; + const __m256i vpos2 = excess_lut_pos2; + const __m256i vmult = _mm256_set1_epi16(0x1001); + const __m256i vbit0 = _mm256_set1_epi8(1); + const __m256i vbit1 = _mm256_set1_epi8(2); + const __m256i vbit2 = _mm256_set1_epi8(4); + const __m256i vbit3 = _mm256_set1_epi8(8); + const __m128i vnibble_mask = _mm_set1_epi8(0x0F); - const __m256i vb = _mm256_set1_epi16((int16_t)bits16); - const __m256i m = _mm256_and_si256(vb, masks); - const __m256i is_zero = _mm256_cmpeq_epi16(m, vzero); - const __m256i is_set = _mm256_andnot_si256(is_zero, vallones); - const __m256i steps = - _mm256_add_epi16(vminus1, _mm256_and_si256(is_set, vtwo)); - - const __m256i pref_rel = excess_prefix_sum_16x_i16(steps); - const __m256i base = _mm256_set1_epi16((int16_t)cur); - const __m256i pref_abs = _mm256_add_epi16(pref_rel, base); - const __m256i cmp = _mm256_cmpeq_epi16(pref_abs, vtarget); - - const uint32_t m32 = (uint32_t)_mm256_movemask_epi8(cmp); - const uint16_t m16 = (uint16_t)_pext_u32(m32, 0xAAAAAAAAu); - - const size_t out_word = bit_off >> 6; - const size_t out_shift = bit_off & 63; - out[out_word] |= uint64_t(m16) << out_shift; - if (out_shift > 48 && out_word + 1 < 8) { - out[out_word + 1] |= uint64_t(m16) >> (64 - out_shift); - } + for (int k = 0; k < 4; ++k) { + int block_delta = + 2 * (std::popcount(s[2 * k]) + std::popcount(s[2 * k + 1])) - 128; - cur += (int)excess_last_prefix_16x_i16(pref_rel); - } -#else - int cur = 0; - for (size_t i = 0; i < 512; ++i) { - const uint64_t w = s[i >> 6]; - const int bit = int((w >> (i & 63)) & 1ull); - cur += bit ? +1 : -1; - if (cur == target_x) { - out[i >> 6] |= (uint64_t{1} << (i & 63)); + const int d = 2 * target_x - block_delta; + if (d < -128 || d > 128) { + target_x -= block_delta; + continue; } - } -#endif -} + __m128i word_vec = _mm_loadu_si128((const __m128i*)&s[2 * k]); + __m128i lo_nibbles = _mm_and_si128(word_vec, vnibble_mask); + __m128i hi_nibbles = + _mm_and_si128(_mm_srli_epi16(word_vec, 4), vnibble_mask); -#ifdef PIXIE_AVX2_SUPPORT -static inline __m128i excess_nibble_delta_lut() noexcept { - alignas(16) static const int8_t lut[16] = {-4, -2, -2, 0, -2, 0, 0, 2, - -2, 0, 0, 2, 0, 2, 2, 4}; - return _mm_load_si128((const __m128i*)lut); -} + __m128i unpack_lo = _mm_unpacklo_epi8(lo_nibbles, hi_nibbles); + __m128i unpack_hi = _mm_unpackhi_epi8(lo_nibbles, hi_nibbles); -static inline __m128i excess_nibble_pos0_lut() noexcept { - alignas(16) static const int8_t lut[16] = {-1, 1, -1, 1, -1, 1, -1, 1, - -1, 1, -1, 1, -1, 1, -1, 1}; - return _mm_load_si128((const __m128i*)lut); -} + __m256i nibbles = _mm256_inserti128_si256(_mm256_castsi128_si256(unpack_lo), + unpack_hi, 1); -static inline __m128i excess_nibble_pos1_lut() noexcept { - alignas(16) static const int8_t lut[16] = {-2, 0, 0, 2, -2, 0, 0, 2, - -2, 0, 0, 2, -2, 0, 0, 2}; - return _mm_load_si128((const __m128i*)lut); -} + __m256i ps = _mm256_shuffle_epi8(vdelta, nibbles); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 1)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 2)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 4)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 8)); -static inline __m128i excess_nibble_pos2_lut() noexcept { - alignas(16) static const int8_t lut[16] = {-3, -1, -1, 1, -1, 1, 1, 3, - -3, -1, -1, 1, -1, 1, 1, 3}; - return _mm_load_si128((const __m128i*)lut); -} + __m128i ps_lo = _mm256_castsi256_si128(ps); + __m128i ps_hi = _mm256_extracti128_si256(ps, 1); + __m128i carry = _mm_set1_epi8((int8_t)_mm_extract_epi8(ps_lo, 15)); + ps_hi = _mm_add_epi8(ps_hi, carry); + ps = _mm256_inserti128_si256(_mm256_castsi128_si256(ps_lo), ps_hi, 1); -#endif + __m256i b = _mm256_permute2x128_si256(ps, ps, 0x08); + __m256i excl_ps = _mm256_alignr_epi8(ps, b, 15); -static inline void excess_positions_512_lut(const uint64_t* s, - int target_x, - uint64_t* out) noexcept { - out[0] = out[1] = out[2] = out[3] = 0; - out[4] = out[5] = out[6] = out[7] = 0; + __m256i vtgt = _mm256_set1_epi8((int8_t)target_x); + __m256i t = _mm256_sub_epi8(vtgt, excl_ps); - if (target_x < -512 || target_x > 512) { - return; - } + __m256i cmp0 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos0, nibbles), t); + __m256i cmp1 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos1, nibbles), t); + __m256i cmp2 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(vpos2, nibbles), t); + __m256i cmp3 = _mm256_cmpeq_epi8(ps, vtgt); -#ifdef PIXIE_AVX2_SUPPORT - const __m128i vdelta = excess_nibble_delta_lut(); - const __m128i vpos0 = excess_nibble_pos0_lut(); - const __m128i vpos1 = excess_nibble_pos1_lut(); - const __m128i vpos2 = excess_nibble_pos2_lut(); - const __m128i vnibble_mask = _mm_set1_epi8(0x0F); + __m256i bit0 = _mm256_and_si256(cmp0, vbit0); + __m256i bit1 = _mm256_and_si256(cmp1, vbit1); + __m256i bit2 = _mm256_and_si256(cmp2, vbit2); + __m256i bit3 = _mm256_and_si256(cmp3, vbit3); - int cur = 0; - for (int w = 0; w < 8; ++w) { - const uint64_t word = s[w]; - const int word_delta = 2 * static_cast(std::popcount(word)) - 64; - const int target_local = target_x - cur; - - const int d = 2 * target_local - word_delta; - if (d < -64 || d > 64) { - cur += word_delta; - continue; - } + __m256i total_match = _mm256_or_si256(_mm256_or_si256(bit0, bit1), + _mm256_or_si256(bit2, bit3)); + + __m256i res = _mm256_maddubs_epi16(total_match, vmult); + __m128i res_lo = _mm256_castsi256_si128(res); + __m128i res_hi = _mm256_extracti128_si256(res, 1); + __m128i packed = _mm_packus_epi16(res_lo, res_hi); - __m128i bytes = _mm_cvtsi64_si128(static_cast(word)); - __m128i lo = _mm_and_si128(bytes, vnibble_mask); - __m128i hi = _mm_and_si128(_mm_srli_epi16(bytes, 4), vnibble_mask); - __m128i nibbles = _mm_unpacklo_epi8(lo, hi); - - __m128i deltas = _mm_shuffle_epi8(vdelta, nibbles); - - __m128i ps = deltas; - ps = _mm_add_epi8(ps, _mm_slli_si128(ps, 1)); - ps = _mm_add_epi8(ps, _mm_slli_si128(ps, 2)); - ps = _mm_add_epi8(ps, _mm_slli_si128(ps, 4)); - ps = _mm_add_epi8(ps, _mm_slli_si128(ps, 8)); - - __m128i excl = _mm_slli_si128(ps, 1); - - __m128i vtarget_local = _mm_set1_epi8(static_cast(target_local)); - // Overflow safety: excl[i] ∈ [-60, 60] (exclusive prefix sum of up to - // 15 deltas each in [-4, +4]), target_local ∈ [-64, 64]. - // t = target_local - excl ∈ [-124, 124], fits perfectly in int8. - __m128i t = _mm_sub_epi8(vtarget_local, excl); - - __m128i cmp0 = _mm_cmpeq_epi8(_mm_shuffle_epi8(vpos0, nibbles), t); - uint16_t bits0 = static_cast(_mm_movemask_epi8(cmp0)); - - __m128i cmp1 = _mm_cmpeq_epi8(_mm_shuffle_epi8(vpos1, nibbles), t); - uint16_t bits1 = static_cast(_mm_movemask_epi8(cmp1)); - - __m128i cmp2 = _mm_cmpeq_epi8(_mm_shuffle_epi8(vpos2, nibbles), t); - uint16_t bits2 = static_cast(_mm_movemask_epi8(cmp2)); - - // cmp3 conceptually checks delta == t, i.e. delta == target_local - excl. - // Since excl + delta == ps (the inclusive prefix sum), this is simply - // ps == target_local. Saves one add and one shuffle. - __m128i cmp3 = _mm_cmpeq_epi8(ps, vtarget_local); - uint16_t bits3 = static_cast(_mm_movemask_epi8(cmp3)); - - // Note: We use movemask + pdep to interleave bits instead of pure AVX2 - // (e.g. maddubs + packus). While pdep is microcoded/slow on older AMD CPUs - // (Zen 2), it is hardware-accelerated and ~15% faster on modern - // architectures (Zen 3+, Intel) due to fewer vector operations and a - // shorter dependency chain. - out[w] = _pdep_u64(bits0, 0x1111111111111111ULL) | - _pdep_u64(bits1, 0x2222222222222222ULL) | - _pdep_u64(bits2, 0x4444444444444444ULL) | - _pdep_u64(bits3, 0x8888888888888888ULL); - - cur += word_delta; + _mm_storeu_si128((__m128i*)&out[2 * k], packed); + + target_x -= block_delta; } #else int cur = 0; @@ -858,7 +774,18 @@ static inline void excess_positions_512_lut(const uint64_t* s, #endif } -void rank_32x8(const uint8_t* x, uint8_t* result) { +/** + * @brief Calculates 32 bit ranks of every 8th bit, result is stored as 32 + + * * 8-bit integers. + * @details Prefix sums are computed modulo 256 (uint8_t + * wraparound). + * + * @param x Pointer to 32 input 8-bit integers + * @param + * result Pointer to store the resulting 32 8-bit integers + */ +static inline void rank_32x8(const uint8_t* x, uint8_t* result) { #ifdef PIXIE_AVX512_SUPPORT // Step 1: Calculate popcount of each byte popcount_32x8(x, result); diff --git a/include/pixie/experimental/excess.h b/include/pixie/experimental/excess.h new file mode 100644 index 0000000..de8de0f --- /dev/null +++ b/include/pixie/experimental/excess.h @@ -0,0 +1,650 @@ +#pragma once + +#include + +#include +#include +#include + +namespace pixie::experimental { + +#ifdef PIXIE_AVX2_SUPPORT +// clang-format off +static inline const __m256i excess_branch_lut_em4 = _mm256_setr_epi8( + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00); + +static inline const __m256i excess_branch_lut_em3 = _mm256_setr_epi8( + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00); + +static inline const __m256i excess_branch_lut_em2 = _mm256_setr_epi8( + 0x02, 0x08, 0x08, 0x00, 0x0A, 0x00, 0x00, 0x00, + 0x0A, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x02, 0x08, 0x08, 0x00, 0x0A, 0x00, 0x00, 0x00, + 0x0A, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00); + +static inline const __m256i excess_branch_lut_em1 = _mm256_setr_epi8( + 0x01, 0x04, 0x05, 0x00, 0x05, 0x00, 0x01, 0x00, + 0x01, 0x04, 0x05, 0x00, 0x05, 0x00, 0x01, 0x00, + 0x01, 0x04, 0x05, 0x00, 0x05, 0x00, 0x01, 0x00, + 0x01, 0x04, 0x05, 0x00, 0x05, 0x00, 0x01, 0x00); + +static inline const __m256i excess_branch_lut_e0 = _mm256_setr_epi8( + 0x00, 0x02, 0x02, 0x08, 0x00, 0x0A, 0x0A, 0x00, + 0x00, 0x0A, 0x0A, 0x00, 0x08, 0x02, 0x02, 0x00, + 0x00, 0x02, 0x02, 0x08, 0x00, 0x0A, 0x0A, 0x00, + 0x00, 0x0A, 0x0A, 0x00, 0x08, 0x02, 0x02, 0x00); + +static inline const __m256i excess_branch_lut_e1 = _mm256_setr_epi8( + 0x00, 0x01, 0x00, 0x05, 0x00, 0x05, 0x04, 0x01, + 0x00, 0x01, 0x00, 0x05, 0x00, 0x05, 0x04, 0x01, + 0x00, 0x01, 0x00, 0x05, 0x00, 0x05, 0x04, 0x01, + 0x00, 0x01, 0x00, 0x05, 0x00, 0x05, 0x04, 0x01); + +static inline const __m256i excess_branch_lut_e2 = _mm256_setr_epi8( + 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0A, + 0x00, 0x00, 0x00, 0x0A, 0x00, 0x08, 0x08, 0x02, + 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0A, + 0x00, 0x00, 0x00, 0x0A, 0x00, 0x08, 0x08, 0x02); + +static inline const __m256i excess_branch_lut_e3 = _mm256_setr_epi8( + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04); + +static inline const __m256i excess_branch_lut_e4 = _mm256_setr_epi8( + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08); +// clang-format on + +static inline __m256i excess_bit_masks_16x() noexcept { + return _mm256_setr_epi16(0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, + 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, + 0x1000, 0x2000, 0x4000, (int16_t)0x8000); +} + +static inline __m256i excess_prefix_sum_16x_i16(__m256i v) noexcept { + __m256i x = v; + __m256i t = _mm256_slli_si256(x, 2); + x = _mm256_add_epi16(x, t); + t = _mm256_slli_si256(x, 4); + x = _mm256_add_epi16(x, t); + t = _mm256_slli_si256(x, 8); + x = _mm256_add_epi16(x, t); + + __m128i lo = _mm256_extracti128_si256(x, 0); + __m128i hi = _mm256_extracti128_si256(x, 1); + const int16_t carry = (int16_t)_mm_extract_epi16(lo, 7); + hi = _mm_add_epi16(hi, _mm_set1_epi16(carry)); + + __m256i out = _mm256_castsi128_si256(lo); + out = _mm256_inserti128_si256(out, hi, 1); + return out; +} + +static inline int16_t excess_last_prefix_16x_i16(__m256i pref) noexcept { + __m128i hi = _mm256_extracti128_si256(pref, 1); + return (int16_t)_mm_extract_epi16(hi, 7); +} + +static inline __m256i excess_bit_masks_32x8() noexcept { + return _mm256_setr_epi8(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (char)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (char)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (char)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (char)0x80); +} + +static inline __m256i excess_byte_selectors_32x8() noexcept { + return _mm256_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, + 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); +} + +static inline __m256i excess_prefix_sum_32x_i8(__m256i v) noexcept { + __m256i x = v; + __m256i t = _mm256_slli_si256(x, 1); + x = _mm256_add_epi8(x, t); + t = _mm256_slli_si256(x, 2); + x = _mm256_add_epi8(x, t); + t = _mm256_slli_si256(x, 4); + x = _mm256_add_epi8(x, t); + t = _mm256_slli_si256(x, 8); + x = _mm256_add_epi8(x, t); + + __m128i lo = _mm256_extracti128_si256(x, 0); + __m128i hi = _mm256_extracti128_si256(x, 1); + const int8_t carry = (int8_t)_mm_extract_epi8(lo, 15); + hi = _mm_add_epi8(hi, _mm_set1_epi8(carry)); + + __m256i out = _mm256_castsi128_si256(lo); + out = _mm256_inserti128_si256(out, hi, 1); + return out; +} + +static inline int8_t excess_last_prefix_32x_i8(__m256i pref) noexcept { + __m128i hi = _mm256_extracti128_si256(pref, 1); + return (int8_t)_mm_extract_epi8(hi, 15); +} + +static inline void excess_positions_512_branching_lut(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + + if (target_x < -512 || target_x > 512) { + return; + } + + int cur = 0; + const __m256i vdelta = + _mm256_setr_epi8(-4, -2, -2, 0, -2, 0, 0, 2, -2, 0, 0, 2, 0, 2, 2, 4, -4, + -2, -2, 0, -2, 0, 0, 2, -2, 0, 0, 2, 0, 2, 2, 4); + const __m256i vmult = _mm256_set1_epi16(0x1001); + const __m128i vnibble_mask = _mm_set1_epi8(0x0F); + + for (int k = 0; k < 4; ++k) { + __m128i word_vec = _mm_loadu_si128((const __m128i*)&s[2 * k]); + __m128i lo_nibbles = _mm_and_si128(word_vec, vnibble_mask); + __m128i hi_nibbles = + _mm_and_si128(_mm_srli_epi16(word_vec, 4), vnibble_mask); + + __m128i unpack_lo = _mm_unpacklo_epi8(lo_nibbles, hi_nibbles); + __m128i unpack_hi = _mm_unpackhi_epi8(lo_nibbles, hi_nibbles); + __m256i nibbles = _mm256_inserti128_si256(_mm256_castsi128_si256(unpack_lo), + unpack_hi, 1); + + __m256i ps = _mm256_shuffle_epi8(vdelta, nibbles); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 1)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 2)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 4)); + ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 8)); + + __m128i ps_lo = _mm256_castsi256_si128(ps); + __m128i ps_hi = _mm256_extracti128_si256(ps, 1); + __m128i carry = _mm_set1_epi8((int8_t)_mm_extract_epi8(ps_lo, 15)); + ps_hi = _mm_add_epi8(ps_hi, carry); + ps = _mm256_inserti128_si256(_mm256_castsi128_si256(ps_lo), ps_hi, 1); + + __m256i b = _mm256_permute2x128_si256(ps, ps, 0x08); + __m256i excl_ps = _mm256_alignr_epi8(ps, b, 15); + + int target_rel = target_x - cur; + int block_delta = + 2 * (std::popcount(s[2 * k]) + std::popcount(s[2 * k + 1])) - 128; + + const int d = 2 * target_rel - block_delta; + if (d < -128 || d > 128) { + cur += block_delta; + continue; + } + + if (target_rel == 128 || target_rel == -128) { + out[2 * k + 1] |= (uint64_t{1} << 63); + cur += block_delta; + continue; + } + + __m256i t = _mm256_sub_epi8(_mm256_set1_epi8((int8_t)target_rel), excl_ps); + __m256i total_match = _mm256_setzero_si256(); + __m256i t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(-4)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_em4, nibbles))); + t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(-3)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_em3, nibbles))); + t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(-2)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_em2, nibbles))); + t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(-1)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_em1, nibbles))); + t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(0)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_e0, nibbles))); + t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(1)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_e1, nibbles))); + t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(2)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_e2, nibbles))); + t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(3)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_e3, nibbles))); + t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(4)); + total_match = _mm256_or_si256( + total_match, + _mm256_and_si256(t_eq, + _mm256_shuffle_epi8(excess_branch_lut_e4, nibbles))); + + __m256i res = _mm256_maddubs_epi16(total_match, vmult); + __m128i packed = _mm_packus_epi16(_mm256_castsi256_si128(res), + _mm256_extracti128_si256(res, 1)); + _mm_storeu_si128((__m128i*)&out[2 * k], packed); + + cur += block_delta; + } +} +#else +static inline void excess_positions_512_branching_lut(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + excess_positions_512(s, target_x, out); +} +#endif + +#ifdef PIXIE_AVX512_SUPPORT +static inline __m512i excess_lut_delta_64x() noexcept { + return _mm512_broadcast_i32x4( + _mm_setr_epi8(-4, -2, -2, 0, -2, 0, 0, 2, -2, 0, 0, 2, 0, 2, 2, 4)); +} + +static inline __m512i excess_lut_pos0_64x() noexcept { + return _mm512_broadcast_i32x4( + _mm_setr_epi8(-1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1)); +} + +static inline __m512i excess_lut_pos1_64x() noexcept { + return _mm512_broadcast_i32x4( + _mm_setr_epi8(-2, 0, 0, 2, -2, 0, 0, 2, -2, 0, 0, 2, -2, 0, 0, 2)); +} + +static inline __m512i excess_lut_pos2_64x() noexcept { + return _mm512_broadcast_i32x4( + _mm_setr_epi8(-3, -1, -1, 1, -1, 1, 1, 3, -3, -1, -1, 1, -1, 1, 1, 3)); +} + +static inline __m512i excess_bit_masks_64x8() noexcept { + alignas(64) static constexpr int8_t masks[64] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80}; + return _mm512_load_si512((const void*)masks); +} + +static inline __m512i excess_byte_selectors_64x8() noexcept { + alignas(64) static constexpr int8_t selectors[64] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, + 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, + 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7}; + return _mm512_load_si512((const void*)selectors); +} + +static inline __m512i excess_prefix_sum_64x_i8(__m512i v) noexcept { + __m512i x = v; + __m512i t = _mm512_bslli_epi128(x, 1); + x = _mm512_add_epi8(x, t); + t = _mm512_bslli_epi128(x, 2); + x = _mm512_add_epi8(x, t); + t = _mm512_bslli_epi128(x, 4); + x = _mm512_add_epi8(x, t); + t = _mm512_bslli_epi128(x, 8); + x = _mm512_add_epi8(x, t); + + const __m512i last_byte = _mm512_set1_epi8(15); + const __m512i lane_carry = _mm512_shuffle_epi8(x, last_byte); + const __m512i shift1_idx = + _mm512_setr_epi32(0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + const __m512i shift2_idx = + _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7); + const __m512i shift3_idx = + _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3); + + __m512i lane_base = + _mm512_maskz_permutexvar_epi32(0xFFF0, shift1_idx, lane_carry); + lane_base = _mm512_add_epi8(lane_base, _mm512_maskz_permutexvar_epi32( + 0xFF00, shift2_idx, lane_carry)); + lane_base = _mm512_add_epi8(lane_base, _mm512_maskz_permutexvar_epi32( + 0xF000, shift3_idx, lane_carry)); + return _mm512_add_epi8(x, lane_base); +} + +static inline __m512i excess_prefix_sum_2x32_i8(__m512i v) noexcept { + __m512i x = v; + __m512i t = _mm512_bslli_epi128(x, 1); + x = _mm512_add_epi8(x, t); + t = _mm512_bslli_epi128(x, 2); + x = _mm512_add_epi8(x, t); + t = _mm512_bslli_epi128(x, 4); + x = _mm512_add_epi8(x, t); + t = _mm512_bslli_epi128(x, 8); + x = _mm512_add_epi8(x, t); + + const __m512i last_byte = _mm512_set1_epi8(15); + const __m512i lane_carry = _mm512_shuffle_epi8(x, last_byte); + const __m512i prev_lane_idx = + _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8); + const __mmask16 carry_to_second_lane_of_each_half = 0xF0F0; + const __m512i lane_base = _mm512_maskz_permutexvar_epi32( + carry_to_second_lane_of_each_half, prev_lane_idx, lane_carry); + return _mm512_add_epi8(x, lane_base); +} + +static inline __m512i excess_nibbles_64x_from_256(__m256i words) noexcept { + const __m512i bytes16 = _mm512_cvtepu8_epi16(words); + const __m512i low = _mm512_and_si512(bytes16, _mm512_set1_epi16(0x000F)); + const __m512i high = _mm512_and_si512(_mm512_srli_epi16(bytes16, 4), + _mm512_set1_epi16(0x000F)); + return _mm512_or_si512(low, _mm512_slli_epi16(high, 8)); +} + +static inline __m512i excess_exclusive_prefix_2x32_i8(__m512i pref) noexcept { + const __m512i zero = _mm512_setzero_si512(); + __m512i out = _mm512_alignr_epi8(pref, zero, 15); + + const __m512i last_byte = _mm512_set1_epi8(15); + const __m512i lane_carry = _mm512_shuffle_epi8(pref, last_byte); + const __m512i prev_lane_idx = + _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8); + const __m512i carry_dwords = + _mm512_permutexvar_epi32(prev_lane_idx, lane_carry); + const __mmask64 first_byte_of_second_lane_in_each_half = + (uint64_t{1} << 16) | (uint64_t{1} << 48); + return _mm512_or_si512( + out, _mm512_maskz_mov_epi8(first_byte_of_second_lane_in_each_half, + carry_dwords)); +} + +static inline uint64_t excess_repeat_byte(int value) noexcept { + return uint64_t{0x0101010101010101} * + static_cast(static_cast(value)); +} + +static inline void excess_positions_512_lut_avx512(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + + if (target_x < -512 || target_x > 512) { + return; + } + + static const __m512i vdelta = excess_lut_delta_64x(); + static const __m512i vpos0 = excess_lut_pos0_64x(); + static const __m512i vpos1 = excess_lut_pos1_64x(); + static const __m512i vpos2 = excess_lut_pos2_64x(); + static const __m512i vbit0 = _mm512_set1_epi8(1); + static const __m512i vbit1 = _mm512_set1_epi8(2); + static const __m512i vbit2 = _mm512_set1_epi8(4); + static const __m512i vbit3 = _mm512_set1_epi8(8); + static const __m512i vmult = _mm512_set1_epi16(0x1001); + + for (int k = 0; k < 2; ++k) { + const int base_word = 4 * k; + const int delta0 = + 2 * (std::popcount(s[base_word]) + std::popcount(s[base_word + 1])) - + 128; + const int delta1 = 2 * (std::popcount(s[base_word + 2]) + + std::popcount(s[base_word + 3])) - + 128; + const int target0 = target_x; + const int target1 = target_x - delta0; + const bool reachable0 = [&] { + const int d = 2 * target0 - delta0; + return -128 <= d && d <= 128; + }(); + const bool reachable1 = [&] { + const int d = 2 * target1 - delta1; + return -128 <= d && d <= 128; + }(); + + if (!reachable0 && !reachable1) { + target_x -= delta0 + delta1; + continue; + } + + const __m256i words = + _mm256_loadu_si256(reinterpret_cast(&s[base_word])); + const __m512i nibbles = excess_nibbles_64x_from_256(words); + const __m512i ps = + excess_prefix_sum_2x32_i8(_mm512_shuffle_epi8(vdelta, nibbles)); + const __m512i excl_ps = excess_exclusive_prefix_2x32_i8(ps); + const uint64_t repeated0 = excess_repeat_byte(target0); + const uint64_t repeated1 = excess_repeat_byte(target1); + const __m512i vtgt = + _mm512_setr_epi64(repeated0, repeated0, repeated0, repeated0, repeated1, + repeated1, repeated1, repeated1); + const __m512i t = _mm512_sub_epi8(vtgt, excl_ps); + + const __mmask64 cmp0 = + _mm512_cmpeq_epi8_mask(_mm512_shuffle_epi8(vpos0, nibbles), t); + const __mmask64 cmp1 = + _mm512_cmpeq_epi8_mask(_mm512_shuffle_epi8(vpos1, nibbles), t); + const __mmask64 cmp2 = + _mm512_cmpeq_epi8_mask(_mm512_shuffle_epi8(vpos2, nibbles), t); + const __mmask64 cmp3 = _mm512_cmpeq_epi8_mask(ps, vtgt); + __m512i total_match = _mm512_maskz_mov_epi8(cmp0, vbit0); + total_match = + _mm512_or_si512(total_match, _mm512_maskz_mov_epi8(cmp1, vbit1)); + total_match = + _mm512_or_si512(total_match, _mm512_maskz_mov_epi8(cmp2, vbit2)); + total_match = + _mm512_or_si512(total_match, _mm512_maskz_mov_epi8(cmp3, vbit3)); + + const __mmask64 active = + (reachable0 ? __mmask64{0x00000000FFFFFFFFull} : __mmask64{0}) | + (reachable1 ? __mmask64{0xFFFFFFFF00000000ull} : __mmask64{0}); + total_match = _mm512_maskz_mov_epi8(active, total_match); + + const __m512i res = _mm512_maddubs_epi16(total_match, vmult); + const __m256i packed = _mm512_cvtepi16_epi8(res); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&out[base_word]), packed); + + target_x -= delta0 + delta1; + } +} +#else +static inline void excess_positions_512_lut_avx512(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + excess_positions_512(s, target_x, out); +} +#endif + +static inline void excess_positions_512_expand(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + + if (target_x < -512 || target_x > 512) { + return; + } + +#ifdef PIXIE_AVX2_SUPPORT + static const __m256i masks = excess_bit_masks_16x(); + static const __m256i vzero = _mm256_setzero_si256(); + static const __m256i vallones = _mm256_cmpeq_epi16(vzero, vzero); + static const __m256i vminus1 = _mm256_set1_epi16(-1); + static const __m256i vtwo = _mm256_set1_epi16(2); + const __m256i vtarget = _mm256_set1_epi16((int16_t)target_x); + + int cur = 0; + for (int block = 0; block < 4; ++block) { + const int target_rel = target_x - cur; + if (target_rel <= -64 || target_rel >= 64) { + const int block_delta = + 2 * (std::popcount(s[2 * block]) + std::popcount(s[2 * block + 1])) - + 128; + const int reachability = 2 * target_rel - block_delta; + if (reachability < -128 || reachability > 128) { + cur += block_delta; + continue; + } + } + + for (int j = 0; j < 8; ++j) { + const int k = 8 * block + j; + const size_t word_idx = size_t(k) >> 2; + const size_t shift = size_t(k & 3) * 16; + const uint16_t bits16 = + static_cast((s[word_idx] >> shift) & 0xFFFFull); + + const __m256i vb = _mm256_set1_epi16((int16_t)bits16); + const __m256i m = _mm256_and_si256(vb, masks); + const __m256i is_zero = _mm256_cmpeq_epi16(m, vzero); + const __m256i is_set = _mm256_andnot_si256(is_zero, vallones); + const __m256i steps = + _mm256_add_epi16(vminus1, _mm256_and_si256(is_set, vtwo)); + + const __m256i pref_rel = excess_prefix_sum_16x_i16(steps); + const __m256i base = _mm256_set1_epi16((int16_t)cur); + const __m256i pref_abs = _mm256_add_epi16(pref_rel, base); + const __m256i cmp = _mm256_cmpeq_epi16(pref_abs, vtarget); + + const uint32_t m32 = (uint32_t)_mm256_movemask_epi8(cmp); + const uint16_t m16 = (uint16_t)_pext_u32(m32, 0xAAAAAAAAu); + + out[word_idx] |= uint64_t(m16) << shift; + cur += (int)excess_last_prefix_16x_i16(pref_rel); + } + } +#else + int cur = 0; + for (size_t i = 0; i < 512; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (cur == target_x) { + out[i >> 6] |= (uint64_t{1} << (i & 63)); + } + } +#endif +} + +static inline void excess_positions_512_expand8(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + + if (target_x < -512 || target_x > 512) { + return; + } + +#ifdef PIXIE_AVX2_SUPPORT + static const __m256i byte_selectors = excess_byte_selectors_32x8(); + static const __m256i masks = excess_bit_masks_32x8(); + static const __m256i vzero = _mm256_setzero_si256(); + static const __m256i vallones = _mm256_cmpeq_epi8(vzero, vzero); + static const __m256i vminus1 = _mm256_set1_epi8(-1); + static const __m256i vtwo = _mm256_set1_epi8(2); + + int cur = 0; + for (int k = 0; k < 16; ++k) { + const size_t word_idx = size_t(k) >> 1; + const size_t shift = size_t(k & 1) * 32; + const uint32_t bits32 = + static_cast((s[word_idx] >> shift) & 0xFFFFFFFFull); + + const int target_rel = target_x - cur; + if (target_rel < -32 || target_rel > 32) { + cur += 2 * static_cast(std::popcount(bits32)) - 32; + continue; + } + + const __m256i src = _mm256_set1_epi32((int)bits32); + const __m256i bytes = _mm256_shuffle_epi8(src, byte_selectors); + const __m256i m = _mm256_and_si256(bytes, masks); + const __m256i is_zero = _mm256_cmpeq_epi8(m, vzero); + const __m256i is_set = _mm256_andnot_si256(is_zero, vallones); + const __m256i steps = + _mm256_add_epi8(vminus1, _mm256_and_si256(is_set, vtwo)); + + const __m256i pref_rel = excess_prefix_sum_32x_i8(steps); + const __m256i vtarget = _mm256_set1_epi8((int8_t)target_rel); + const __m256i cmp = _mm256_cmpeq_epi8(pref_rel, vtarget); + const uint32_t mask = static_cast(_mm256_movemask_epi8(cmp)); + + out[word_idx] |= uint64_t(mask) << shift; + cur += static_cast(excess_last_prefix_32x_i8(pref_rel)); + } +#else + int cur = 0; + for (size_t i = 0; i < 512; ++i) { + const uint64_t w = s[i >> 6]; + const int bit = int((w >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (cur == target_x) { + out[i >> 6] |= (uint64_t{1} << (i & 63)); + } + } +#endif +} + +static inline void excess_positions_512_expand_avx512(const uint64_t* s, + int target_x, + uint64_t* out) noexcept { + out[0] = out[1] = out[2] = out[3] = 0; + out[4] = out[5] = out[6] = out[7] = 0; + + if (target_x < -512 || target_x > 512) { + return; + } + +#ifdef PIXIE_AVX512_SUPPORT + static const __m512i byte_selectors = excess_byte_selectors_64x8(); + static const __m512i masks = excess_bit_masks_64x8(); + static const __m512i vzero = _mm512_setzero_si512(); + static const __m512i vallones = _mm512_set1_epi8(-1); + static const __m512i vminus1 = _mm512_set1_epi8(-1); + static const __m512i vtwo = _mm512_set1_epi8(2); + + int cur = 0; + for (int k = 0; k < 8; ++k) { + const uint64_t bits64 = s[k]; + const int target_rel = target_x - cur; + if (target_rel < -64 || target_rel > 64) { + cur += 2 * static_cast(std::popcount(bits64)) - 64; + continue; + } + + const __m512i src = _mm512_set1_epi64(static_cast(bits64)); + const __m512i bytes = _mm512_shuffle_epi8(src, byte_selectors); + const __m512i m = _mm512_and_si512(bytes, masks); + const __mmask64 is_zero = _mm512_cmpeq_epi8_mask(m, vzero); + const __m512i is_set = _mm512_maskz_mov_epi8(~is_zero, vallones); + const __m512i steps = + _mm512_add_epi8(vminus1, _mm512_and_si512(is_set, vtwo)); + + const __m512i pref_rel = excess_prefix_sum_64x_i8(steps); + const __mmask64 match = + _mm512_cmpeq_epi8_mask(pref_rel, _mm512_set1_epi8((int8_t)target_rel)); + out[k] = static_cast(match); + cur += 2 * static_cast(std::popcount(bits64)) - 64; + } +#else + excess_positions_512_expand8(s, target_x, out); +#endif +} + +} // namespace pixie::experimental diff --git a/src/benchmarks/excess_positions_benchmarks.cpp b/src/benchmarks/excess_positions_benchmarks.cpp index cf52f2b..62e0e52 100644 --- a/src/benchmarks/excess_positions_benchmarks.cpp +++ b/src/benchmarks/excess_positions_benchmarks.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -7,6 +8,12 @@ #include #include +using pixie::experimental::excess_positions_512_branching_lut; +using pixie::experimental::excess_positions_512_expand; +using pixie::experimental::excess_positions_512_expand8; +using pixie::experimental::excess_positions_512_expand_avx512; +using pixie::experimental::excess_positions_512_lut_avx512; + static std::vector> make_blocks( size_t num_blocks = 4096) { std::mt19937_64 rng(42); @@ -45,7 +52,7 @@ BENCHMARK(BM_ExcessPositions512) ->Args({8}) ->Args({64}); -static void BM_ExcessPositions512_Scalar(benchmark::State& state) { +static void BM_ExcessPositions512_BranchingLUT(benchmark::State& state) { const int target_x = state.range(0); const auto blocks = make_blocks(); const size_t num_blocks = blocks.size(); @@ -55,17 +62,7 @@ static void BM_ExcessPositions512_Scalar(benchmark::State& state) { for (auto _ : state) { const auto& s = blocks[idx % num_blocks]; - for (int w = 0; w < 8; ++w) { - out[w] = 0; - } - int cur = 0; - for (size_t i = 0; i < 512; ++i) { - const int bit = int((s[i >> 6] >> (i & 63)) & 1ull); - cur += bit ? +1 : -1; - if (cur == target_x) { - out[i >> 6] |= (uint64_t{1} << (i & 63)); - } - } + excess_positions_512_branching_lut(s.data(), target_x, out); benchmark::DoNotOptimize(out); ++idx; } @@ -73,7 +70,59 @@ static void BM_ExcessPositions512_Scalar(benchmark::State& state) { state.SetItemsProcessed(state.iterations()); } -BENCHMARK(BM_ExcessPositions512_Scalar) +BENCHMARK(BM_ExcessPositions512_BranchingLUT) + ->ArgNames({"X"}) + ->Args({-64}) + ->Args({-8}) + ->Args({0}) + ->Args({8}) + ->Args({64}); + +static void BM_ExcessPositions512_LUTAVX512(benchmark::State& state) { + const int target_x = state.range(0); + const auto blocks = make_blocks(); + const size_t num_blocks = blocks.size(); + + alignas(64) uint64_t out[8]; + size_t idx = 0; + + for (auto _ : state) { + const auto& s = blocks[idx % num_blocks]; + excess_positions_512_lut_avx512(s.data(), target_x, out); + benchmark::DoNotOptimize(out); + ++idx; + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_ExcessPositions512_LUTAVX512) + ->ArgNames({"X"}) + ->Args({-64}) + ->Args({-8}) + ->Args({0}) + ->Args({8}) + ->Args({64}); + +static void BM_ExcessPositions512_Expand(benchmark::State& state) { + const int target_x = state.range(0); + const auto blocks = make_blocks(); + const size_t num_blocks = blocks.size(); + + alignas(64) uint64_t out[8]; + size_t idx = 0; + + for (auto _ : state) { + const auto& s = blocks[idx % num_blocks]; + excess_positions_512_expand(s.data(), target_x, out); + benchmark::DoNotOptimize(out); + ++idx; + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_ExcessPositions512_Expand) ->ArgNames({"X"}) ->Args({-64}) ->Args({-8}) @@ -81,7 +130,7 @@ BENCHMARK(BM_ExcessPositions512_Scalar) ->Args({8}) ->Args({64}); -static void BM_ExcessPositions512_LUT(benchmark::State& state) { +static void BM_ExcessPositions512_Expand8(benchmark::State& state) { const int target_x = state.range(0); const auto blocks = make_blocks(); const size_t num_blocks = blocks.size(); @@ -91,7 +140,7 @@ static void BM_ExcessPositions512_LUT(benchmark::State& state) { for (auto _ : state) { const auto& s = blocks[idx % num_blocks]; - excess_positions_512_lut(s.data(), target_x, out); + excess_positions_512_expand8(s.data(), target_x, out); benchmark::DoNotOptimize(out); ++idx; } @@ -99,7 +148,69 @@ static void BM_ExcessPositions512_LUT(benchmark::State& state) { state.SetItemsProcessed(state.iterations()); } -BENCHMARK(BM_ExcessPositions512_LUT) +BENCHMARK(BM_ExcessPositions512_Expand8) + ->ArgNames({"X"}) + ->Args({-64}) + ->Args({-8}) + ->Args({0}) + ->Args({8}) + ->Args({64}); + +static void BM_ExcessPositions512_ExpandAVX512(benchmark::State& state) { + const int target_x = state.range(0); + const auto blocks = make_blocks(); + const size_t num_blocks = blocks.size(); + + alignas(64) uint64_t out[8]; + size_t idx = 0; + + for (auto _ : state) { + const auto& s = blocks[idx % num_blocks]; + excess_positions_512_expand_avx512(s.data(), target_x, out); + benchmark::DoNotOptimize(out); + ++idx; + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_ExcessPositions512_ExpandAVX512) + ->ArgNames({"X"}) + ->Args({-64}) + ->Args({-8}) + ->Args({0}) + ->Args({8}) + ->Args({64}); + +static void BM_ExcessPositions512_Scalar(benchmark::State& state) { + const int target_x = state.range(0); + const auto blocks = make_blocks(); + const size_t num_blocks = blocks.size(); + + alignas(64) uint64_t out[8]; + size_t idx = 0; + + for (auto _ : state) { + const auto& s = blocks[idx % num_blocks]; + for (int w = 0; w < 8; ++w) { + out[w] = 0; + } + int cur = 0; + for (size_t i = 0; i < 512; ++i) { + const int bit = int((s[i >> 6] >> (i & 63)) & 1ull); + cur += bit ? +1 : -1; + if (cur == target_x) { + out[i >> 6] |= (uint64_t{1} << (i & 63)); + } + } + benchmark::DoNotOptimize(out); + ++idx; + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_ExcessPositions512_Scalar) ->ArgNames({"X"}) ->Args({-64}) ->Args({-8}) diff --git a/src/tests/excess_positions_tests.cpp b/src/tests/excess_positions_tests.cpp index f9ed763..cb09bc3 100644 --- a/src/tests/excess_positions_tests.cpp +++ b/src/tests/excess_positions_tests.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -8,6 +9,12 @@ #include #include +using pixie::experimental::excess_positions_512_branching_lut; +using pixie::experimental::excess_positions_512_expand; +using pixie::experimental::excess_positions_512_expand8; +using pixie::experimental::excess_positions_512_expand_avx512; +using pixie::experimental::excess_positions_512_lut_avx512; + static void naive_excess_positions_512(const uint64_t* s, int target_x, uint64_t* out) { @@ -71,6 +78,16 @@ TEST(ExcessPositions512, AllZeros) { for (int w = 0; w < 8; ++w) { EXPECT_EQ(out[w], 0u); } + + for (int x = -8; x <= 1; ++x) { + check_matches_naive(excess_positions_512_expand, "expand", s, x); + check_matches_naive(excess_positions_512_expand8, "expand8", s, x); + check_matches_naive(excess_positions_512_expand_avx512, "expand_avx512", s, + x); + check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, + x); + check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x); + } } TEST(ExcessPositions512, AllOnes) { @@ -91,6 +108,16 @@ TEST(ExcessPositions512, AllOnes) { for (int w = 0; w < 8; ++w) { EXPECT_EQ(out[w], 0u); } + + for (int x = -1; x <= 8; ++x) { + check_matches_naive(excess_positions_512_expand, "expand", s, x); + check_matches_naive(excess_positions_512_expand8, "expand8", s, x); + check_matches_naive(excess_positions_512_expand_avx512, "expand_avx512", s, + x); + check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, + x); + check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x); + } } TEST(ExcessPositions512, Alternating) { @@ -107,6 +134,13 @@ TEST(ExcessPositions512, Alternating) { for (int w = 0; w < 8; ++w) { EXPECT_EQ(out[w], ref[w]) << "x=" << x << " word=" << w; } + check_matches_naive(excess_positions_512_expand, "expand", s, x); + check_matches_naive(excess_positions_512_expand8, "expand8", s, x); + check_matches_naive(excess_positions_512_expand_avx512, "expand_avx512", s, + x); + check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, + x); + check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x); } } @@ -137,6 +171,16 @@ TEST(ExcessPositions512, ExhaustiveSmall16) { for (int x = -20; x <= 20; ++x) { excess_positions_512(s, x, out); naive_excess_positions_512(s, x, ref); + check_matches_naive(excess_positions_512_expand, "expand", s, x, + static_cast(pattern)); + check_matches_naive(excess_positions_512_expand8, "expand8", s, x, + static_cast(pattern)); + check_matches_naive(excess_positions_512_expand_avx512, "expand_avx512", + s, x, static_cast(pattern)); + check_matches_naive(excess_positions_512_branching_lut, "branching_lut", + s, x, static_cast(pattern)); + check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x, + static_cast(pattern)); for (int w = 0; w < 8; ++w) { ASSERT_EQ(out[w], ref[w]) << "pattern=" << pattern << " x=" << x << " word=" << w; @@ -170,6 +214,13 @@ TEST(ExcessPositions512, Random) { excess_positions_512(s, x, out); naive_excess_positions_512(s, x, ref); + check_matches_naive(excess_positions_512_expand, "expand", s, x, t); + check_matches_naive(excess_positions_512_expand8, "expand8", s, x, t); + check_matches_naive(excess_positions_512_expand_avx512, "expand_avx512", s, + x, t); + check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, + x, t); + check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, x, t); for (int w = 0; w < 8; ++w) { ASSERT_EQ(out[w], ref[w]) << "case=" << t << " x=" << x << " word=" << w; @@ -194,166 +245,15 @@ TEST(ExcessPositions512, TargetZero) { } excess_positions_512(s, 0, out); naive_excess_positions_512(s, 0, ref); + check_matches_naive(excess_positions_512_expand, "expand", s, 0, t); + check_matches_naive(excess_positions_512_expand8, "expand8", s, 0, t); + check_matches_naive(excess_positions_512_expand_avx512, "expand_avx512", s, + 0, t); + check_matches_naive(excess_positions_512_branching_lut, "branching_lut", s, + 0, t); + check_matches_naive(excess_positions_512_lut_avx512, "lut_avx512", s, 0, t); for (int w = 0; w < 8; ++w) { ASSERT_EQ(out[w], ref[w]) << "case=" << t << " word=" << w; } } } - -TEST(ExcessPositions512LUT, AllZeros) { - alignas(64) uint64_t s[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - - for (int x = -8; x <= 0; ++x) { - check_matches_naive(excess_positions_512_lut, "lut", s, x); - } - - alignas(64) uint64_t out[8]; - excess_positions_512_lut(s, 1, out); - for (int w = 0; w < 8; ++w) { - EXPECT_EQ(out[w], 0u); - } -} - -TEST(ExcessPositions512LUT, AllOnes) { - alignas(64) uint64_t s[8] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, - UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; - - for (int x = 0; x <= 8; ++x) { - check_matches_naive(excess_positions_512_lut, "lut", s, x); - } - - alignas(64) uint64_t out[8]; - excess_positions_512_lut(s, -1, out); - for (int w = 0; w < 8; ++w) { - EXPECT_EQ(out[w], 0u); - } -} - -TEST(ExcessPositions512LUT, Alternating) { - alignas(64) uint64_t s[8]; - for (int w = 0; w < 8; ++w) { - s[w] = 0xAAAAAAAAAAAAAAAAull; - } - - for (int x = -2; x <= 2; ++x) { - check_matches_naive(excess_positions_512_lut, "lut", s, x); - } -} - -TEST(ExcessPositions512LUT, OutOfRange) { - alignas(64) uint64_t s[8] = {UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX, - UINT64_MAX, UINT64_MAX, UINT64_MAX, UINT64_MAX}; - alignas(64) uint64_t out[8]; - excess_positions_512_lut(s, 513, out); - for (int w = 0; w < 8; ++w) { - EXPECT_EQ(out[w], 0u); - } - excess_positions_512_lut(s, -513, out); - for (int w = 0; w < 8; ++w) { - EXPECT_EQ(out[w], 0u); - } -} - -TEST(ExcessPositions512LUT, ExhaustiveSmall16) { - alignas(64) uint64_t s[8]; - - for (uint64_t pattern = 0; pattern < (1ull << 16); ++pattern) { - s[0] = pattern; - for (int w = 1; w < 8; ++w) { - s[w] = 0; - } - for (int x = -20; x <= 20; ++x) { - check_matches_naive(excess_positions_512_lut, "lut", s, x, - static_cast(pattern)); - } - } -} - -TEST(ExcessPositions512LUT, Random) { - const int cases = [] { - const char* env = std::getenv("EXCESS_POS_CASES"); - return env ? std::atoi(env) : 1000; - }(); - const uint64_t seed = [] { - const char* env = std::getenv("EXCESS_POS_SEED"); - return env ? std::stoull(env) : 42ull; - }(); - - std::mt19937_64 rng(static_cast(seed)); - std::uniform_int_distribution x_dist(-512, 512); - - alignas(64) uint64_t s[8]; - - for (int t = 0; t < cases; ++t) { - for (int w = 0; w < 8; ++w) { - s[w] = rng(); - } - const int x = x_dist(rng); - check_matches_naive(excess_positions_512_lut, "lut", s, x, t); - } -} - -TEST(ExcessPositions512LUT, TargetZero) { - const uint64_t seed = 12345; - std::mt19937_64 rng(seed); - - alignas(64) uint64_t s[8]; - - for (int t = 0; t < 500; ++t) { - for (int w = 0; w < 8; ++w) { - s[w] = rng(); - } - check_matches_naive(excess_positions_512_lut, "lut", s, 0, t); - } -} - -TEST(ExcessPositions512LUT, MatchesExpand) { - const int cases = 500; - std::mt19937_64 rng(99999); - std::uniform_int_distribution x_dist(-512, 512); - - alignas(64) uint64_t s[8]; - alignas(64) uint64_t out_expand[8]; - alignas(64) uint64_t out_lut[8]; - - for (int t = 0; t < cases; ++t) { - for (int w = 0; w < 8; ++w) { - s[w] = rng(); - } - const int x = x_dist(rng); - - excess_positions_512(s, x, out_expand); - excess_positions_512_lut(s, x, out_lut); - - for (int w = 0; w < 8; ++w) { - ASSERT_EQ(out_expand[w], out_lut[w]) - << "case=" << t << " x=" << x << " word=" << w; - } - } -} - -TEST(ExcessPositions512LUT, OverflowBoundary) { - alignas(64) uint64_t s[8]; - alignas(64) uint64_t out_expand[8]; - alignas(64) uint64_t out_lut[8]; - - for (int x = -64; x <= 64; ++x) { - for (uint64_t hi_pattern = 0; hi_pattern < 256; ++hi_pattern) { - for (int fill = 0; fill <= 7; ++fill) { - for (int w = 0; w < 8; ++w) { - s[w] = (w < fill) ? UINT64_MAX : 0; - } - s[fill] = hi_pattern; - - excess_positions_512(s, x, out_expand); - excess_positions_512_lut(s, x, out_lut); - - for (int w = 0; w < 8; ++w) { - ASSERT_EQ(out_expand[w], out_lut[w]) - << "x=" << x << " fill=" << fill << " hi=" << hi_pattern - << " word=" << w; - } - } - } - } -}