From 0caead9f769db21bef7431ee78e0a10a1da71a5b Mon Sep 17 00:00:00 2001 From: Brady Date: Wed, 31 Jul 2024 17:19:21 -0500 Subject: [PATCH] Separate `begin, end` from `scan_context` --- include/libhat/Scanner.hpp | 45 ++++++++++++++++++++++++++------------ src/Scanner.cpp | 16 ++++++++------ src/arch/x86/AVX2.cpp | 24 ++++++++++---------- src/arch/x86/AVX512.cpp | 24 ++++++++++---------- src/arch/x86/SSE.cpp | 24 ++++++++++---------- 5 files changed, 76 insertions(+), 57 deletions(-) diff --git a/include/libhat/Scanner.hpp b/include/libhat/Scanner.hpp index 7a5c2e1..aac2a48 100644 --- a/include/libhat/Scanner.hpp +++ b/include/libhat/Scanner.hpp @@ -82,10 +82,21 @@ namespace hat { namespace detail { struct scan_context { - const std::byte* begin{}; - const std::byte* end{}; signature_view signature{}; scan_hint hints{}; + + static constexpr scan_context create(const signature_view signature, const scan_hint hints) { + scan_context ctx{}; + ctx.signature = signature; + ctx.hints = hints; + if LIBHAT_IF_CONSTEVAL {} else { + ctx.apply_hints(); + } + return ctx; + } + private: + scan_context() = default; + void apply_hints(); }; enum class scan_mode { @@ -130,14 +141,14 @@ namespace hat { } template - const_scan_result find_pattern(const scan_context&); + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context&); template - const_scan_result find_pattern(const scan_context&); + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context&); template<> - inline constexpr const_scan_result find_pattern(const scan_context& context) { - auto [begin, end, signature, _] = context; + inline constexpr const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + const auto signature = context.signature; const auto firstByte = *signature[0]; const auto scanEnd = end - signature.size() + 1; @@ -163,8 +174,8 @@ namespace hat { } template<> - inline const_scan_result find_pattern(const scan_context& context) { - auto [begin, end, signature, _] = context; + inline const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + const auto signature = context.signature; const auto firstByte = *signature[0]; const auto scanBegin = next_boundary_align(begin); @@ -236,11 +247,13 @@ namespace hat { return {nullptr}; } + const auto context = detail::scan_context::create(trunc, hints); + const_scan_result result; if LIBHAT_IF_CONSTEVAL { - result = detail::find_pattern({begin, end, trunc, hints}); + result = detail::find_pattern(begin, end, context); } else { - result = detail::find_pattern({begin, end, trunc, hints}); + result = detail::find_pattern(begin, end, context); } return result.has_result() ? const_cast::underlying_type>(result.get() - offset) @@ -267,12 +280,14 @@ namespace hat { auto i = begin; auto out = beginOut; + const auto context = detail::scan_context::create(trunc, hints); + while (i < end && out != endOut && trunc.size() <= static_cast(std::distance(i, end))) { const_scan_result result; if LIBHAT_IF_CONSTEVAL { - result = detail::find_pattern({i, end, trunc, hints}); + result = detail::find_pattern(i, end, context); } else { - result = detail::find_pattern({i, end, trunc, hints}); + result = detail::find_pattern(i, end, context); } if (!result.has_result()) { i = end; @@ -303,12 +318,14 @@ namespace hat { auto out = outIn; size_t matches{}; + const auto context = detail::scan_context::create(trunc, hints); + while (begin < end && trunc.size() <= static_cast(std::distance(i, end))) { const_scan_result result; if LIBHAT_IF_CONSTEVAL { - result = detail::find_pattern({i, end, trunc, hints}); + result = detail::find_pattern(i, end, context); } else { - result = detail::find_pattern({i, end, trunc, hints}); + result = detail::find_pattern(i, end, context); } if (!result.has_result()) { break; diff --git a/src/Scanner.cpp b/src/Scanner.cpp index ae6aaf5..915095a 100644 --- a/src/Scanner.cpp +++ b/src/Scanner.cpp @@ -5,32 +5,34 @@ namespace hat::detail { + void scan_context::apply_hints() {} + template - const_scan_result find_pattern(const scan_context& context) { + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { #if defined(LIBHAT_X86) const auto& ext = get_system().extensions; if (ext.bmi) { #if !defined(LIBHAT_DISABLE_AVX512) if (ext.avx512f && ext.avx512bw) { - return find_pattern(context); + return find_pattern(begin, end, context); } #endif if (ext.avx2) { - return find_pattern(context); + return find_pattern(begin, end, context); } } #if !defined(LIBHAT_DISABLE_SSE) if (ext.sse41) { - return find_pattern(context); + return find_pattern(begin, end, context); } #endif #endif // If none of the vectorized implementations are available/supported, then fallback to scanning per-byte - return find_pattern(context); + return find_pattern(begin, end, context); } - template const_scan_result find_pattern(const scan_context& context); - template const_scan_result find_pattern(const scan_context& context); + template const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context); + template const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context); } // Validate return value const-ness for the root find_pattern impl diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index 91bd0dc..18dc403 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -26,8 +26,8 @@ namespace hat::detail { } template - const_scan_result find_pattern_avx2(const scan_context& context) { - auto [begin, end, signature, hints] = context; + const_scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, const scan_context& context) { + const auto signature = context.signature; // 256 bit vector containing first signature byte repeated const auto firstByte = _mm256_set1_epi8(static_cast(*signature[0])); @@ -90,34 +90,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 256 bits begin = reinterpret_cast(vec); - return find_pattern({begin, end, signature, hints}); + return find_pattern(begin, end, context); } template - const_scan_result find_pattern_avx2(const scan_context& context) { + const_scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, const scan_context& context) { auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 33; if (cmpeq2 && veccmp) { - return find_pattern_avx2(context); + return find_pattern_avx2(begin, end, context); } else if (cmpeq2) { - return find_pattern_avx2(context); + return find_pattern_avx2(begin, end, context); } else if (veccmp) { - return find_pattern_avx2(context); + return find_pattern_avx2(begin, end, context); } else { - return find_pattern_avx2(context); + return find_pattern_avx2(begin, end, context); } } template<> - const_scan_result find_pattern(const scan_context& context) { - return find_pattern_avx2(context); + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + return find_pattern_avx2(begin, end, context); } template<> - const_scan_result find_pattern(const scan_context& context) { - return find_pattern_avx2(context); + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + return find_pattern_avx2(begin, end, context); } } #endif diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index 4224b6b..1ef28cf 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -26,8 +26,8 @@ namespace hat::detail { } template - const_scan_result find_pattern_avx512(const scan_context& context) { - auto [begin, end, signature, hints] = context; + const_scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, const scan_context& context) { + const auto signature = context.signature; // 512 bit vector containing first signature byte repeated const auto firstByte = _mm512_set1_epi8(static_cast(*signature[0])); @@ -85,34 +85,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 512 bits begin = reinterpret_cast(vec); - return find_pattern({begin, end, signature, hints}); + return find_pattern(begin, end, context); } template - const_scan_result find_pattern_avx512(const scan_context& context) { + const_scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, const scan_context& context) { auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 65; if (cmpeq2 && veccmp) { - return find_pattern_avx512(context); + return find_pattern_avx512(begin, end, context); } else if (cmpeq2) { - return find_pattern_avx512(context); + return find_pattern_avx512(begin, end, context); } else if (veccmp) { - return find_pattern_avx512(context); + return find_pattern_avx512(begin, end, context); } else { - return find_pattern_avx512(context); + return find_pattern_avx512(begin, end, context); } } template<> - const_scan_result find_pattern(const scan_context& context) { - return find_pattern_avx512(context); + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + return find_pattern_avx512(begin, end, context); } template<> - const_scan_result find_pattern(const scan_context& context) { - return find_pattern_avx512(context); + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + return find_pattern_avx512(begin, end, context); } } #endif diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index cb10956..b88c41b 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -26,8 +26,8 @@ namespace hat::detail { } template - const_scan_result find_pattern_sse(const scan_context& context) { - auto [begin, end, signature, hints] = context; + const_scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, const scan_context& context) { + const auto signature = context.signature; // 256 bit vector containing first signature byte repeated const auto firstByte = _mm_set1_epi8(static_cast(*signature[0])); @@ -88,34 +88,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 128 bits begin = reinterpret_cast(vec); - return find_pattern({begin, end, signature, hints}); + return find_pattern(begin, end, context); } template - const_scan_result find_pattern_sse(const scan_context& context) { + const_scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, const scan_context& context) { auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 17; if (cmpeq2 && veccmp) { - return find_pattern_sse(context); + return find_pattern_sse(begin, end, context); } else if (cmpeq2) { - return find_pattern_sse(context); + return find_pattern_sse(begin, end, context); } else if (veccmp) { - return find_pattern_sse(context); + return find_pattern_sse(begin, end, context); } else { - return find_pattern_sse(context); + return find_pattern_sse(begin, end, context); } } template<> - const_scan_result find_pattern(const scan_context& context) { - return find_pattern_sse(context); + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + return find_pattern_sse(begin, end, context); } template<> - const_scan_result find_pattern(const scan_context& context) { - return find_pattern_sse(context); + const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) { + return find_pattern_sse(begin, end, context); } } #endif