diff --git a/src/Scanner.cpp b/src/Scanner.cpp index 7ccbdc9..4eff148 100644 --- a/src/Scanner.cpp +++ b/src/Scanner.cpp @@ -35,17 +35,16 @@ namespace hat::detail { template scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { - const auto size = signature.size(); #if defined(LIBHAT_X86) const auto& ext = get_system().extensions; if (ext.bmi1) { - if (size <= 65 && ext.avx512) { + if (ext.avx512) { return find_pattern(begin, end, signature); - } else if (size <= 33 && ext.avx2) { + } else if (ext.avx2) { return find_pattern(begin, end, signature); } } - if (size <= 17 && ext.sse41) { + if (ext.sse41) { return find_pattern(begin, end, signature); } #endif diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index 9f51963..bf1082e 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -25,7 +25,7 @@ namespace hat::detail { ); } - template + template scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) { // 256 bit vector containing first signature byte repeated const auto firstByte = _mm256_set1_epi8(static_cast(*signature[0])); @@ -62,11 +62,20 @@ namespace hat::detail { while (mask) { const auto offset = _tzcnt_u32(mask); const auto i = reinterpret_cast(vec) + offset; - const auto data = _mm256_loadu_si256(reinterpret_cast(i + 1)); - const auto cmpToSig = _mm256_cmpeq_epi8(signatureBytes, data); - const auto matched = _mm256_testc_si256(cmpToSig, signatureMask); - if (matched) { - return i; + if constexpr (veccmp) { + const auto data = _mm256_loadu_si256(reinterpret_cast(i + 1)); + const auto cmpToSig = _mm256_cmpeq_epi8(signatureBytes, data); + const auto matched = _mm256_testc_si256(cmpToSig, signatureMask); + if (matched) { + return i; + } + } else { + auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) { + return !opt.has_value() || *opt == byte; + }); + if (match) { + return i; + } } mask = _blsr_u32(mask); } @@ -79,10 +88,17 @@ namespace hat::detail { template scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) { - if (signature.size() > 1 && signature[1].has_value()) { - return find_pattern_avx2(begin, end, signature); + const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + const bool veccmp = signature.size() <= 33; + + if (cmpeq2 && veccmp) { + return find_pattern_avx2(begin, end, signature); + } else if (cmpeq2) { + return find_pattern_avx2(begin, end, signature); + } else if (veccmp) { + return find_pattern_avx2(begin, end, signature); } else { - return find_pattern_avx2(begin, end, signature); + return find_pattern_avx2(begin, end, signature); } } diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index ea7092a..c253043 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -25,7 +25,7 @@ namespace hat::detail { ); } - template + template scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) { // 512 bit vector containing first signature byte repeated const auto firstByte = _mm512_set1_epi8(static_cast(*signature[0])); @@ -58,10 +58,19 @@ namespace hat::detail { while (mask) { const auto offset = LIBHAT_TZCNT64(mask); const auto i = reinterpret_cast(vec) + offset; - const auto data = _mm512_loadu_si512(i + 1); - const auto invalid = _mm512_mask_cmpneq_epi8_mask(signatureMask, signatureBytes, data); - if (!invalid) { - return i; + if constexpr (veccmp) { + const auto data = _mm512_loadu_si512(i + 1); + const auto invalid = _mm512_mask_cmpneq_epi8_mask(signatureMask, signatureBytes, data); + if (!invalid) { + return i; + } + } else { + auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) { + return !opt.has_value() || *opt == byte; + }); + if (match) { + return i; + } } mask = LIBHAT_BLSR64(mask); } @@ -74,10 +83,17 @@ namespace hat::detail { template scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) { - if (signature.size() > 1 && signature[1].has_value()) { - return find_pattern_avx512(begin, end, signature); + const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + const bool veccmp = signature.size() <= 65; + + if (cmpeq2 && veccmp) { + return find_pattern_avx512(begin, end, signature); + } else if (cmpeq2) { + return find_pattern_avx512(begin, end, signature); + } else if (veccmp) { + return find_pattern_avx512(begin, end, signature); } else { - return find_pattern_avx512(begin, end, signature); + return find_pattern_avx512(begin, end, signature); } } diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index ba63e79..060c0f1 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -25,7 +25,7 @@ namespace hat::detail { ); } - template + template scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) { // 256 bit vector containing first signature byte repeated const auto firstByte = _mm_set1_epi8(static_cast(*signature[0])); @@ -60,11 +60,20 @@ namespace hat::detail { while (mask) { const auto offset = LIBHAT_BSF32(mask); const auto i = reinterpret_cast(vec) + offset; - const auto data = _mm_loadu_si128(reinterpret_cast(i + 1)); - const auto cmpToSig = _mm_cmpeq_epi8(signatureBytes, data); - const auto matched = _mm_testc_si128(cmpToSig, signatureMask); - if (matched) { - return i; + if constexpr (veccmp) { + const auto data = _mm_loadu_si128(reinterpret_cast(i + 1)); + const auto cmpToSig = _mm_cmpeq_epi8(signatureBytes, data); + const auto matched = _mm_testc_si128(cmpToSig, signatureMask); + if (matched) { + return i; + } + } else { + auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) { + return !opt.has_value() || *opt == byte; + }); + if (match) { + return i; + } } mask &= (mask - 1); } @@ -77,10 +86,17 @@ namespace hat::detail { template scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) { - if (signature.size() > 1 && signature[1].has_value()) { - return find_pattern_sse(begin, end, signature); + const bool cmpeq2 = signature.size() > 1 && signature[1].has_value(); + const bool veccmp = signature.size() <= 17; + + if (cmpeq2 && veccmp) { + return find_pattern_sse(begin, end, signature); + } else if (cmpeq2) { + return find_pattern_sse(begin, end, signature); + } else if (veccmp) { + return find_pattern_sse(begin, end, signature); } else { - return find_pattern_sse(begin, end, signature); + return find_pattern_sse(begin, end, signature); } }