From 65d0d0ba34dd6a16f9a61ce787e973a7c089cb88 Mon Sep 17 00:00:00 2001 From: Brady Date: Mon, 29 Jul 2024 17:09:18 -0700 Subject: [PATCH] Match iterator constness for `find_pattern` return value --- include/libhat/Scanner.hpp | 45 ++++++++++++++++++++++++-------------- src/Scanner.cpp | 19 +++++++++++++--- src/arch/x86/AVX2.cpp | 8 +++---- src/arch/x86/AVX512.cpp | 8 +++---- src/arch/x86/SSE.cpp | 8 +++---- 5 files changed, 56 insertions(+), 32 deletions(-) diff --git a/include/libhat/Scanner.hpp b/include/libhat/Scanner.hpp index 8568b88..cd70832 100644 --- a/include/libhat/Scanner.hpp +++ b/include/libhat/Scanner.hpp @@ -12,12 +12,15 @@ namespace hat { - class scan_result { + template requires (std::is_pointer_v && sizeof(std::remove_pointer_t) == 1) + class scan_result_base { using rel_t = int32_t; public: - constexpr scan_result() : result(nullptr) {} - constexpr scan_result(std::nullptr_t) : result(nullptr) {} // NOLINT(google-explicit-constructor) - constexpr scan_result(const std::byte* result) : result(result) {} // NOLINT(google-explicit-constructor) + using underlying_type = T; + + constexpr scan_result_base() : result(nullptr) {} + constexpr scan_result_base(std::nullptr_t) : result(nullptr) {} // NOLINT(google-explicit-constructor) + constexpr scan_result_base(T result) : result(result) {} // NOLINT(google-explicit-constructor) /// Reads an integer of the specified type located at an offset from the signature result template @@ -32,7 +35,7 @@ namespace hat { } /// Resolve the relative address located at an offset from the signature result - [[nodiscard]] constexpr const std::byte* rel(size_t offset) const { + [[nodiscard]] constexpr T rel(size_t offset) const { return this->has_result() ? this->result + this->read(offset) + offset + sizeof(rel_t) : nullptr; } @@ -40,17 +43,20 @@ namespace hat { return this->result != nullptr; } - [[nodiscard]] constexpr const std::byte* operator*() const { + [[nodiscard]] constexpr T operator*() const { return this->result; } - [[nodiscard]] constexpr const std::byte* get() const { + [[nodiscard]] constexpr T get() const { return this->result; } private: - const std::byte* result; + T result; }; + using scan_result = scan_result_base; + using const_scan_result = scan_result_base; + enum class scan_alignment { X1 = 1, X16 = 16 @@ -122,13 +128,13 @@ namespace hat { } template - scan_result find_pattern(const scan_context&); + const_scan_result find_pattern(const scan_context&); template - scan_result find_pattern(const scan_context&); + const_scan_result find_pattern(const scan_context&); template<> - inline constexpr scan_result find_pattern(const scan_context& context) { + inline constexpr const_scan_result find_pattern(const scan_context& context) { auto [begin, end, signature, _] = context; const auto firstByte = *signature[0]; const auto scanEnd = end - signature.size() + 1; @@ -155,14 +161,14 @@ namespace hat { } template<> - inline scan_result find_pattern(const scan_context& context) { + inline constexpr const_scan_result find_pattern(const scan_context& context) { auto [begin, end, signature, _] = context; const auto firstByte = *signature[0]; const auto scanBegin = next_boundary_align(begin); const auto scanEnd = prev_boundary_align(end - signature.size() + 1); if (scanBegin >= scanEnd) { - return {}; + return nullptr; } for (auto i = scanBegin; i != scanEnd; i += 16) { @@ -198,7 +204,7 @@ namespace hat { /// Root implementation of find_pattern template - constexpr scan_result find_pattern( + constexpr auto find_pattern( Iter beginIt, Iter endIt, signature_view signature, @@ -216,17 +222,22 @@ namespace hat { const auto begin = std::to_address(beginIt) + offset; const auto end = std::to_address(endIt); + + using result_t = std::conditional_t>, const_scan_result, scan_result>; + if (begin >= end || signature.size() > static_cast(std::distance(begin, end))) { - return nullptr; + return result_t{nullptr}; } - hat::scan_result result; + const_scan_result result; if LIBHAT_IF_CONSTEVAL { result = detail::find_pattern({begin, end, signature, hints}); } else { result = detail::find_pattern({begin, end, signature, hints}); } - return result.has_result() ? result.get() - offset : nullptr; + return result.has_result() + ? const_cast(result.get() - offset) + : result_t{nullptr}; } } diff --git a/src/Scanner.cpp b/src/Scanner.cpp index d724c5d..ae6aaf5 100644 --- a/src/Scanner.cpp +++ b/src/Scanner.cpp @@ -6,7 +6,7 @@ namespace hat::detail { template - scan_result find_pattern(const scan_context& context) { + const_scan_result find_pattern(const scan_context& context) { #if defined(LIBHAT_X86) const auto& ext = get_system().extensions; if (ext.bmi) { @@ -29,6 +29,19 @@ namespace hat::detail { return find_pattern(context); } - template scan_result find_pattern(const scan_context& context); - template scan_result find_pattern(const scan_context& context); + template const_scan_result find_pattern(const scan_context& context); + template const_scan_result find_pattern(const scan_context& context); +} + +// Validate return value const-ness for the root find_pattern impl +namespace hat { + static_assert(std::is_same_v(), + std::declval(), + std::declval()))>); + + static_assert(std::is_same_v(), + std::declval(), + std::declval()))>); } diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index cda9190..91bd0dc 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -26,7 +26,7 @@ namespace hat::detail { } template - scan_result find_pattern_avx2(const scan_context& context) { + const_scan_result find_pattern_avx2(const scan_context& context) { auto [begin, end, signature, hints] = context; // 256 bit vector containing first signature byte repeated @@ -94,7 +94,7 @@ namespace hat::detail { } template - scan_result find_pattern_avx2(const scan_context& context) { + const_scan_result find_pattern_avx2(const scan_context& context) { auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 33; @@ -111,12 +111,12 @@ namespace hat::detail { } template<> - scan_result find_pattern(const scan_context& context) { + const_scan_result find_pattern(const scan_context& context) { return find_pattern_avx2(context); } template<> - scan_result find_pattern(const scan_context& context) { + const_scan_result find_pattern(const scan_context& context) { return find_pattern_avx2(context); } } diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index 8d4d5a1..4224b6b 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -26,7 +26,7 @@ namespace hat::detail { } template - scan_result find_pattern_avx512(const scan_context& context) { + const_scan_result find_pattern_avx512(const scan_context& context) { auto [begin, end, signature, hints] = context; // 512 bit vector containing first signature byte repeated const auto firstByte = _mm512_set1_epi8(static_cast(*signature[0])); @@ -89,7 +89,7 @@ namespace hat::detail { } template - scan_result find_pattern_avx512(const scan_context& context) { + const_scan_result find_pattern_avx512(const scan_context& context) { auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 65; @@ -106,12 +106,12 @@ namespace hat::detail { } template<> - scan_result find_pattern(const scan_context& context) { + const_scan_result find_pattern(const scan_context& context) { return find_pattern_avx512(context); } template<> - scan_result find_pattern(const scan_context& context) { + const_scan_result find_pattern(const scan_context& context) { return find_pattern_avx512(context); } } diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index 3eee222..cb10956 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -26,7 +26,7 @@ namespace hat::detail { } template - scan_result find_pattern_sse(const scan_context& context) { + const_scan_result find_pattern_sse(const scan_context& context) { auto [begin, end, signature, hints] = context; // 256 bit vector containing first signature byte repeated @@ -92,7 +92,7 @@ namespace hat::detail { } template - scan_result find_pattern_sse(const scan_context& context) { + const_scan_result find_pattern_sse(const scan_context& context) { auto& signature = context.signature; const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value(); const bool veccmp = signature.size() <= 17; @@ -109,12 +109,12 @@ namespace hat::detail { } template<> - scan_result find_pattern(const scan_context& context) { + const_scan_result find_pattern(const scan_context& context) { return find_pattern_sse(context); } template<> - scan_result find_pattern(const scan_context& context) { + const_scan_result find_pattern(const scan_context& context) { return find_pattern_sse(context); } }