Skip to content

Commit

Permalink
Resolve scanner function in scan_context
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Aug 1, 2024
1 parent f9e2800 commit 707e10d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 43 deletions.
77 changes: 42 additions & 35 deletions include/libhat/Scanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,35 @@ namespace hat {

namespace detail {

struct scan_context {
class scan_context;

using scan_function_t = const_scan_result(*)(const std::byte* begin, const std::byte* end, const scan_context& context);

class scan_context {
public:
signature_view signature{};
scan_function_t scanner{};
scan_alignment alignment{};
size_t vectorSize{};
scan_hint hints{};

static constexpr scan_context create(const signature_view signature, const scan_hint hints, const scan_alignment alignment) {
scan_context ctx{};
ctx.signature = signature;
ctx.hints = hints;
if LIBHAT_IF_CONSTEVAL {} else {
ctx.apply_hints(alignment);
}
return ctx;
[[nodiscard]] constexpr const_scan_result scan(const std::byte* begin, const std::byte* end) const {
return this->scanner(begin, end, *this);
}

void apply_hints();

template<scan_alignment alignment>
static constexpr scan_context create(signature_view signature, scan_hint hints);
private:
scan_context() = default;
void apply_hints(scan_alignment alignment);
};

template<scan_alignment alignment>
[[nodiscard]] std::pair<scan_function_t, size_t> resolve_scanner();

void apply_hints(scan_context& context);

enum class scan_mode {
FastFirst, // std::find + std::equal
SSE, // x86 SSE 4.1
Expand Down Expand Up @@ -143,9 +154,6 @@ namespace hat {
template<scan_mode, scan_alignment>
const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context&);

template<scan_alignment alignment>
const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context&);

template<>
inline constexpr const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;
Expand Down Expand Up @@ -213,6 +221,21 @@ namespace hat {
template<byte_input_iterator T>
using result_type_for = std::conditional_t<std::is_const_v<std::remove_reference_t<std::iter_reference_t<T>>>,
const_scan_result, scan_result>;

template<scan_alignment alignment>
constexpr scan_context scan_context::create(const signature_view signature, const scan_hint hints) {
scan_context ctx{};
ctx.signature = signature;
ctx.hints = hints;
ctx.alignment = alignment;
if LIBHAT_IF_CONSTEVAL {
ctx.scanner = &find_pattern<detail::scan_mode::Single, alignment>;
} else {
std::tie(ctx.scanner, ctx.vectorSize) = resolve_scanner<alignment>();
ctx.apply_hints();
}
return ctx;
}
}

/// Perform a signature scan on the entirety of the process module or a specified module
Expand Down Expand Up @@ -247,14 +270,8 @@ namespace hat {
return {nullptr};
}

const auto context = detail::scan_context::create(trunc, hints, alignment);

const_scan_result result;
if LIBHAT_IF_CONSTEVAL {
result = detail::find_pattern<detail::scan_mode::Single, alignment>(begin, end, context);
} else {
result = detail::find_pattern<alignment>(begin, end, context);
}
const auto context = detail::scan_context::create<alignment>(trunc, hints);
const const_scan_result result = context.scan(begin, end);
return result.has_result()
? const_cast<typename detail::result_type_for<Iter>::underlying_type>(result.get() - offset)
: nullptr;
Expand All @@ -280,15 +297,10 @@ namespace hat {
auto i = begin;
auto out = beginOut;

const auto context = detail::scan_context::create(trunc, hints, alignment);
const auto context = detail::scan_context::create<alignment>(trunc, hints);

while (i < end && out != endOut && trunc.size() <= static_cast<size_t>(std::distance(i, end))) {
const_scan_result result;
if LIBHAT_IF_CONSTEVAL {
result = detail::find_pattern<detail::scan_mode::Single, alignment>(i, end, context);
} else {
result = detail::find_pattern<alignment>(i, end, context);
}
const auto result = context.scan(i, end);
if (!result.has_result()) {
i = end;
break;
Expand Down Expand Up @@ -318,15 +330,10 @@ namespace hat {
auto out = outIn;
size_t matches{};

const auto context = detail::scan_context::create(trunc, hints, alignment);
const auto context = detail::scan_context::create<alignment>(trunc, hints);

while (begin < end && trunc.size() <= static_cast<size_t>(std::distance(i, end))) {
const_scan_result result;
if LIBHAT_IF_CONSTEVAL {
result = detail::find_pattern<detail::scan_mode::Single, alignment>(i, end, context);
} else {
result = detail::find_pattern<alignment>(i, end, context);
}
const auto result = context.scan(i, end);
if (!result.has_result()) {
break;
}
Expand Down
16 changes: 8 additions & 8 deletions src/Scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,34 @@

namespace hat::detail {

void scan_context::apply_hints([[maybe_unused]] const scan_alignment alignment) {}
void scan_context::apply_hints() {}

template<scan_alignment alignment>
const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) {
std::pair<scan_function_t, size_t> resolve_scanner() {
#if defined(LIBHAT_X86)
const auto& ext = get_system().extensions;
if (ext.bmi) {
#if !defined(LIBHAT_DISABLE_AVX512)
if (ext.avx512f && ext.avx512bw) {
return find_pattern<scan_mode::AVX512, alignment>(begin, end, context);
return {&find_pattern<scan_mode::AVX512, alignment>, 64};
}
#endif
if (ext.avx2) {
return find_pattern<scan_mode::AVX2, alignment>(begin, end, context);
return {&find_pattern<scan_mode::AVX2, alignment>, 32};
}
}
#if !defined(LIBHAT_DISABLE_SSE)
if (ext.sse41) {
return find_pattern<scan_mode::SSE, alignment>(begin, end, context);
return {&find_pattern<scan_mode::SSE, alignment>, 16};
}
#endif
#endif
// If none of the vectorized implementations are available/supported, then fallback to scanning per-byte
return find_pattern<scan_mode::Single, alignment>(begin, end, context);
return {&find_pattern<scan_mode::Single, alignment>, 0};
}

template const_scan_result find_pattern<scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context);
template const_scan_result find_pattern<scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context);
template std::pair<scan_function_t, size_t> resolve_scanner<scan_alignment::X1>();
template std::pair<scan_function_t, size_t> resolve_scanner<scan_alignment::X16>();
}

// Validate return value const-ness for the root find_pattern impl
Expand Down

0 comments on commit 707e10d

Please sign in to comment.