From 7aae038757cfbd50f7eee94fb6ece98edd04d314 Mon Sep 17 00:00:00 2001 From: Brady Date: Mon, 21 Aug 2023 16:38:29 -0500 Subject: [PATCH] 16-byte alignment scanning support --- include/libhat/Scanner.hpp | 60 ++++++++++++++++++++++++++++++++++++-- src/Scanner.cpp | 9 ++++-- src/arch/x86/AVX2.cpp | 22 ++++++++++++-- src/arch/x86/AVX512.cpp | 20 +++++++++++-- src/arch/x86/SSE.cpp | 24 ++++++++++++--- 5 files changed, 121 insertions(+), 14 deletions(-) diff --git a/include/libhat/Scanner.hpp b/include/libhat/Scanner.hpp index cb34312..e319025 100644 --- a/include/libhat/Scanner.hpp +++ b/include/libhat/Scanner.hpp @@ -52,7 +52,8 @@ namespace hat { }; enum class scan_alignment { - X1 + X1 = 1, + X16 = 16 }; namespace detail { @@ -67,6 +68,37 @@ namespace hat { Single = FastFirst }; + template + inline constexpr auto alignment_stride = static_cast>(alignment); + + template> + inline consteval auto create_alignment_mask() { + type mask{}; + for (size_t i = 0; i < sizeof(type) * 8; i += stride) { + mask |= (type(1) << i); + } + return mask; + } + + template> + inline const std::byte* next_boundary_align(const std::byte* ptr) { + if constexpr (stride == 1) { + return ptr; + } + uintptr_t mod = reinterpret_cast(ptr) % stride; + ptr += mod ? stride - mod : 0; + return ptr; + } + + template> + inline const std::byte* prev_boundary_align(const std::byte* ptr) { + if constexpr (stride == 1) { + return ptr; + } + uintptr_t mod = reinterpret_cast(ptr) % stride; + return ptr - mod; + } + template scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature); @@ -74,7 +106,7 @@ namespace hat { scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature); template<> - constexpr scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + inline constexpr scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { const auto firstByte = *signature[0]; const auto scanEnd = end - signature.size() + 1; @@ -98,6 +130,30 @@ namespace hat { } return nullptr; } + + template<> + inline scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + const auto firstByte = *signature[0]; + + const auto scanBegin = next_boundary_align(begin); + const auto scanEnd = prev_boundary_align(end - signature.size() + 1); + if (scanBegin >= scanEnd) { + return {}; + } + + for (auto i = scanBegin; i != scanEnd; i += 16) { + if (*i == firstByte) { + // Compare everything after the first byte + auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) { + return !opt.has_value() || *opt == byte; + }); + if (match) { + return i; + } + } + } + return nullptr; + } } enum class compiler_type { diff --git a/src/Scanner.cpp b/src/Scanner.cpp index fc49cc2..7ccbdc9 100644 --- a/src/Scanner.cpp +++ b/src/Scanner.cpp @@ -25,8 +25,10 @@ namespace hat { return find_pattern(data.begin(), data.end(), signature); } - template scan_result find_pattern(signature_view signature, module_t mod); - template scan_result find_pattern(signature_view signature, std::string_view section, module_t mod); + template scan_result find_pattern(signature_view, module_t); + template scan_result find_pattern(signature_view, std::string_view, module_t); + template scan_result find_pattern(signature_view, module_t); + template scan_result find_pattern(signature_view, std::string_view, module_t); } namespace hat::detail { @@ -51,5 +53,6 @@ namespace hat::detail { return find_pattern(begin, end, signature); } - template scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature); + template scan_result find_pattern(const std::byte*, const std::byte*, signature_view); + template scan_result find_pattern(const std::byte*, const std::byte*, signature_view); } diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index 3d994cf..720da42 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -25,12 +25,17 @@ namespace hat::detail { ); } - template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + template + scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) { // 256 bit vector containing first signature byte repeated const auto firstByte = _mm256_set1_epi8(static_cast(*signature[0])); const auto [signatureBytes, signatureMask] = load_signature_256(signature); + begin = next_boundary_align(begin); + if (begin >= end) { + return {}; + } + auto vec = reinterpret_cast(begin); const auto n = static_cast(end - signature.size() - begin) / sizeof(__m256i); const auto e = vec + n; @@ -38,6 +43,7 @@ namespace hat::detail { for (; vec != e; vec++) { const auto cmp = _mm256_cmpeq_epi8(firstByte, _mm256_loadu_si256(vec)); auto mask = static_cast(_mm256_movemask_epi8(cmp)); + mask &= create_alignment_mask(); while (mask) { const auto offset = _tzcnt_u32(mask); const auto i = reinterpret_cast(vec) + offset; @@ -53,7 +59,17 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 256 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); + } + + template<> + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + return find_pattern_avx2(begin, end, signature); + } + + template<> + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + return find_pattern_avx2(begin, end, signature); } } #endif diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index 00e20ae..e340985 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -25,18 +25,24 @@ namespace hat::detail { ); } - template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + template + scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) { // 512 bit vector containing first signature byte repeated const auto firstByte = _mm512_set1_epi8(static_cast(*signature[0])); const auto [signatureBytes, signatureMask] = load_signature_512(signature); + begin = next_boundary_align(begin); + if (begin >= end) { + return {}; + } + auto vec = reinterpret_cast(begin); const auto n = static_cast(end - signature.size() - begin) / sizeof(__m512i); const auto e = vec + n; for (; vec != e; vec++) { auto mask = _mm512_cmpeq_epi8_mask(firstByte, _mm512_loadu_si512(vec)); + mask &= create_alignment_mask(); while (mask) { const auto offset = LIBHAT_TZCNT64(mask); const auto i = reinterpret_cast(vec) + offset; @@ -53,5 +59,15 @@ namespace hat::detail { begin = reinterpret_cast(vec); return find_pattern(begin, end, signature); } + + template<> + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + return find_pattern_avx512(begin, end, signature); + } + + template<> + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + return find_pattern_avx512(begin, end, signature); + } } #endif diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index 583f77a..c592502 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -25,19 +25,25 @@ namespace hat::detail { ); } - template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + template + scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) { // 256 bit vector containing first signature byte repeated const auto firstByte = _mm_set1_epi8(static_cast(*signature[0])); const auto [signatureBytes, signatureMask] = load_signature_128(signature); + begin = next_boundary_align(begin); + if (begin >= end) { + return {}; + } + auto vec = reinterpret_cast(begin); const auto n = static_cast(end - signature.size() - begin) / sizeof(__m128i); const auto e = vec + n; for (; vec != e; vec++) { const auto cmp = _mm_cmpeq_epi8(firstByte, _mm_loadu_si128(vec)); - auto mask = static_cast(_mm_movemask_epi8(cmp)); + auto mask = static_cast(_mm_movemask_epi8(cmp)); + mask &= create_alignment_mask(); while (mask) { const auto offset = LIBHAT_BSF32(mask); const auto i = reinterpret_cast(vec) + offset; @@ -53,7 +59,17 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 128 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); + } + + template<> + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + return find_pattern_sse(begin, end, signature); + } + + template<> + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + return find_pattern_sse(begin, end, signature); } } #endif