Skip to content

Commit

Permalink
Improve find_pattern performance for sigs that exceed register size
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Oct 18, 2023
1 parent 70f32a8 commit 9af066b
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 30 deletions.
7 changes: 3 additions & 4 deletions src/Scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,16 @@ namespace hat::detail {

template<scan_alignment alignment>
scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) {
const auto size = signature.size();
#if defined(LIBHAT_X86)
const auto& ext = get_system().extensions;
if (ext.bmi1) {
if (size <= 65 && ext.avx512) {
if (ext.avx512) {
return find_pattern<scan_mode::AVX512, alignment>(begin, end, signature);
} else if (size <= 33 && ext.avx2) {
} else if (ext.avx2) {
return find_pattern<scan_mode::AVX2, alignment>(begin, end, signature);
}
}
if (size <= 17 && ext.sse41) {
if (ext.sse41) {
return find_pattern<scan_mode::SSE, alignment>(begin, end, signature);
}
#endif
Expand Down
34 changes: 25 additions & 9 deletions src/arch/x86/AVX2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace hat::detail {
);
}

template<scan_alignment alignment, bool cmpeq2>
template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) {
// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[0]));
Expand Down Expand Up @@ -62,11 +62,20 @@ namespace hat::detail {
while (mask) {
const auto offset = _tzcnt_u32(mask);
const auto i = reinterpret_cast<const std::byte*>(vec) + offset;
const auto data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(i + 1));
const auto cmpToSig = _mm256_cmpeq_epi8(signatureBytes, data);
const auto matched = _mm256_testc_si256(cmpToSig, signatureMask);
if (matched) {
return i;
if constexpr (veccmp) {
const auto data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(i + 1));
const auto cmpToSig = _mm256_cmpeq_epi8(signatureBytes, data);
const auto matched = _mm256_testc_si256(cmpToSig, signatureMask);
if (matched) {
return i;
}
} else {
auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) {
return !opt.has_value() || *opt == byte;
});
if (match) {
return i;
}
}
mask = _blsr_u32(mask);
}
Expand All @@ -79,10 +88,17 @@ namespace hat::detail {

template<scan_alignment alignment>
scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) {
if (signature.size() > 1 && signature[1].has_value()) {
return find_pattern_avx2<alignment, true>(begin, end, signature);
const bool cmpeq2 = signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 33;

if (cmpeq2 && veccmp) {
return find_pattern_avx2<alignment, true, true>(begin, end, signature);
} else if (cmpeq2) {
return find_pattern_avx2<alignment, true, false>(begin, end, signature);
} else if (veccmp) {
return find_pattern_avx2<alignment, false, true>(begin, end, signature);
} else {
return find_pattern_avx2<alignment, false>(begin, end, signature);
return find_pattern_avx2<alignment, false, false>(begin, end, signature);
}
}

Expand Down
32 changes: 24 additions & 8 deletions src/arch/x86/AVX512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace hat::detail {
);
}

template<scan_alignment alignment, bool cmpeq2>
template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) {
// 512 bit vector containing first signature byte repeated
const auto firstByte = _mm512_set1_epi8(static_cast<int8_t>(*signature[0]));
Expand Down Expand Up @@ -58,10 +58,19 @@ namespace hat::detail {
while (mask) {
const auto offset = LIBHAT_TZCNT64(mask);
const auto i = reinterpret_cast<const std::byte*>(vec) + offset;
const auto data = _mm512_loadu_si512(i + 1);
const auto invalid = _mm512_mask_cmpneq_epi8_mask(signatureMask, signatureBytes, data);
if (!invalid) {
return i;
if constexpr (veccmp) {
const auto data = _mm512_loadu_si512(i + 1);
const auto invalid = _mm512_mask_cmpneq_epi8_mask(signatureMask, signatureBytes, data);
if (!invalid) {
return i;
}
} else {
auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) {
return !opt.has_value() || *opt == byte;
});
if (match) {
return i;
}
}
mask = LIBHAT_BLSR64(mask);
}
Expand All @@ -74,10 +83,17 @@ namespace hat::detail {

template<scan_alignment alignment>
scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) {
if (signature.size() > 1 && signature[1].has_value()) {
return find_pattern_avx512<alignment, true>(begin, end, signature);
const bool cmpeq2 = signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 65;

if (cmpeq2 && veccmp) {
return find_pattern_avx512<alignment, true, true>(begin, end, signature);
} else if (cmpeq2) {
return find_pattern_avx512<alignment, true, false>(begin, end, signature);
} else if (veccmp) {
return find_pattern_avx512<alignment, false, true>(begin, end, signature);
} else {
return find_pattern_avx512<alignment, false>(begin, end, signature);
return find_pattern_avx512<alignment, false, false>(begin, end, signature);
}
}

Expand Down
34 changes: 25 additions & 9 deletions src/arch/x86/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace hat::detail {
);
}

template<scan_alignment alignment, bool cmpeq2>
template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) {
// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm_set1_epi8(static_cast<int8_t>(*signature[0]));
Expand Down Expand Up @@ -60,11 +60,20 @@ namespace hat::detail {
while (mask) {
const auto offset = LIBHAT_BSF32(mask);
const auto i = reinterpret_cast<const std::byte*>(vec) + offset;
const auto data = _mm_loadu_si128(reinterpret_cast<const __m128i*>(i + 1));
const auto cmpToSig = _mm_cmpeq_epi8(signatureBytes, data);
const auto matched = _mm_testc_si128(cmpToSig, signatureMask);
if (matched) {
return i;
if constexpr (veccmp) {
const auto data = _mm_loadu_si128(reinterpret_cast<const __m128i*>(i + 1));
const auto cmpToSig = _mm_cmpeq_epi8(signatureBytes, data);
const auto matched = _mm_testc_si128(cmpToSig, signatureMask);
if (matched) {
return i;
}
} else {
auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) {
return !opt.has_value() || *opt == byte;
});
if (match) {
return i;
}
}
mask &= (mask - 1);
}
Expand All @@ -77,10 +86,17 @@ namespace hat::detail {

template<scan_alignment alignment>
scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) {
if (signature.size() > 1 && signature[1].has_value()) {
return find_pattern_sse<alignment, true>(begin, end, signature);
const bool cmpeq2 = signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 17;

if (cmpeq2 && veccmp) {
return find_pattern_sse<alignment, true, true>(begin, end, signature);
} else if (cmpeq2) {
return find_pattern_sse<alignment, true, false>(begin, end, signature);
} else if (veccmp) {
return find_pattern_sse<alignment, false, true>(begin, end, signature);
} else {
return find_pattern_sse<alignment, false>(begin, end, signature);
return find_pattern_sse<alignment, false, false>(begin, end, signature);
}
}

Expand Down

0 comments on commit 9af066b

Please sign in to comment.