From cbb522a60f16a2e5a839cb0f786d437a551e641e Mon Sep 17 00:00:00 2001 From: Brady Date: Thu, 1 Aug 2024 19:31:55 -0500 Subject: [PATCH] Implement scan hint optimizations --- include/libhat/Defines.hpp | 16 ++++++--- include/libhat/Scanner.hpp | 69 +++++++++++++++++++++++++++++--------- src/Scanner.cpp | 66 ++++++++++++++++++++++++++++++++---- src/arch/x86/AVX2.cpp | 52 ++++++++++++++++------------ src/arch/x86/AVX512.cpp | 53 +++++++++++++++++------------ src/arch/x86/Frequency.hpp | 33 ++++++++++++++++++ src/arch/x86/SSE.cpp | 54 +++++++++++++++++------------ 7 files changed, 254 insertions(+), 89 deletions(-) create mode 100644 src/arch/x86/Frequency.hpp diff --git a/include/libhat/Defines.hpp b/include/libhat/Defines.hpp index 3573769..d8c14e9 100644 --- a/include/libhat/Defines.hpp +++ b/include/libhat/Defines.hpp @@ -77,10 +77,10 @@ #if __cpp_lib_unreachable >= 202202L #include #define LIBHAT_UNREACHABLE() std::unreachable() +#elif defined(__GNUC__) || defined(__clang__) + #define LIBHAT_UNREACHABLE() __builtin_unreachable() #elif defined(_MSC_VER) #define LIBHAT_UNREACHABLE() __assume(false) -#elif defined(__GNUC__) - #define LIBHAT_UNREACHABLE() __builtin_unreachable() #else #include namespace hat::detail { @@ -93,10 +93,10 @@ #if __has_cpp_attribute(assume) #define LIBHAT_ASSUME(...) [[assume(__VA_ARGS__)]] -#elif defined(_MSC_VER) - #define LIBHAT_ASSUME(...) __assume(__VA_ARGS__) #elif defined(__clang__) #define LIBHAT_ASSUME(...) __builtin_assume(__VA_ARGS__) +#elif defined(_MSC_VER) + #define LIBHAT_ASSUME(...) __assume(__VA_ARGS__) #else #define LIBHAT_ASSUME(...) \ do { \ @@ -105,3 +105,11 @@ } \ } while (0) #endif + +#if defined(__GNUC__) || defined(__clang__) + #define LIBHAT_FORCEINLINE inline __attribute__((always_inline)) +#elif defined(_MSC_VER) + #define LIBHAT_FORCEINLINE __forceinline +#else + #define LIBHAT_FORCEINLINE inline +#endif diff --git a/include/libhat/Scanner.hpp b/include/libhat/Scanner.hpp index bfd1aa2..15bbabc 100644 --- a/include/libhat/Scanner.hpp +++ b/include/libhat/Scanner.hpp @@ -65,8 +65,9 @@ namespace hat { }; enum class scan_hint : uint64_t { - none = 0, // no hints - x86_64 = 1 << 0, // The data being scanned is x86_64 machine code + none = 0, // no hints + x86_64 = 1 << 0, // The data being scanned is x86_64 machine code + pair0 = 1 << 1, // Only utilize byte pair based scanning if the signature starts with a byte pair }; constexpr scan_hint operator|(scan_hint lhs, scan_hint rhs) { @@ -85,27 +86,30 @@ namespace hat { using scan_function_t = const_scan_result(*)(const std::byte* begin, const std::byte* end, const scan_context& context); + struct scanner_context { + size_t vectorSize{}; + }; + class scan_context { public: signature_view signature{}; scan_function_t scanner{}; scan_alignment alignment{}; - size_t vectorSize{}; scan_hint hints{}; + std::optional pairIndex{}; [[nodiscard]] constexpr const_scan_result scan(const std::byte* begin, const std::byte* end) const { return this->scanner(begin, end, *this); } - void apply_hints(); + void auto_resolve_scanner(); + void apply_hints(const scanner_context&); static constexpr scan_context create(signature_view signature, scan_alignment alignment, scan_hint hints); private: scan_context() = default; }; - [[nodiscard]] std::pair resolve_scanner(const scan_context&); - enum class scan_mode { Single, // std::find + std::equal SSE, // x86 SSE 4.1 @@ -117,7 +121,7 @@ namespace hat { inline constexpr auto alignment_stride = static_cast>(alignment); template - inline consteval auto create_alignment_mask() { + LIBHAT_FORCEINLINE consteval auto create_alignment_mask() { type mask{}; for (size_t i = 0; i < sizeof(type) * 8; i += alignment_stride) { mask |= (type(1) << i); @@ -126,7 +130,7 @@ namespace hat { } template - inline const std::byte* next_boundary_align(const std::byte* ptr) { + LIBHAT_FORCEINLINE const std::byte* next_boundary_align(const std::byte* ptr) { constexpr auto stride = alignment_stride; if constexpr (stride == 1) { return ptr; @@ -137,17 +141,53 @@ namespace hat { } template - inline const std::byte* prev_boundary_align(const std::byte* ptr) { + LIBHAT_FORCEINLINE const std::byte* prev_boundary_align(const std::byte* ptr) { constexpr auto stride = alignment_stride; if constexpr (stride == 1) { return ptr; } - uintptr_t mod = reinterpret_cast(ptr) % stride; + const uintptr_t mod = reinterpret_cast(ptr) % stride; return std::assume_aligned(ptr - mod); } + template + LIBHAT_FORCEINLINE const std::byte* align_pointer_as(const std::byte* ptr) { + constexpr size_t alignment = alignof(Type); + const uintptr_t mod = reinterpret_cast(ptr) % alignment; + ptr += mod ? alignment - mod : 0; + return std::assume_aligned(ptr); + } + + template + LIBHAT_FORCEINLINE auto segment_scan( + const std::byte* begin, + const std::byte* end, + const size_t signatureSize, + const size_t cmpOffset + ) -> std::tuple, std::span, std::span> { + const auto preBegin = begin; + const auto vecBegin = reinterpret_cast(align_pointer_as(preBegin + cmpOffset)); + const auto vecEnd = vecBegin + (static_cast(end - reinterpret_cast(vecBegin)) - signatureSize) / sizeof(Vector); + const auto preEnd = reinterpret_cast(vecBegin) - cmpOffset + signatureSize; + const auto postBegin = reinterpret_cast(vecEnd); + const auto postEnd = end; + + auto validateRange = [signatureSize](const std::byte* b, const std::byte* e) -> std::span { + if (b <= e && static_cast(e - b) >= signatureSize) { + return {b, e}; + } + return {}; + }; + + return { + validateRange(preBegin, preEnd), + std::span{vecBegin, vecEnd}, + validateRange(postBegin, postEnd) + }; + } + template - scan_function_t get_scanner(const scan_context&); + scan_function_t resolve_scanner(scan_context&); template const_scan_result find_pattern_single(const std::byte* begin, const std::byte* end, const scan_context&); @@ -205,7 +245,7 @@ namespace hat { } template<> - constexpr scan_function_t get_scanner(const scan_context& context) { + constexpr scan_function_t resolve_scanner(scan_context& context) { switch (context.alignment) { case scan_alignment::X1: return &find_pattern_single; case scan_alignment::X16: return &find_pattern_single; @@ -235,10 +275,9 @@ namespace hat { ctx.alignment = alignment; ctx.hints = hints; if LIBHAT_IF_CONSTEVAL { - ctx.scanner = get_scanner(ctx); + ctx.scanner = resolve_scanner(ctx); } else { - std::tie(ctx.scanner, ctx.vectorSize) = resolve_scanner(ctx); - ctx.apply_hints(); + ctx.auto_resolve_scanner(); } return ctx; } diff --git a/src/Scanner.cpp b/src/Scanner.cpp index 5847fa1..bf8f32a 100644 --- a/src/Scanner.cpp +++ b/src/Scanner.cpp @@ -3,31 +3,85 @@ #include #include +#include "arch/x86/Frequency.hpp" + namespace hat::detail { - void scan_context::apply_hints() {} + void scan_context::apply_hints(const scanner_context& scanner) { + const bool x86_64 = static_cast(this->hints & scan_hint::x86_64); + const bool pair0 = static_cast(this->hints & scan_hint::pair0); + + if (x86_64 && !pair0 && scanner.vectorSize && this->alignment == hat::scan_alignment::X1) { + const auto get_score = [this](const std::byte a, const std::byte b) { + constexpr auto& pairs = hat::detail::x86_64::pairs_x1; + const auto it = std::ranges::find(pairs, std::pair{a, b}); + return it == pairs.end() ? pairs.size() : pairs.size() - static_cast(it - pairs.begin()) - 1; + }; + + const auto score_pair = [&](auto&& tup) { + auto [a, b] = std::get<1>(tup); + return std::make_tuple(std::get<0>(tup), get_score(a.value(), b.value())); + }; + + static constexpr auto is_complete_pair = [](auto&& tup) { + auto [a, b] = std::get<1>(tup); + return a.has_value() && b.has_value(); + }; + + auto valid_pairs = this->signature + | std::views::take(scanner.vectorSize) + | std::views::adjacent<2> + | std::views::enumerate + | std::views::filter(is_complete_pair) + | std::views::transform(score_pair); + + if (!valid_pairs.empty()) { + this->pairIndex = std::get<0>(std::ranges::max(valid_pairs, std::ranges::less{}, [](auto&& tup) { + return std::get<1>(tup); + })); + } + } + + // If no "optimal" pair was found, find the first byte pair in the signature + if (!this->pairIndex.has_value()) { + size_t i{}; + for (auto&& [a, b] : this->signature | std::views::adjacent<2>) { + if (a.has_value() && b.has_value()) { + this->pairIndex = i; + break; + } + if (i == 0 && pair0) { + break; + } + i++; + } + } + } - std::pair resolve_scanner(const scan_context& context) { + void scan_context::auto_resolve_scanner() { #if defined(LIBHAT_X86) const auto& ext = get_system().extensions; if (ext.bmi) { #if !defined(LIBHAT_DISABLE_AVX512) if (ext.avx512f && ext.avx512bw) { - return {get_scanner(context), 64}; + this->scanner = resolve_scanner(*this); + return; } #endif if (ext.avx2) { - return {get_scanner(context), 32}; + this->scanner = resolve_scanner(*this); + return; } } #if !defined(LIBHAT_DISABLE_SSE) if (ext.sse41) { - return {get_scanner(context), 16}; + this->scanner = resolve_scanner(*this); + return; } #endif #endif // If none of the vectorized implementations are available/supported, then fallback to scanning per-byte - return {get_scanner(context), 0}; + this->scanner = resolve_scanner(*this); } } diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index 3a55ea7..e4b642b 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -9,14 +9,14 @@ namespace hat::detail { - inline auto load_signature_256(signature_view signature) { + inline auto load_signature_256(const signature_view signature) { std::byte byteBuffer[32]{}; // The remaining signature bytes std::byte maskBuffer[32]{}; // A bitmask for the signature bytes we care about - for (size_t i = 1; i < signature.size(); i++) { + for (size_t i = 0; i < signature.size(); i++) { auto e = signature[i]; if (e.has_value()) { - byteBuffer[i - 1] = *e; - maskBuffer[i - 1] = std::byte{0xFFu}; + byteBuffer[i] = *e; + maskBuffer[i] = std::byte{0xFFu}; } } return std::make_tuple( @@ -28,13 +28,15 @@ namespace hat::detail { template const_scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, const scan_context& context) { const auto signature = context.signature; + const auto cmpIndex = cmpeq2 ? *context.pairIndex : 0; + LIBHAT_ASSUME(cmpIndex < 32); // 256 bit vector containing first signature byte repeated - const auto firstByte = _mm256_set1_epi8(static_cast(*signature[0])); + const auto firstByte = _mm256_set1_epi8(static_cast(*signature[cmpIndex])); __m256i secondByte; if constexpr (cmpeq2) { - secondByte = _mm256_set1_epi8(static_cast(*signature[1])); + secondByte = _mm256_set1_epi8(static_cast(*signature[cmpIndex + 1])); } __m256i signatureBytes, signatureMask; @@ -47,19 +49,24 @@ namespace hat::detail { return {}; } - auto vec = reinterpret_cast(begin); - const auto n = static_cast(end - signature.size() - begin) / sizeof(__m256i); - const auto e = vec + n; + auto [pre, vec, post] = segment_scan<__m256i>(begin, end, signature.size(), cmpIndex); - for (; vec != e; vec++) { - const auto cmp = _mm256_cmpeq_epi8(firstByte, _mm256_loadu_si256(vec)); + if (!pre.empty()) { + const auto result = find_pattern_single(pre.data(), pre.data() + pre.size(), context); + if (result.has_result()) { + return result; + } + } + + for (auto& it : vec) { + const auto cmp = _mm256_cmpeq_epi8(firstByte, _mm256_loadu_si256(&it)); auto mask = static_cast(_mm256_movemask_epi8(cmp)); if constexpr (alignment != scan_alignment::X1) { mask &= create_alignment_mask(); if (!mask) continue; } else if constexpr (cmpeq2) { - const auto cmp2 = _mm256_cmpeq_epi8(secondByte, _mm256_loadu_si256(vec)); + const auto cmp2 = _mm256_cmpeq_epi8(secondByte, _mm256_loadu_si256(&it)); auto mask2 = static_cast(_mm256_movemask_epi8(cmp2)); // avoid loading unaligned memory by letting a match of the first signature byte in the last // position imply that the second byte also matched @@ -68,16 +75,16 @@ namespace hat::detail { while (mask) { const auto offset = _tzcnt_u32(mask); - const auto i = reinterpret_cast(vec) + offset; + const auto i = reinterpret_cast(&it) + offset - cmpIndex; if constexpr (veccmp) { - const auto data = _mm256_loadu_si256(reinterpret_cast(i + 1)); + const auto data = _mm256_loadu_si256(reinterpret_cast(i)); const auto cmpToSig = _mm256_cmpeq_epi8(signatureBytes, data); const auto matched = _mm256_testc_si256(cmpToSig, signatureMask); if (matched) LIBHAT_UNLIKELY { return i; } } else { - auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) { + auto match = std::equal(signature.begin(), signature.end(), i, [](auto opt, auto byte) { return !opt.has_value() || *opt == byte; }); if (match) LIBHAT_UNLIKELY { @@ -88,19 +95,22 @@ namespace hat::detail { } } - // Look in remaining bytes that couldn't be grouped into 256 bits - begin = reinterpret_cast(vec); - return find_pattern_single(begin, end, context); + if (!post.empty()) { + return find_pattern_single(post.data(), post.data() + post.size(), context); + } + return {}; } template<> - scan_function_t get_scanner(const scan_context& context) { + scan_function_t resolve_scanner(scan_context& context) { + context.apply_hints({.vectorSize = 32}); + const auto alignment = context.alignment; const auto signature = context.signature; - const bool veccmp = signature.size() <= 33; + const bool veccmp = signature.size() <= 32; if (alignment == scan_alignment::X1) { - const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + const bool cmpeq2 = context.pairIndex.has_value(); if (cmpeq2 && veccmp) { return &find_pattern_avx2; } else if (cmpeq2) { diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index 150a05b..4fcf403 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -9,14 +9,14 @@ namespace hat::detail { - inline auto load_signature_512(signature_view signature) { + inline auto load_signature_512(const signature_view signature) { std::byte byteBuffer[64]{}; // The remaining signature bytes uint64_t maskBuffer{}; // A bitmask for the signature bytes we care about - for (size_t i = 1; i < signature.size(); i++) { + for (size_t i = 0; i < signature.size(); i++) { auto e = signature[i]; if (e.has_value()) { - byteBuffer[i - 1] = *e; - maskBuffer |= (1ull << (i - 1)); + byteBuffer[i] = *e; + maskBuffer |= (1ull << i); } } return std::make_tuple( @@ -28,12 +28,15 @@ namespace hat::detail { template const_scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, const scan_context& context) { const auto signature = context.signature; + const auto cmpIndex = cmpeq2 ? *context.pairIndex : 0; + LIBHAT_ASSUME(cmpIndex < 64); + // 512 bit vector containing first signature byte repeated - const auto firstByte = _mm512_set1_epi8(static_cast(*signature[0])); + const auto firstByte = _mm512_set1_epi8(static_cast(*signature[cmpIndex])); __m512i secondByte; if constexpr (cmpeq2) { - secondByte = _mm512_set1_epi8(static_cast(*signature[1])); + secondByte = _mm512_set1_epi8(static_cast(*signature[cmpIndex + 1])); } __m512i signatureBytes; @@ -47,32 +50,37 @@ namespace hat::detail { return {}; } - auto vec = reinterpret_cast(begin); - const auto n = static_cast(end - signature.size() - begin) / sizeof(__m512i); - const auto e = vec + n; + auto [pre, vec, post] = segment_scan<__m512i>(begin, end, signature.size(), cmpIndex); + + if (!pre.empty()) { + const auto result = find_pattern_single(pre.data(), pre.data() + pre.size(), context); + if (result.has_result()) { + return result; + } + } - for (; vec != e; vec++) { - auto mask = _mm512_cmpeq_epi8_mask(firstByte, _mm512_loadu_si512(vec)); + for (auto& it : vec) { + auto mask = _mm512_cmpeq_epi8_mask(firstByte, _mm512_loadu_si512(&it)); if constexpr (alignment != scan_alignment::X1) { mask &= create_alignment_mask(); if (!mask) continue; } else if constexpr (cmpeq2) { - const auto mask2 = _mm512_cmpeq_epi8_mask(secondByte, _mm512_loadu_si512(vec)); + const auto mask2 = _mm512_cmpeq_epi8_mask(secondByte, _mm512_loadu_si512(&it)); mask &= (mask2 >> 1) | (0b1ull << 63); } while (mask) { const auto offset = LIBHAT_TZCNT64(mask); - const auto i = reinterpret_cast(vec) + offset; + const auto i = reinterpret_cast(&it) + offset - cmpIndex; if constexpr (veccmp) { - const auto data = _mm512_loadu_si512(i + 1); + const auto data = _mm512_loadu_si512(i); const auto invalid = _mm512_mask_cmpneq_epi8_mask(signatureMask, signatureBytes, data); if (!invalid) LIBHAT_UNLIKELY { return i; } } else { - auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) { + auto match = std::equal(signature.begin(), signature.end(), i, [](auto opt, auto byte) { return !opt.has_value() || *opt == byte; }); if (match) LIBHAT_UNLIKELY { @@ -83,19 +91,22 @@ namespace hat::detail { } } - // Look in remaining bytes that couldn't be grouped into 512 bits - begin = reinterpret_cast(vec); - return find_pattern_single(begin, end, context); + if (!post.empty()) { + return find_pattern_single(post.data(), post.data() + post.size(), context); + } + return {}; } template<> - scan_function_t get_scanner(const scan_context& context) { + scan_function_t resolve_scanner(scan_context& context) { + context.apply_hints({.vectorSize = 64}); + const auto alignment = context.alignment; const auto signature = context.signature; - const bool veccmp = signature.size() <= 65; + const bool veccmp = signature.size() <= 64; if (alignment == scan_alignment::X1) { - const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + const bool cmpeq2 = context.pairIndex.has_value(); if (cmpeq2 && veccmp) { return &find_pattern_avx512; } else if (cmpeq2) { diff --git a/src/arch/x86/Frequency.hpp b/src/arch/x86/Frequency.hpp new file mode 100644 index 0000000..0f0dd21 --- /dev/null +++ b/src/arch/x86/Frequency.hpp @@ -0,0 +1,33 @@ +#pragma once + +namespace hat::detail::x86_64 { + + static constexpr auto p(uint8_t a, uint8_t b) { + return std::pair{std::byte{a}, std::byte{b}}; + } + + // Top 100 byte pair occurrences on 1 byte alignment + // Accounts for ~39.7% of all pairs + static constexpr inline std::array pairs_x1{ + p(0x00, 0x00), p(0x48, 0x8B), p(0xCC, 0xCC), p(0x48, 0x8D), p(0x48, 0x89), + p(0x00, 0x48), p(0x48, 0x83), p(0x44, 0x24), p(0x01, 0x00), p(0x49, 0x8B), + p(0x48, 0x85), p(0x4C, 0x24), p(0xFF, 0xFF), p(0x0F, 0x11), p(0x4C, 0x8B), + p(0x08, 0x48), p(0x24, 0x20), p(0x5C, 0x24), p(0x01, 0x48), p(0xFF, 0x48), + p(0x4C, 0x89), p(0x4C, 0x8D), p(0xCC, 0x48), p(0xFF, 0x15), p(0x10, 0x48), + p(0x24, 0x30), p(0x03, 0x48), p(0x89, 0x44), p(0x00, 0xE8), p(0x90, 0x48), + p(0x8D, 0x05), p(0x83, 0xC4), p(0xC3, 0xCC), p(0x20, 0x48), p(0x0F, 0x57), + p(0x30, 0x48), p(0x02, 0x00), p(0xF3, 0x0F), p(0x00, 0x0F), p(0x54, 0x24), + p(0x85, 0xC9), p(0xC0, 0x0F), p(0x48, 0xC7), p(0x48, 0x81), p(0x85, 0xC0), + p(0x74, 0x24), p(0x02, 0x48), p(0x89, 0x5C), p(0x0F, 0x10), p(0x83, 0xEC), + p(0xC9, 0x74), p(0x8D, 0x4D), p(0x24, 0x40), p(0x57, 0xC0), p(0x24, 0x28), + p(0x8D, 0x4C), p(0x24, 0x38), p(0x00, 0x4C), p(0x8B, 0xCB), p(0x38, 0x48), + p(0x48, 0x3B), p(0xF8, 0x48), p(0x8D, 0x0D), p(0xC0, 0x48), p(0x04, 0x48), + p(0x0F, 0x84), p(0x03, 0x00), p(0x00, 0x49), p(0xC3, 0x48), p(0x8B, 0xCF), + p(0xC0, 0x74), p(0x89, 0x45), p(0x57, 0x48), p(0x40, 0x48), p(0x48, 0x33), + p(0x24, 0x48), p(0x24, 0x50), p(0x0F, 0xB6), p(0x8D, 0x15), p(0x18, 0x48), + p(0x28, 0x48), p(0x0F, 0x7F), p(0x7C, 0x24), p(0x8D, 0x54), p(0x8B, 0x40), + p(0x8B, 0xC8), p(0x8B, 0x01), p(0x8D, 0x8D), p(0xC1, 0x48), p(0x8B, 0x5C), + p(0xFE, 0x48), p(0x89, 0x74), p(0xC7, 0x44), p(0x66, 0x0F), p(0x83, 0xF8), + p(0xCB, 0xE8), p(0x24, 0x60), p(0xCC, 0xE8), p(0xC4, 0x20), p(0x8B, 0x4D), + }; +} diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index bd737ca..0853b8f 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -9,14 +9,14 @@ namespace hat::detail { - inline auto load_signature_128(signature_view signature) { + inline auto load_signature_128(const signature_view signature) { std::byte byteBuffer[16]{}; // The remaining signature bytes std::byte maskBuffer[16]{}; // A bitmask for the signature bytes we care about - for (size_t i = 1; i < signature.size(); i++) { + for (size_t i = 0; i < signature.size(); i++) { auto e = signature[i]; if (e.has_value()) { - byteBuffer[i - 1] = *e; - maskBuffer[i - 1] = std::byte{0xFFu}; + byteBuffer[i] = *e; + maskBuffer[i] = std::byte{0xFFu}; } } return std::make_tuple( @@ -28,13 +28,15 @@ namespace hat::detail { template const_scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, const scan_context& context) { const auto signature = context.signature; + const auto cmpIndex = cmpeq2 ? *context.pairIndex : 0; + LIBHAT_ASSUME(cmpIndex < 16); - // 256 bit vector containing first signature byte repeated - const auto firstByte = _mm_set1_epi8(static_cast(*signature[0])); + // 128 bit vector containing first signature byte repeated + const auto firstByte = _mm_set1_epi8(static_cast(*signature[cmpIndex])); __m128i secondByte; if constexpr (cmpeq2) { - secondByte = _mm_set1_epi8(static_cast(*signature[1])); + secondByte = _mm_set1_epi8(static_cast(*signature[cmpIndex + 1])); } __m128i signatureBytes, signatureMask; @@ -47,35 +49,40 @@ namespace hat::detail { return {}; } - auto vec = reinterpret_cast(begin); - const auto n = static_cast(end - signature.size() - begin) / sizeof(__m128i); - const auto e = vec + n; + auto [pre, vec, post] = segment_scan<__m128i>(begin, end, signature.size(), cmpIndex); - for (; vec != e; vec++) { - const auto cmp = _mm_cmpeq_epi8(firstByte, _mm_loadu_si128(vec)); + if (!pre.empty()) { + const auto result = find_pattern_single(pre.data(), pre.data() + pre.size(), context); + if (result.has_result()) { + return result; + } + } + + for (auto& it : vec) { + const auto cmp = _mm_cmpeq_epi8(firstByte, _mm_loadu_si128(&it)); auto mask = static_cast(_mm_movemask_epi8(cmp)); if constexpr (alignment != scan_alignment::X1) { mask &= create_alignment_mask(); if (!mask) continue; } else if constexpr (cmpeq2) { - const auto cmp2 = _mm_cmpeq_epi8(secondByte, _mm_loadu_si128(vec)); + const auto cmp2 = _mm_cmpeq_epi8(secondByte, _mm_loadu_si128(&it)); auto mask2 = static_cast(_mm_movemask_epi8(cmp2)); mask &= (mask2 >> 1) | (0b1u << 15); } while (mask) { const auto offset = LIBHAT_BSF32(mask); - const auto i = reinterpret_cast(vec) + offset; + const auto i = reinterpret_cast(&it) + offset - cmpIndex; if constexpr (veccmp) { - const auto data = _mm_loadu_si128(reinterpret_cast(i + 1)); + const auto data = _mm_loadu_si128(reinterpret_cast(i)); const auto cmpToSig = _mm_cmpeq_epi8(signatureBytes, data); const auto matched = _mm_testc_si128(cmpToSig, signatureMask); if (matched) LIBHAT_UNLIKELY { return i; } } else { - auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) { + auto match = std::equal(signature.begin(), signature.end(), i, [](auto opt, auto byte) { return !opt.has_value() || *opt == byte; }); if (match) LIBHAT_UNLIKELY { @@ -86,19 +93,22 @@ namespace hat::detail { } } - // Look in remaining bytes that couldn't be grouped into 128 bits - begin = reinterpret_cast(vec); - return find_pattern_single(begin, end, context); + if (!post.empty()) { + return find_pattern_single(post.data(), post.data() + post.size(), context); + } + return {}; } template<> - scan_function_t get_scanner(const scan_context& context) { + scan_function_t resolve_scanner(scan_context& context) { + context.apply_hints({.vectorSize = 16}); + const auto alignment = context.alignment; const auto signature = context.signature; - const bool veccmp = signature.size() <= 17; + const bool veccmp = signature.size() <= 16; if (alignment == scan_alignment::X1) { - const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + const bool cmpeq2 = context.pairIndex.has_value(); if (cmpeq2 && veccmp) { return &find_pattern_sse; } else if (cmpeq2) {