From 57efa1bb7f2bceec7fba63f264aa9ca8f366a2db Mon Sep 17 00:00:00 2001 From: Brady Date: Tue, 13 Feb 2024 13:04:07 -0600 Subject: [PATCH] Hints soon TM --- include/libhat/Scanner.hpp | 41 ++++++++++++++++++++++++++++++------- include/libhat/Scanner.inl | 22 ++++++++++++++++++++ src/Scanner.cpp | 42 +++++++------------------------------- src/arch/x86/AVX2.cpp | 25 +++++++++++++---------- src/arch/x86/AVX512.cpp | 24 ++++++++++++---------- src/arch/x86/SSE.cpp | 25 +++++++++++++---------- 6 files changed, 104 insertions(+), 75 deletions(-) create mode 100644 include/libhat/Scanner.inl diff --git a/include/libhat/Scanner.hpp b/include/libhat/Scanner.hpp index 64cfc1b..33ded4b 100644 --- a/include/libhat/Scanner.hpp +++ b/include/libhat/Scanner.hpp @@ -56,8 +56,30 @@ namespace hat { X16 = 16 }; + enum class scan_hint : uint64_t { + none = 0, // no hints + x86_64 = 1 << 0, // The data being scanned is x86_64 machine code + }; + + constexpr scan_hint operator|(scan_hint lhs, scan_hint rhs) { + using U = std::underlying_type_t; + return static_cast(static_cast(lhs) | static_cast(rhs)); + } + + constexpr scan_hint operator&(scan_hint lhs, scan_hint rhs) { + using U = std::underlying_type_t; + return static_cast(static_cast(lhs) & static_cast(rhs)); + } + namespace detail { + struct scan_context { + const std::byte* begin{}; + const std::byte* end{}; + signature_view signature{}; + scan_hint hints{}; + }; + enum class scan_mode { FastFirst, // std::find + std::equal SSE, // x86 SSE 4.1 @@ -100,13 +122,14 @@ namespace hat { } template - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature); + scan_result find_pattern(const scan_context&); template - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature); + scan_result find_pattern(const scan_context&); template<> - inline constexpr scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + inline constexpr scan_result find_pattern(const scan_context& context) { + auto [begin, end, signature, _] = context; const auto firstByte = *signature[0]; const auto scanEnd = end - signature.size() + 1; @@ -132,7 +155,8 @@ namespace hat { } template<> - inline scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + inline scan_result find_pattern(const scan_context& context) { + auto [begin, end, signature, _] = context; const auto firstByte = *signature[0]; const auto scanBegin = next_boundary_align(begin); @@ -176,7 +200,8 @@ namespace hat { constexpr scan_result find_pattern( Iter beginIt, Iter endIt, - signature_view signature + signature_view signature, + scan_hint hints = scan_hint::none ) { // Truncate the leading wildcards from the signature size_t offset = 0; @@ -196,9 +221,9 @@ namespace hat { hat::scan_result result; if LIBHAT_IF_CONSTEVAL { - result = detail::find_pattern(begin, end, signature); + result = detail::find_pattern({begin, end, signature, hints}); } else { - result = detail::find_pattern(begin, end, signature); + result = detail::find_pattern({begin, end, signature, hints}); } return result.has_result() ? result.get() - offset : nullptr; } @@ -218,3 +243,5 @@ namespace hat::experimental { process::module_t mod = process::get_process_module() ); } + +#include "Scanner.inl" diff --git a/include/libhat/Scanner.inl b/include/libhat/Scanner.inl new file mode 100644 index 0000000..d284bc3 --- /dev/null +++ b/include/libhat/Scanner.inl @@ -0,0 +1,22 @@ +#pragma once + +namespace hat { + + template + scan_result find_pattern(const signature_view signature, const hat::process::module_t mod) { + const auto data = hat::process::get_module_data(mod); + if (data.empty()) { + return nullptr; + } + return find_pattern(data.begin(), data.end(), signature); + } + + template + scan_result find_pattern(const signature_view signature, const std::string_view section, const hat::process::module_t mod) { + const auto data = hat::process::get_section_data(mod, section); + if (data.empty()) { + return nullptr; + } + return find_pattern(data.begin(), data.end(), signature); + } +} diff --git a/src/Scanner.cpp b/src/Scanner.cpp index e298d49..5f09501 100644 --- a/src/Scanner.cpp +++ b/src/Scanner.cpp @@ -3,60 +3,32 @@ #include #include "System.hpp" -namespace hat { - - using namespace hat::process; - - template - scan_result find_pattern(signature_view signature, module_t mod) { - const auto data = get_module_data(mod); - if (data.empty()) { - return nullptr; - } - return find_pattern(data.begin(), data.end(), signature); - } - - template - scan_result find_pattern(signature_view signature, std::string_view section, module_t mod) { - const auto data = get_section_data(mod, section); - if (data.empty()) { - return nullptr; - } - return find_pattern(data.begin(), data.end(), signature); - } - - template scan_result find_pattern(signature_view, module_t); - template scan_result find_pattern(signature_view, std::string_view, module_t); - template scan_result find_pattern(signature_view, module_t); - template scan_result find_pattern(signature_view, std::string_view, module_t); -} - namespace hat::detail { template - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern(const scan_context& context) { #if defined(LIBHAT_X86) const auto& ext = get_system().extensions; if (ext.bmi1) { #if !defined(LIBHAT_DISABLE_AVX512) if (ext.avx512) { - return find_pattern(begin, end, signature); + return find_pattern(context); } #endif if (ext.avx2) { - return find_pattern(begin, end, signature); + return find_pattern(context); } } #if !defined(LIBHAT_DISABLE_SSE) if (ext.sse41) { - return find_pattern(begin, end, signature); + return find_pattern(context); } #endif #endif // If none of the vectorized implementations are available/supported, then fallback to scanning per-byte - return find_pattern(begin, end, signature); + return find_pattern(context); } - template scan_result find_pattern(const std::byte*, const std::byte*, signature_view); - template scan_result find_pattern(const std::byte*, const std::byte*, signature_view); + template scan_result find_pattern(const scan_context& context); + template scan_result find_pattern(const scan_context& context); } diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index 5579f1a..cda9190 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -26,7 +26,9 @@ namespace hat::detail { } template - scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern_avx2(const scan_context& context) { + auto [begin, end, signature, hints] = context; + // 256 bit vector containing first signature byte repeated const auto firstByte = _mm256_set1_epi8(static_cast(*signature[0])); @@ -88,33 +90,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 256 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, signature); + return find_pattern({begin, end, signature, hints}); } template - scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern_avx2(const scan_context& context) { + auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 33; if (cmpeq2 && veccmp) { - return find_pattern_avx2(begin, end, signature); + return find_pattern_avx2(context); } else if (cmpeq2) { - return find_pattern_avx2(begin, end, signature); + return find_pattern_avx2(context); } else if (veccmp) { - return find_pattern_avx2(begin, end, signature); + return find_pattern_avx2(context); } else { - return find_pattern_avx2(begin, end, signature); + return find_pattern_avx2(context); } } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - return find_pattern_avx2(begin, end, signature); + scan_result find_pattern(const scan_context& context) { + return find_pattern_avx2(context); } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - return find_pattern_avx2(begin, end, signature); + scan_result find_pattern(const scan_context& context) { + return find_pattern_avx2(context); } } #endif diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index 046e43f..8d4d5a1 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -26,7 +26,8 @@ namespace hat::detail { } template - scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern_avx512(const scan_context& context) { + auto [begin, end, signature, hints] = context; // 512 bit vector containing first signature byte repeated const auto firstByte = _mm512_set1_epi8(static_cast(*signature[0])); @@ -84,33 +85,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 512 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, signature); + return find_pattern({begin, end, signature, hints}); } template - scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern_avx512(const scan_context& context) { + auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 65; if (cmpeq2 && veccmp) { - return find_pattern_avx512(begin, end, signature); + return find_pattern_avx512(context); } else if (cmpeq2) { - return find_pattern_avx512(begin, end, signature); + return find_pattern_avx512(context); } else if (veccmp) { - return find_pattern_avx512(begin, end, signature); + return find_pattern_avx512(context); } else { - return find_pattern_avx512(begin, end, signature); + return find_pattern_avx512(context); } } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - return find_pattern_avx512(begin, end, signature); + scan_result find_pattern(const scan_context& context) { + return find_pattern_avx512(context); } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - return find_pattern_avx512(begin, end, signature); + scan_result find_pattern(const scan_context& context) { + return find_pattern_avx512(context); } } #endif diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index e16e4bb..3eee222 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -26,7 +26,9 @@ namespace hat::detail { } template - scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern_sse(const scan_context& context) { + auto [begin, end, signature, hints] = context; + // 256 bit vector containing first signature byte repeated const auto firstByte = _mm_set1_epi8(static_cast(*signature[0])); @@ -86,33 +88,34 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 128 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, signature); + return find_pattern({begin, end, signature, hints}); } template - scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern_sse(const scan_context& context) { + auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 17; if (cmpeq2 && veccmp) { - return find_pattern_sse(begin, end, signature); + return find_pattern_sse(context); } else if (cmpeq2) { - return find_pattern_sse(begin, end, signature); + return find_pattern_sse(context); } else if (veccmp) { - return find_pattern_sse(begin, end, signature); + return find_pattern_sse(context); } else { - return find_pattern_sse(begin, end, signature); + return find_pattern_sse(context); } } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - return find_pattern_sse(begin, end, signature); + scan_result find_pattern(const scan_context& context) { + return find_pattern_sse(context); } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - return find_pattern_sse(begin, end, signature); + scan_result find_pattern(const scan_context& context) { + return find_pattern_sse(context); } } #endif