From a4495f372366b3acbada8f7fe074a4032433b1bd Mon Sep 17 00:00:00 2001 From: Brady Date: Tue, 15 Aug 2023 22:55:45 -0500 Subject: [PATCH] Add `scan_alignment` parameter to `find_pattern` --- CMakeLists.txt | 1 - include/libhat/Scanner.hpp | 28 ++++++++++++++++------------ src/Scanner.cpp | 35 ++++++++++------------------------- src/arch/arm/Neon.cpp | 28 ---------------------------- src/arch/x86/AVX2.cpp | 4 ++-- src/arch/x86/AVX512.cpp | 4 ++-- src/arch/x86/SSE.cpp | 4 ++-- 7 files changed, 32 insertions(+), 72 deletions(-) delete mode 100644 src/arch/arm/Neon.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b144148..262bd31 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,6 @@ set(LIBHAT_SRC src/arch/x86/AVX512.cpp src/arch/x86/System.cpp - src/arch/arm/Neon.cpp src/arch/arm/System.cpp) add_library(libhat STATIC ${LIBHAT_SRC}) diff --git a/include/libhat/Scanner.hpp b/include/libhat/Scanner.hpp index 1bcb48b..9731c18 100644 --- a/include/libhat/Scanner.hpp +++ b/include/libhat/Scanner.hpp @@ -16,7 +16,8 @@ namespace hat { using rel_t = int32_t; public: constexpr scan_result() : result(nullptr) {} - constexpr scan_result(const std::byte* result) : result(result) {} // NOLINT(google-explicit-constructor) + constexpr scan_result(std::nullptr_t) : result(nullptr) {} // NOLINT(google-explicit-constructor) + constexpr scan_result(const std::byte* result) : result(result) {} // NOLINT(google-explicit-constructor) /// Reads an integer of the specified type located at an offset from the signature result template @@ -50,29 +51,30 @@ namespace hat { const std::byte* result; }; + enum class scan_alignment { + X1 + }; + namespace detail { enum class scan_mode { - Auto, // Automatically choose the mode to use - Search, // std::search FastFirst, // std::find + std::equal SSE, // x86 SSE 4.1 AVX2, // x86 AVX2 AVX512, // x86 AVX512 - Neon, // ARM Neon // Fallback mode to use for SIMD remaining bytes Single = FastFirst }; - template + template scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature); - template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature); + template + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature); template<> - constexpr scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + constexpr scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { const auto firstByte = *signature[0]; const auto scanEnd = end - signature.size() + 1; @@ -111,29 +113,31 @@ namespace hat { ); /// Perform a signature scan on the entirety of the process module or a specified module + template scan_result find_pattern( signature_view signature, process::module_t mod = process::get_process_module() ); /// Perform a signature scan on a specific section of the process module or a specified module + template scan_result find_pattern( signature_view signature, std::string_view section, process::module_t mod = process::get_process_module() ); - /// Root implementation of FindPattern - template + /// Root implementation of find_pattern + template constexpr scan_result find_pattern( Iter begin, Iter end, signature_view signature ) { if LIBHAT_IF_CONSTEVAL { - return detail::find_pattern(std::to_address(begin), std::to_address(end), signature); + return detail::find_pattern(std::to_address(begin), std::to_address(end), signature); } else { - return detail::find_pattern(std::to_address(begin), std::to_address(end), signature); + return detail::find_pattern(std::to_address(begin), std::to_address(end), signature); } } } diff --git a/src/Scanner.cpp b/src/Scanner.cpp index db0829a..c006c22 100644 --- a/src/Scanner.cpp +++ b/src/Scanner.cpp @@ -1,7 +1,5 @@ #include -#include - #include #include @@ -9,57 +7,44 @@ namespace hat { using namespace hat::process; + template scan_result find_pattern(signature_view signature, module_t mod) { const auto data = get_module_data(mod); if (data.empty()) { return nullptr; } - return find_pattern(data.begin(), data.end(), signature); + return find_pattern(data.begin(), data.end(), signature); } + template scan_result find_pattern(signature_view signature, std::string_view section, module_t mod) { const auto data = get_section_data(mod, section); if (data.empty()) { return nullptr; } - return find_pattern(data.begin(), data.end(), signature); + return find_pattern(data.begin(), data.end(), signature); } } namespace hat::detail { - template<> - [[deprecated]] scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - auto it = std::search( - begin, end, - signature.begin(), signature.end(), - [](auto byte, auto opt) { - return !opt.has_value() || *opt == byte; - }); - return it != end ? it : nullptr; - } - - template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + template + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { const auto size = signature.size(); #if defined(LIBHAT_X86) const auto& ext = get_system().extensions; if (ext.bmi1) { if (size <= 65 && ext.avx512) { - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); } else if (size <= 33 && ext.avx2) { - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); } } if (size <= 17 && ext.sse41) { - return find_pattern(begin, end, signature); - } -#elif defined(LIBHAT_ARM) - if (size <= 17) { - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); } #endif // If none of the vectorized implementations are available/supported, then fallback to scanning per-byte - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); } } diff --git a/src/arch/arm/Neon.cpp b/src/arch/arm/Neon.cpp deleted file mode 100644 index f176387..0000000 --- a/src/arch/arm/Neon.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include -#ifdef LIBHAT_ARM - -#include - -namespace hat::detail { - - template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - const auto firstByte = vld1q_dup_u8(reinterpret_cast(*signature[0])); - - auto vec = reinterpret_cast(begin); - const auto n = static_cast(end - signature.size() - begin) / sizeof(uint8x16_t); - const auto e = vec + n; - - for (; vec != e; vec++) { - const auto cmp = vceqq_u8(firstByte, *vec); - uint64_t first = vgetq_lane_u64(vreinterpretq_u64_u8(cmp), 0); - uint64_t second = vgetq_lane_u64(vreinterpretq_u64_u8(cmp), 1); - if (first || second) { - // TODO: Extract Mask - } - } - - return find_pattern(begin, end, signature); - } -} -#endif diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index 9db11f0..3d994cf 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -26,7 +26,7 @@ namespace hat::detail { } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { // 256 bit vector containing first signature byte repeated const auto firstByte = _mm256_set1_epi8(static_cast(*signature[0])); const auto [signatureBytes, signatureMask] = load_signature_256(signature); @@ -53,7 +53,7 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 256 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); } } #endif diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index 40cc100..00e20ae 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -26,7 +26,7 @@ namespace hat::detail { } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { // 512 bit vector containing first signature byte repeated const auto firstByte = _mm512_set1_epi8(static_cast(*signature[0])); const auto [signatureBytes, signatureMask] = load_signature_512(signature); @@ -51,7 +51,7 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 512 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); } } #endif diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index 89f32e9..583f77a 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -26,7 +26,7 @@ namespace hat::detail { } template<> - scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { + scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { // 256 bit vector containing first signature byte repeated const auto firstByte = _mm_set1_epi8(static_cast(*signature[0])); const auto [signatureBytes, signatureMask] = load_signature_128(signature); @@ -53,7 +53,7 @@ namespace hat::detail { // Look in remaining bytes that couldn't be grouped into 128 bits begin = reinterpret_cast(vec); - return find_pattern(begin, end, signature); + return find_pattern(begin, end, signature); } } #endif