From 3ee98fd0634dfd864e223353495b862f38555886 Mon Sep 17 00:00:00 2001 From: Brady Date: Thu, 1 Aug 2024 13:42:19 -0500 Subject: [PATCH] Fetch actual scanner impl --- include/libhat/Scanner.hpp | 46 ++++++++++++++++++++++---------------- src/Scanner.cpp | 14 +++++------- src/arch/x86/AVX2.cpp | 46 +++++++++++++++++++------------------- src/arch/x86/AVX512.cpp | 46 +++++++++++++++++++------------------- src/arch/x86/SSE.cpp | 46 +++++++++++++++++++------------------- 5 files changed, 101 insertions(+), 97 deletions(-) diff --git a/include/libhat/Scanner.hpp b/include/libhat/Scanner.hpp index 2beb3b3..8f3d883 100644 --- a/include/libhat/Scanner.hpp +++ b/include/libhat/Scanner.hpp @@ -105,19 +105,13 @@ namespace hat { scan_context() = default; }; - template - [[nodiscard]] std::pair resolve_scanner(); - - void apply_hints(scan_context& context); + [[nodiscard]] std::pair resolve_scanner(const scan_context&); enum class scan_mode { - FastFirst, // std::find + std::equal - SSE, // x86 SSE 4.1 - AVX2, // x86 AVX2 - AVX512, // x86 AVX512 - - // Fallback mode to use for SIMD remaining bytes - Single = FastFirst + Single, // std::find + std::equal + SSE, // x86 SSE 4.1 + AVX2, // x86 AVX2 + AVX512, // x86 AVX512 }; template @@ -132,8 +126,9 @@ namespace hat { return mask; } - template> + template inline const std::byte* next_boundary_align(const std::byte* ptr) { + constexpr auto stride = alignment_stride; if constexpr (stride == 1) { return ptr; } @@ -142,8 +137,9 @@ namespace hat { return ptr; } - template> + template inline const std::byte* prev_boundary_align(const std::byte* ptr) { + constexpr auto stride = alignment_stride; if constexpr (stride == 1) { return ptr; } @@ -151,11 +147,14 @@ namespace hat { return ptr - mod; } - template - const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context&); + template + scan_function_t get_scanner(const scan_context&); + + template + const_scan_result find_pattern_single(const std::byte* begin, const std::byte* end, const scan_context&); template<> - inline constexpr const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + constexpr const_scan_result find_pattern_single(const std::byte* begin, const std::byte* end, const scan_context& context) { const auto signature = context.signature; const auto firstByte = *signature[0]; const auto scanEnd = end - signature.size() + 1; @@ -182,7 +181,7 @@ namespace hat { } template<> - inline const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + inline const_scan_result find_pattern_single(const std::byte* begin, const std::byte* end, const scan_context& context) { const auto signature = context.signature; const auto firstByte = *signature[0]; @@ -206,6 +205,15 @@ namespace hat { return nullptr; } + template<> + constexpr scan_function_t get_scanner(const scan_context& context) { + switch (context.alignment) { + case scan_alignment::X1: return &find_pattern_single; + case scan_alignment::X16: return &find_pattern_single; + } + std::unreachable(); + } + [[nodiscard]] constexpr auto truncate(const signature_view signature) noexcept { // Truncate the leading wildcards from the signature size_t offset = 0; @@ -229,9 +237,9 @@ namespace hat { ctx.hints = hints; ctx.alignment = alignment; if LIBHAT_IF_CONSTEVAL { - ctx.scanner = &find_pattern; + ctx.scanner = get_scanner(ctx); } else { - std::tie(ctx.scanner, ctx.vectorSize) = resolve_scanner(); + std::tie(ctx.scanner, ctx.vectorSize) = resolve_scanner(ctx); ctx.apply_hints(); } return ctx; diff --git a/src/Scanner.cpp b/src/Scanner.cpp index 6d9a7b2..5847fa1 100644 --- a/src/Scanner.cpp +++ b/src/Scanner.cpp @@ -7,32 +7,28 @@ namespace hat::detail { void scan_context::apply_hints() {} - template - std::pair resolve_scanner() { + std::pair resolve_scanner(const scan_context& context) { #if defined(LIBHAT_X86) const auto& ext = get_system().extensions; if (ext.bmi) { #if !defined(LIBHAT_DISABLE_AVX512) if (ext.avx512f && ext.avx512bw) { - return {&find_pattern, 64}; + return {get_scanner(context), 64}; } #endif if (ext.avx2) { - return {&find_pattern, 32}; + return {get_scanner(context), 32}; } } #if !defined(LIBHAT_DISABLE_SSE) if (ext.sse41) { - return {&find_pattern, 16}; + return {get_scanner(context), 16}; } #endif #endif // If none of the vectorized implementations are available/supported, then fallback to scanning per-byte - return {&find_pattern, 0}; + return {get_scanner(context), 0}; } - - template std::pair resolve_scanner(); - template std::pair resolve_scanner(); } // Validate return value const-ness for the root find_pattern impl diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index 18dc403..17fb134 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -90,34 +90,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 256 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, context); + return find_pattern_single(begin, end, context); } - template - const_scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, const scan_context& context) { - auto& signature = context.signature; - const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); + template<> + scan_function_t get_scanner(const scan_context& context) { + const auto alignment = context.alignment; + const auto signature = context.signature; const bool veccmp = signature.size() <= 33; - if (cmpeq2 && veccmp) { - return find_pattern_avx2(begin, end, context); - } else if (cmpeq2) { - return find_pattern_avx2(begin, end, context); - } else if (veccmp) { - return find_pattern_avx2(begin, end, context); - } else { - return find_pattern_avx2(begin, end, context); + if (alignment == scan_alignment::X1) { + const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + if (cmpeq2 && veccmp) { + return &find_pattern_avx2; + } else if (cmpeq2) { + return &find_pattern_avx2; + } else if (veccmp) { + return &find_pattern_avx2; + } else { + return &find_pattern_avx2; + } + } else if (alignment == scan_alignment::X16) { + if (veccmp) { + return &find_pattern_avx2; + } else { + return &find_pattern_avx2; + } } - } - - template<> - const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { - return find_pattern_avx2(begin, end, context); - } - - template<> - const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { - return find_pattern_avx2(begin, end, context); + std::unreachable(); } } #endif diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index 1ef28cf..6c7a7f1 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -85,34 +85,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 512 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, context); + return find_pattern_single(begin, end, context); } - template - const_scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, const scan_context& context) { - auto& signature = context.signature; - const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); + template<> + scan_function_t get_scanner(const scan_context& context) { + const auto alignment = context.alignment; + const auto signature = context.signature; const bool veccmp = signature.size() <= 65; - if (cmpeq2 && veccmp) { - return find_pattern_avx512(begin, end, context); - } else if (cmpeq2) { - return find_pattern_avx512(begin, end, context); - } else if (veccmp) { - return find_pattern_avx512(begin, end, context); - } else { - return find_pattern_avx512(begin, end, context); + if (alignment == scan_alignment::X1) { + const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + if (cmpeq2 && veccmp) { + return &find_pattern_avx512; + } else if (cmpeq2) { + return &find_pattern_avx512; + } else if (veccmp) { + return &find_pattern_avx512; + } else { + return &find_pattern_avx512; + } + } else if (alignment == scan_alignment::X16) { + if (veccmp) { + return &find_pattern_avx512; + } else { + return &find_pattern_avx512; + } } - } - - template<> - const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { - return find_pattern_avx512(begin, end, context); - } - - template<> - const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { - return find_pattern_avx512(begin, end, context); + std::unreachable(); } } #endif diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index b88c41b..7540707 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -88,34 +88,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 128 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, context); + return find_pattern_single(begin, end, context); } - template - const_scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, const scan_context& context) { - auto& signature = context.signature; - const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); + template<> + scan_function_t get_scanner(const scan_context& context) { + const auto alignment = context.alignment; + const auto signature = context.signature; const bool veccmp = signature.size() <= 17; - if (cmpeq2 && veccmp) { - return find_pattern_sse(begin, end, context); - } else if (cmpeq2) { - return find_pattern_sse(begin, end, context); - } else if (veccmp) { - return find_pattern_sse(begin, end, context); - } else { - return find_pattern_sse(begin, end, context); + if (alignment == scan_alignment::X1) { + const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + if (cmpeq2 && veccmp) { + return &find_pattern_sse; + } else if (cmpeq2) { + return &find_pattern_sse; + } else if (veccmp) { + return &find_pattern_sse; + } else { + return &find_pattern_sse; + } + } else if (alignment == scan_alignment::X16) { + if (veccmp) { + return &find_pattern_sse; + } else { + return &find_pattern_sse; + } } - } - - template<> - const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { - return find_pattern_sse(begin, end, context); - } - - template<> - const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { - return find_pattern_sse(begin, end, context); + std::unreachable(); } } #endif