diff --git a/src/arch/x86/AVX2.cpp b/src/arch/x86/AVX2.cpp index 43baba7..9f51963 100644 --- a/src/arch/x86/AVX2.cpp +++ b/src/arch/x86/AVX2.cpp @@ -25,10 +25,16 @@ namespace hat::detail { ); } - template + template scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) { // 256 bit vector containing first signature byte repeated const auto firstByte = _mm256_set1_epi8(static_cast(*signature[0])); + + __m256i secondByte; + if constexpr (cmpeq2) { + secondByte = _mm256_set1_epi8(static_cast(*signature[1])); + } + alignas(__m256i) const auto [signatureBytes, signatureMask] = load_signature_256(signature); begin = next_boundary_align(begin); @@ -43,6 +49,15 @@ namespace hat::detail { for (; vec != e; vec++) { const auto cmp = _mm256_cmpeq_epi8(firstByte, _mm256_loadu_si256(vec)); auto mask = static_cast(_mm256_movemask_epi8(cmp)); + + if constexpr (cmpeq2) { + const auto cmp2 = _mm256_cmpeq_epi8(secondByte, _mm256_loadu_si256(vec)); + auto mask2 = static_cast(_mm256_movemask_epi8(cmp2)); + // avoid loading unaligned memory by letting a match of the first signature byte in the last + // position imply that the second byte also matched + mask &= (mask2 >> 1) | (0b1u << 31); + } + mask &= create_alignment_mask(); while (mask) { const auto offset = _tzcnt_u32(mask); @@ -62,6 +77,15 @@ namespace hat::detail { return find_pattern(begin, end, signature); } + template + scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) { + if (signature.size() > 1 && signature[1].has_value()) { + return find_pattern_avx2(begin, end, signature); + } else { + return find_pattern_avx2(begin, end, signature); + } + } + template<> scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { return find_pattern_avx2(begin, end, signature); diff --git a/src/arch/x86/AVX512.cpp b/src/arch/x86/AVX512.cpp index 7fefd92..ea7092a 100644 --- a/src/arch/x86/AVX512.cpp +++ b/src/arch/x86/AVX512.cpp @@ -25,10 +25,16 @@ namespace hat::detail { ); } - template + template scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) { // 512 bit vector containing first signature byte repeated const auto firstByte = _mm512_set1_epi8(static_cast(*signature[0])); + + __m512i secondByte; + if constexpr (cmpeq2) { + secondByte = _mm512_set1_epi8(static_cast(*signature[1])); + } + alignas(__m512i) const auto [signatureBytes, signatureMask] = load_signature_512(signature); begin = next_boundary_align(begin); @@ -42,6 +48,12 @@ namespace hat::detail { for (; vec != e; vec++) { auto mask = _mm512_cmpeq_epi8_mask(firstByte, _mm512_loadu_si512(vec)); + + if constexpr (cmpeq2) { + const auto mask2 = _mm512_cmpeq_epi8_mask(secondByte, _mm512_loadu_si512(vec)); + mask &= (mask2 >> 1) | (0b1ull << 63); + } + mask &= create_alignment_mask(); while (mask) { const auto offset = LIBHAT_TZCNT64(mask); @@ -60,6 +72,15 @@ namespace hat::detail { return find_pattern(begin, end, signature); } + template + scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) { + if (signature.size() > 1 && signature[1].has_value()) { + return find_pattern_avx512(begin, end, signature); + } else { + return find_pattern_avx512(begin, end, signature); + } + } + template<> scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { return find_pattern_avx512(begin, end, signature); diff --git a/src/arch/x86/SSE.cpp b/src/arch/x86/SSE.cpp index 230bc96..ba63e79 100644 --- a/src/arch/x86/SSE.cpp +++ b/src/arch/x86/SSE.cpp @@ -25,10 +25,16 @@ namespace hat::detail { ); } - template + template scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) { // 256 bit vector containing first signature byte repeated const auto firstByte = _mm_set1_epi8(static_cast(*signature[0])); + + __m128i secondByte; + if constexpr (cmpeq2) { + secondByte = _mm_set1_epi8(static_cast(*signature[1])); + } + alignas(__m128i) const auto [signatureBytes, signatureMask] = load_signature_128(signature); begin = next_boundary_align(begin); @@ -43,6 +49,13 @@ namespace hat::detail { for (; vec != e; vec++) { const auto cmp = _mm_cmpeq_epi8(firstByte, _mm_loadu_si128(vec)); auto mask = static_cast(_mm_movemask_epi8(cmp)); + + if constexpr (cmpeq2) { + const auto cmp2 = _mm_cmpeq_epi8(secondByte, _mm_loadu_si128(vec)); + auto mask2 = static_cast(_mm_movemask_epi8(cmp2)); + mask &= (mask2 >> 1) | (0b1u << 15); + } + mask &= create_alignment_mask(); while (mask) { const auto offset = LIBHAT_BSF32(mask); @@ -62,6 +75,15 @@ namespace hat::detail { return find_pattern(begin, end, signature); } + template + scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) { + if (signature.size() > 1 && signature[1].has_value()) { + return find_pattern_sse(begin, end, signature); + } else { + return find_pattern_sse(begin, end, signature); + } + } + template<> scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) { return find_pattern_sse(begin, end, signature);