Skip to content

Commit

Permalink
Improve find_pattern performance for sigs starting with 2 bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Oct 18, 2023
1 parent 63127ef commit 70f32a8
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 3 deletions.
26 changes: 25 additions & 1 deletion src/arch/x86/AVX2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ namespace hat::detail {
);
}

template<scan_alignment alignment>
template<scan_alignment alignment, bool cmpeq2>
scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) {
// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[0]));

__m256i secondByte;
if constexpr (cmpeq2) {
secondByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[1]));
}

alignas(__m256i) const auto [signatureBytes, signatureMask] = load_signature_256(signature);

begin = next_boundary_align<alignment>(begin);
Expand All @@ -43,6 +49,15 @@ namespace hat::detail {
for (; vec != e; vec++) {
const auto cmp = _mm256_cmpeq_epi8(firstByte, _mm256_loadu_si256(vec));
auto mask = static_cast<uint32_t>(_mm256_movemask_epi8(cmp));

if constexpr (cmpeq2) {
const auto cmp2 = _mm256_cmpeq_epi8(secondByte, _mm256_loadu_si256(vec));
auto mask2 = static_cast<uint32_t>(_mm256_movemask_epi8(cmp2));
// avoid loading unaligned memory by letting a match of the first signature byte in the last
// position imply that the second byte also matched
mask &= (mask2 >> 1) | (0b1u << 31);
}

mask &= create_alignment_mask<uint32_t, alignment>();
while (mask) {
const auto offset = _tzcnt_u32(mask);
Expand All @@ -62,6 +77,15 @@ namespace hat::detail {
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
}

template<scan_alignment alignment>
scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) {
if (signature.size() > 1 && signature[1].has_value()) {
return find_pattern_avx2<alignment, true>(begin, end, signature);
} else {
return find_pattern_avx2<alignment, false>(begin, end, signature);
}
}

template<>
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx2<scan_alignment::X1>(begin, end, signature);
Expand Down
23 changes: 22 additions & 1 deletion src/arch/x86/AVX512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ namespace hat::detail {
);
}

template<scan_alignment alignment>
template<scan_alignment alignment, bool cmpeq2>
scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) {
// 512 bit vector containing first signature byte repeated
const auto firstByte = _mm512_set1_epi8(static_cast<int8_t>(*signature[0]));

__m512i secondByte;
if constexpr (cmpeq2) {
secondByte = _mm512_set1_epi8(static_cast<int8_t>(*signature[1]));
}

alignas(__m512i) const auto [signatureBytes, signatureMask] = load_signature_512(signature);

begin = next_boundary_align<alignment>(begin);
Expand All @@ -42,6 +48,12 @@ namespace hat::detail {

for (; vec != e; vec++) {
auto mask = _mm512_cmpeq_epi8_mask(firstByte, _mm512_loadu_si512(vec));

if constexpr (cmpeq2) {
const auto mask2 = _mm512_cmpeq_epi8_mask(secondByte, _mm512_loadu_si512(vec));
mask &= (mask2 >> 1) | (0b1ull << 63);
}

mask &= create_alignment_mask<uint64_t, alignment>();
while (mask) {
const auto offset = LIBHAT_TZCNT64(mask);
Expand All @@ -60,6 +72,15 @@ namespace hat::detail {
return find_pattern<scan_mode::Single, scan_alignment::X1>(begin, end, signature);
}

template<scan_alignment alignment>
scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) {
if (signature.size() > 1 && signature[1].has_value()) {
return find_pattern_avx512<alignment, true>(begin, end, signature);
} else {
return find_pattern_avx512<alignment, false>(begin, end, signature);
}
}

template<>
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx512<scan_alignment::X1>(begin, end, signature);
Expand Down
24 changes: 23 additions & 1 deletion src/arch/x86/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ namespace hat::detail {
);
}

template<scan_alignment alignment>
template<scan_alignment alignment, bool cmpeq2>
scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) {
// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm_set1_epi8(static_cast<int8_t>(*signature[0]));

__m128i secondByte;
if constexpr (cmpeq2) {
secondByte = _mm_set1_epi8(static_cast<int8_t>(*signature[1]));
}

alignas(__m128i) const auto [signatureBytes, signatureMask] = load_signature_128(signature);

begin = next_boundary_align<alignment>(begin);
Expand All @@ -43,6 +49,13 @@ namespace hat::detail {
for (; vec != e; vec++) {
const auto cmp = _mm_cmpeq_epi8(firstByte, _mm_loadu_si128(vec));
auto mask = static_cast<uint16_t>(_mm_movemask_epi8(cmp));

if constexpr (cmpeq2) {
const auto cmp2 = _mm_cmpeq_epi8(secondByte, _mm_loadu_si128(vec));
auto mask2 = static_cast<uint16_t>(_mm_movemask_epi8(cmp2));
mask &= (mask2 >> 1) | (0b1u << 15);
}

mask &= create_alignment_mask<uint16_t, alignment>();
while (mask) {
const auto offset = LIBHAT_BSF32(mask);
Expand All @@ -62,6 +75,15 @@ namespace hat::detail {
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
}

template<scan_alignment alignment>
scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) {
if (signature.size() > 1 && signature[1].has_value()) {
return find_pattern_sse<alignment, true>(begin, end, signature);
} else {
return find_pattern_sse<alignment, false>(begin, end, signature);
}
}

template<>
scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_sse<scan_alignment::X1>(begin, end, signature);
Expand Down

0 comments on commit 70f32a8

Please sign in to comment.