Skip to content

Commit

Permalink
Hints soon TM
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Feb 13, 2024
1 parent 82de2e6 commit 57efa1b
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 75 deletions.
41 changes: 34 additions & 7 deletions include/libhat/Scanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,30 @@ namespace hat {
X16 = 16
};

enum class scan_hint : uint64_t {
none = 0, // no hints
x86_64 = 1 << 0, // The data being scanned is x86_64 machine code
};

constexpr scan_hint operator|(scan_hint lhs, scan_hint rhs) {
using U = std::underlying_type_t<scan_hint>;
return static_cast<scan_hint>(static_cast<U>(lhs) | static_cast<U>(rhs));
}

constexpr scan_hint operator&(scan_hint lhs, scan_hint rhs) {
using U = std::underlying_type_t<scan_hint>;
return static_cast<scan_hint>(static_cast<U>(lhs) & static_cast<U>(rhs));
}

namespace detail {

struct scan_context {
const std::byte* begin{};
const std::byte* end{};
signature_view signature{};
scan_hint hints{};
};

enum class scan_mode {
FastFirst, // std::find + std::equal
SSE, // x86 SSE 4.1
Expand Down Expand Up @@ -100,13 +122,14 @@ namespace hat {
}

template<scan_mode, scan_alignment>
scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature);
scan_result find_pattern(const scan_context&);

template<scan_alignment alignment>
scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature);
scan_result find_pattern(const scan_context&);

template<>
inline constexpr scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
inline constexpr scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const scan_context& context) {
auto [begin, end, signature, _] = context;
const auto firstByte = *signature[0];
const auto scanEnd = end - signature.size() + 1;

Expand All @@ -132,7 +155,8 @@ namespace hat {
}

template<>
inline scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X16>(const std::byte* begin, const std::byte* end, signature_view signature) {
inline scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X16>(const scan_context& context) {
auto [begin, end, signature, _] = context;
const auto firstByte = *signature[0];

const auto scanBegin = next_boundary_align<scan_alignment::X16>(begin);
Expand Down Expand Up @@ -176,7 +200,8 @@ namespace hat {
constexpr scan_result find_pattern(
Iter beginIt,
Iter endIt,
signature_view signature
signature_view signature,
scan_hint hints = scan_hint::none
) {
// Truncate the leading wildcards from the signature
size_t offset = 0;
Expand All @@ -196,9 +221,9 @@ namespace hat {

hat::scan_result result;
if LIBHAT_IF_CONSTEVAL {
result = detail::find_pattern<detail::scan_mode::Single, alignment>(begin, end, signature);
result = detail::find_pattern<detail::scan_mode::Single, alignment>({begin, end, signature, hints});
} else {
result = detail::find_pattern<alignment>(begin, end, signature);
result = detail::find_pattern<alignment>({begin, end, signature, hints});
}
return result.has_result() ? result.get() - offset : nullptr;
}
Expand All @@ -218,3 +243,5 @@ namespace hat::experimental {
process::module_t mod = process::get_process_module()
);
}

#include "Scanner.inl"
22 changes: 22 additions & 0 deletions include/libhat/Scanner.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

namespace hat {

template<scan_alignment alignment>
scan_result find_pattern(const signature_view signature, const hat::process::module_t mod) {
const auto data = hat::process::get_module_data(mod);
if (data.empty()) {
return nullptr;
}
return find_pattern<alignment>(data.begin(), data.end(), signature);
}

template<scan_alignment alignment>
scan_result find_pattern(const signature_view signature, const std::string_view section, const hat::process::module_t mod) {
const auto data = hat::process::get_section_data(mod, section);
if (data.empty()) {
return nullptr;
}
return find_pattern<alignment>(data.begin(), data.end(), signature);
}
}
42 changes: 7 additions & 35 deletions src/Scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,32 @@
#include <libhat/Defines.hpp>
#include "System.hpp"

namespace hat {

using namespace hat::process;

template<scan_alignment alignment>
scan_result find_pattern(signature_view signature, module_t mod) {
const auto data = get_module_data(mod);
if (data.empty()) {
return nullptr;
}
return find_pattern<alignment>(data.begin(), data.end(), signature);
}

template<scan_alignment alignment>
scan_result find_pattern(signature_view signature, std::string_view section, module_t mod) {
const auto data = get_section_data(mod, section);
if (data.empty()) {
return nullptr;
}
return find_pattern<alignment>(data.begin(), data.end(), signature);
}

template scan_result find_pattern<scan_alignment::X1>(signature_view, module_t);
template scan_result find_pattern<scan_alignment::X1>(signature_view, std::string_view, module_t);
template scan_result find_pattern<scan_alignment::X16>(signature_view, module_t);
template scan_result find_pattern<scan_alignment::X16>(signature_view, std::string_view, module_t);
}

namespace hat::detail {

template<scan_alignment alignment>
scan_result find_pattern(const std::byte* begin, const std::byte* end, signature_view signature) {
scan_result find_pattern(const scan_context& context) {
#if defined(LIBHAT_X86)
const auto& ext = get_system().extensions;
if (ext.bmi1) {
#if !defined(LIBHAT_DISABLE_AVX512)
if (ext.avx512) {
return find_pattern<scan_mode::AVX512, alignment>(begin, end, signature);
return find_pattern<scan_mode::AVX512, alignment>(context);
}
#endif
if (ext.avx2) {
return find_pattern<scan_mode::AVX2, alignment>(begin, end, signature);
return find_pattern<scan_mode::AVX2, alignment>(context);
}
}
#if !defined(LIBHAT_DISABLE_SSE)
if (ext.sse41) {
return find_pattern<scan_mode::SSE, alignment>(begin, end, signature);
return find_pattern<scan_mode::SSE, alignment>(context);
}
#endif
#endif
// If none of the vectorized implementations are available/supported, then fallback to scanning per-byte
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
return find_pattern<scan_mode::Single, alignment>(context);
}

template scan_result find_pattern<scan_alignment::X1>(const std::byte*, const std::byte*, signature_view);
template scan_result find_pattern<scan_alignment::X16>(const std::byte*, const std::byte*, signature_view);
template scan_result find_pattern<scan_alignment::X1>(const scan_context& context);
template scan_result find_pattern<scan_alignment::X16>(const scan_context& context);
}
25 changes: 14 additions & 11 deletions src/arch/x86/AVX2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) {
scan_result find_pattern_avx2(const scan_context& context) {
auto [begin, end, signature, hints] = context;

// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[0]));

Expand Down Expand Up @@ -88,33 +90,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 256 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
return find_pattern<scan_mode::Single, alignment>({begin, end, signature, hints});
}

template<scan_alignment alignment>
scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, signature_view signature) {
scan_result find_pattern_avx2(const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 33;

if (cmpeq2 && veccmp) {
return find_pattern_avx2<alignment, true, true>(begin, end, signature);
return find_pattern_avx2<alignment, true, true>(context);
} else if (cmpeq2) {
return find_pattern_avx2<alignment, true, false>(begin, end, signature);
return find_pattern_avx2<alignment, true, false>(context);
} else if (veccmp) {
return find_pattern_avx2<alignment, false, true>(begin, end, signature);
return find_pattern_avx2<alignment, false, true>(context);
} else {
return find_pattern_avx2<alignment, false, false>(begin, end, signature);
return find_pattern_avx2<alignment, false, false>(context);
}
}

template<>
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx2<scan_alignment::X1>(begin, end, signature);
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const scan_context& context) {
return find_pattern_avx2<scan_alignment::X1>(context);
}

template<>
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X16>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx2<scan_alignment::X16>(begin, end, signature);
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X16>(const scan_context& context) {
return find_pattern_avx2<scan_alignment::X16>(context);
}
}
#endif
24 changes: 13 additions & 11 deletions src/arch/x86/AVX512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) {
scan_result find_pattern_avx512(const scan_context& context) {
auto [begin, end, signature, hints] = context;
// 512 bit vector containing first signature byte repeated
const auto firstByte = _mm512_set1_epi8(static_cast<int8_t>(*signature[0]));

Expand Down Expand Up @@ -84,33 +85,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 512 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
return find_pattern<scan_mode::Single, alignment>({begin, end, signature, hints});
}

template<scan_alignment alignment>
scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, signature_view signature) {
scan_result find_pattern_avx512(const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 65;

if (cmpeq2 && veccmp) {
return find_pattern_avx512<alignment, true, true>(begin, end, signature);
return find_pattern_avx512<alignment, true, true>(context);
} else if (cmpeq2) {
return find_pattern_avx512<alignment, true, false>(begin, end, signature);
return find_pattern_avx512<alignment, true, false>(context);
} else if (veccmp) {
return find_pattern_avx512<alignment, false, true>(begin, end, signature);
return find_pattern_avx512<alignment, false, true>(context);
} else {
return find_pattern_avx512<alignment, false, false>(begin, end, signature);
return find_pattern_avx512<alignment, false, false>(context);
}
}

template<>
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx512<scan_alignment::X1>(begin, end, signature);
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const scan_context& context) {
return find_pattern_avx512<scan_alignment::X1>(context);
}

template<>
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X16>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_avx512<scan_alignment::X16>(begin, end, signature);
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X16>(const scan_context& context) {
return find_pattern_avx512<scan_alignment::X16>(context);
}
}
#endif
25 changes: 14 additions & 11 deletions src/arch/x86/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) {
scan_result find_pattern_sse(const scan_context& context) {
auto [begin, end, signature, hints] = context;

// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm_set1_epi8(static_cast<int8_t>(*signature[0]));

Expand Down Expand Up @@ -86,33 +88,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 128 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>(begin, end, signature);
return find_pattern<scan_mode::Single, alignment>({begin, end, signature, hints});
}

template<scan_alignment alignment>
scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, signature_view signature) {
scan_result find_pattern_sse(const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 17;

if (cmpeq2 && veccmp) {
return find_pattern_sse<alignment, true, true>(begin, end, signature);
return find_pattern_sse<alignment, true, true>(context);
} else if (cmpeq2) {
return find_pattern_sse<alignment, true, false>(begin, end, signature);
return find_pattern_sse<alignment, true, false>(context);
} else if (veccmp) {
return find_pattern_sse<alignment, false, true>(begin, end, signature);
return find_pattern_sse<alignment, false, true>(context);
} else {
return find_pattern_sse<alignment, false, false>(begin, end, signature);
return find_pattern_sse<alignment, false, false>(context);
}
}

template<>
scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_sse<scan_alignment::X1>(begin, end, signature);
scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const scan_context& context) {
return find_pattern_sse<scan_alignment::X1>(context);
}

template<>
scan_result find_pattern<scan_mode::SSE, scan_alignment::X16>(const std::byte* begin, const std::byte* end, signature_view signature) {
return find_pattern_sse<scan_alignment::X16>(begin, end, signature);
scan_result find_pattern<scan_mode::SSE, scan_alignment::X16>(const scan_context& context) {
return find_pattern_sse<scan_alignment::X16>(context);
}
}
#endif

0 comments on commit 57efa1b

Please sign in to comment.