Skip to content

Commit

Permalink
16-byte alignment scanning support
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Aug 21, 2023
1 parent 5752d6b commit 7aae038
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 14 deletions.
60 changes: 58 additions & 2 deletions include/libhat/Scanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ namespace hat {
};

enum class scan_alignment {
X1
X1 = 1,
X16 = 16
};

namespace detail {
Expand All @@ -67,14 +68,45 @@ namespace hat {
Single = FastFirst
};

template<scan_alignment alignment>
inline constexpr auto alignment_stride = static_cast<std::underlying_type_t<scan_alignment>>(alignment);

template<std::integral type, scan_alignment alignment, auto stride = alignment_stride<alignment>>
inline consteval auto create_alignment_mask() {
type mask{};
for (size_t i = 0; i < sizeof(type) * 8; i += stride) {
mask |= (type(1) << i);
}
return mask;
}

template<scan_alignment alignment, auto stride = alignment_stride<alignment>>
inline const std::byte* next_boundary_align(const std::byte* ptr) {
if constexpr (stride == 1) {
return ptr;
}
uintptr_t mod = reinterpret_cast<uintptr_t>(ptr) % stride;
ptr += mod ? stride - mod : 0;
return ptr;
}

template<scan_alignment alignment, auto stride = alignment_stride<alignment>>
inline const std::byte* prev_boundary_align(const std::byte* ptr) {
if constexpr (stride == 1) {
return ptr;
}
uintptr_t mod = reinterpret_cast<uintptr_t>(ptr) % stride;
return ptr - mod;
}

template<scan_mode, scan_alignment>
scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature);

template<scan_alignment alignment>
scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature);

template<>
constexpr scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
inline constexpr scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
const auto firstByte = *signature[0];
const auto scanEnd = end - signature.size() + 1;

Expand All @@ -98,6 +130,30 @@ namespace hat {
}
return nullptr;
}

template<>
inline scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X16>(const std::byte* begin, const std::byte* end, signature_view signature) {
const auto firstByte = *signature[0];

const auto scanBegin = next_boundary_align<scan_alignment::X16>(begin);
const auto scanEnd = prev_boundary_align<scan_alignment::X16>(end - signature.size() + 1);
if (scanBegin >= scanEnd) {
return {};
}

for (auto i = scanBegin; i != scanEnd; i += 16) {
if (*i == firstByte) {
// Compare everything after the first byte
auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) {
return !opt.has_value() || *opt == byte;
});
if (match) {
return i;
}
}
}
return nullptr;
}
}

enum class compiler_type {
Expand Down
9 changes: 6 additions & 3 deletions src/Scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ namespace hat {
return find_pattern<alignment>(data.begin(), data.end(), signature);
}

template scan_result find_pattern<scan_alignment::X1>(signature_view signature, module_t mod);
template scan_result find_pattern<scan_alignment::X1>(signature_view signature, std::string_view section, module_t mod);
template scan_result find_pattern<scan_alignment::X1>(signature_view, module_t);
template scan_result find_pattern<scan_alignment::X1>(signature_view, std::string_view, module_t);
template scan_result find_pattern<scan_alignment::X16>(signature_view, module_t);
template scan_result find_pattern<scan_alignment::X16>(signature_view, std::string_view, module_t);
}

namespace hat::detail {
Expand All @@ -51,5 +53,6 @@ namespace hat::detail {
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
}

template scan_result find_pattern<scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature);
template scan_result find_pattern<scan_alignment::X1>(const std::byte*, const std::byte*, signature_view);
template scan_result find_pattern<scan_alignment::X16>(const std::byte*, const std::byte*, signature_view);
}
22 changes: 19 additions & 3 deletions src/arch/x86/AVX2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,25 @@ namespace hat::detail {
);
}

template<>
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
template<scan_alignment alignment>
scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) {
// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[0]));
const auto [signatureBytes, signatureMask] = load_signature_256(signature);

begin = next_boundary_align<alignment>(begin);
if (begin >= end) {
return {};
}

auto vec = reinterpret_cast<const __m256i*>(begin);
const auto n = static_cast<size_t>(end - signature.size() - begin) / sizeof(__m256i);
const auto e = vec + n;

for (; vec != e; vec++) {
const auto cmp = _mm256_cmpeq_epi8(firstByte, _mm256_loadu_si256(vec));
auto mask = static_cast<uint32_t>(_mm256_movemask_epi8(cmp));
mask &= create_alignment_mask<uint32_t, alignment>();
while (mask) {
const auto offset = _tzcnt_u32(mask);
const auto i = reinterpret_cast<const std::byte*>(vec) + offset;
Expand All @@ -53,7 +59,17 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 256 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, scan_alignment::X1>(begin, end, signature);
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
}

template<>
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx2<scan_alignment::X1>(begin, end, signature);
}

template<>
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X16>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx2<scan_alignment::X16>(begin, end, signature);
}
}
#endif
20 changes: 18 additions & 2 deletions src/arch/x86/AVX512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,24 @@ namespace hat::detail {
);
}

template<>
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
template<scan_alignment alignment>
scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) {
// 512 bit vector containing first signature byte repeated
const auto firstByte = _mm512_set1_epi8(static_cast<int8_t>(*signature[0]));
const auto [signatureBytes, signatureMask] = load_signature_512(signature);

begin = next_boundary_align<alignment>(begin);
if (begin >= end) {
return {};
}

auto vec = reinterpret_cast<const __m512i*>(begin);
const auto n = static_cast<size_t>(end - signature.size() - begin) / sizeof(__m512i);
const auto e = vec + n;

for (; vec != e; vec++) {
auto mask = _mm512_cmpeq_epi8_mask(firstByte, _mm512_loadu_si512(vec));
mask &= create_alignment_mask<uint64_t, alignment>();
while (mask) {
const auto offset = LIBHAT_TZCNT64(mask);
const auto i = reinterpret_cast<const std::byte*>(vec) + offset;
Expand All @@ -53,5 +59,15 @@ namespace hat::detail {
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, scan_alignment::X1>(begin, end, signature);
}

template<>
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx512<scan_alignment::X1>(begin, end, signature);
}

template<>
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X16>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx512<scan_alignment::X16>(begin, end, signature);
}
}
#endif
24 changes: 20 additions & 4 deletions src/arch/x86/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,25 @@ namespace hat::detail {
);
}

template<>
scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
template<scan_alignment alignment>
scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) {
// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm_set1_epi8(static_cast<int8_t>(*signature[0]));
const auto [signatureBytes, signatureMask] = load_signature_128(signature);

begin = next_boundary_align<alignment>(begin);
if (begin >= end) {
return {};
}

auto vec = reinterpret_cast<const __m128i*>(begin);
const auto n = static_cast<size_t>(end - signature.size() - begin) / sizeof(__m128i);
const auto e = vec + n;

for (; vec != e; vec++) {
const auto cmp = _mm_cmpeq_epi8(firstByte, _mm_loadu_si128(vec));
auto mask = static_cast<uint32_t>(_mm_movemask_epi8(cmp));
auto mask = static_cast<uint16_t>(_mm_movemask_epi8(cmp));
mask &= create_alignment_mask<uint16_t, alignment>();
while (mask) {
const auto offset = LIBHAT_BSF32(mask);
const auto i = reinterpret_cast<const std::byte*>(vec) + offset;
Expand All @@ -53,7 +59,17 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 128 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, scan_alignment::X1>(begin, end, signature);
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
}

template<>
scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_sse<scan_alignment::X1>(begin, end, signature);
}

template<>
scan_result find_pattern<scan_mode::SSE, scan_alignment::X16>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_sse<scan_alignment::X16>(begin, end, signature);
}
}
#endif

0 comments on commit 7aae038

Please sign in to comment.