Skip to content

Commit

Permalink
Separate begin, end from scan_context
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Jul 31, 2024
1 parent 215ca47 commit 0caead9
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 57 deletions.
45 changes: 31 additions & 14 deletions include/libhat/Scanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,21 @@ namespace hat {
namespace detail {

struct scan_context {
const std::byte* begin{};
const std::byte* end{};
signature_view signature{};
scan_hint hints{};

static constexpr scan_context create(const signature_view signature, const scan_hint hints) {
scan_context ctx{};
ctx.signature = signature;
ctx.hints = hints;
if LIBHAT_IF_CONSTEVAL {} else {
ctx.apply_hints();
}
return ctx;
}
private:
scan_context() = default;
void apply_hints();
};

enum class scan_mode {
Expand Down Expand Up @@ -130,14 +141,14 @@ namespace hat {
}

template<scan_mode, scan_alignment>
const_scan_result find_pattern(const scan_context&);
const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context&);

template<scan_alignment alignment>
const_scan_result find_pattern(const scan_context&);
const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context&);

template<>
inline constexpr const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const scan_context& context) {
auto [begin, end, signature, _] = context;
inline constexpr const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;
const auto firstByte = *signature[0];
const auto scanEnd = end - signature.size() + 1;

Expand All @@ -163,8 +174,8 @@ namespace hat {
}

template<>
inline const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X16>(const scan_context& context) {
auto [begin, end, signature, _] = context;
inline const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;
const auto firstByte = *signature[0];

const auto scanBegin = next_boundary_align<scan_alignment::X16>(begin);
Expand Down Expand Up @@ -236,11 +247,13 @@ namespace hat {
return {nullptr};
}

const auto context = detail::scan_context::create(trunc, hints);

const_scan_result result;
if LIBHAT_IF_CONSTEVAL {
result = detail::find_pattern<detail::scan_mode::Single, alignment>({begin, end, trunc, hints});
result = detail::find_pattern<detail::scan_mode::Single, alignment>(begin, end, context);
} else {
result = detail::find_pattern<alignment>({begin, end, trunc, hints});
result = detail::find_pattern<alignment>(begin, end, context);
}
return result.has_result()
? const_cast<typename detail::result_type_for<Iter>::underlying_type>(result.get() - offset)
Expand All @@ -267,12 +280,14 @@ namespace hat {
auto i = begin;
auto out = beginOut;

const auto context = detail::scan_context::create(trunc, hints);

while (i < end && out != endOut && trunc.size() <= static_cast<size_t>(std::distance(i, end))) {
const_scan_result result;
if LIBHAT_IF_CONSTEVAL {
result = detail::find_pattern<detail::scan_mode::Single, alignment>({i, end, trunc, hints});
result = detail::find_pattern<detail::scan_mode::Single, alignment>(i, end, context);
} else {
result = detail::find_pattern<alignment>({i, end, trunc, hints});
result = detail::find_pattern<alignment>(i, end, context);
}
if (!result.has_result()) {
i = end;
Expand Down Expand Up @@ -303,12 +318,14 @@ namespace hat {
auto out = outIn;
size_t matches{};

const auto context = detail::scan_context::create(trunc, hints);

while (begin < end && trunc.size() <= static_cast<size_t>(std::distance(i, end))) {
const_scan_result result;
if LIBHAT_IF_CONSTEVAL {
result = detail::find_pattern<detail::scan_mode::Single, alignment>({i, end, trunc, hints});
result = detail::find_pattern<detail::scan_mode::Single, alignment>(i, end, context);
} else {
result = detail::find_pattern<alignment>({i, end, trunc, hints});
result = detail::find_pattern<alignment>(i, end, context);
}
if (!result.has_result()) {
break;
Expand Down
16 changes: 9 additions & 7 deletions src/Scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,34 @@

namespace hat::detail {

void scan_context::apply_hints() {}

template<scan_alignment alignment>
const_scan_result find_pattern(const scan_context& context) {
const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context& context) {
#if defined(LIBHAT_X86)
const auto& ext = get_system().extensions;
if (ext.bmi) {
#if !defined(LIBHAT_DISABLE_AVX512)
if (ext.avx512f && ext.avx512bw) {
return find_pattern<scan_mode::AVX512, alignment>(context);
return find_pattern<scan_mode::AVX512, alignment>(begin, end, context);
}
#endif
if (ext.avx2) {
return find_pattern<scan_mode::AVX2, alignment>(context);
return find_pattern<scan_mode::AVX2, alignment>(begin, end, context);
}
}
#if !defined(LIBHAT_DISABLE_SSE)
if (ext.sse41) {
return find_pattern<scan_mode::SSE, alignment>(context);
return find_pattern<scan_mode::SSE, alignment>(begin, end, context);
}
#endif
#endif
// If none of the vectorized implementations are available/supported, then fallback to scanning per-byte
return find_pattern<scan_mode::Single, alignment>(context);
return find_pattern<scan_mode::Single, alignment>(begin, end, context);
}

template const_scan_result find_pattern<scan_alignment::X1>(const scan_context& context);
template const_scan_result find_pattern<scan_alignment::X16>(const scan_context& context);
template const_scan_result find_pattern<scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context);
template const_scan_result find_pattern<scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context);
}

// Validate return value const-ness for the root find_pattern impl
Expand Down
24 changes: 12 additions & 12 deletions src/arch/x86/AVX2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
const_scan_result find_pattern_avx2(const scan_context& context) {
auto [begin, end, signature, hints] = context;
const_scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;

// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[0]));
Expand Down Expand Up @@ -90,34 +90,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 256 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>({begin, end, signature, hints});
return find_pattern<scan_mode::Single, alignment>(begin, end, context);
}

template<scan_alignment alignment>
const_scan_result find_pattern_avx2(const scan_context& context) {
const_scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 33;

if (cmpeq2 && veccmp) {
return find_pattern_avx2<alignment, true, true>(context);
return find_pattern_avx2<alignment, true, true>(begin, end, context);
} else if (cmpeq2) {
return find_pattern_avx2<alignment, true, false>(context);
return find_pattern_avx2<alignment, true, false>(begin, end, context);
} else if (veccmp) {
return find_pattern_avx2<alignment, false, true>(context);
return find_pattern_avx2<alignment, false, true>(begin, end, context);
} else {
return find_pattern_avx2<alignment, false, false>(context);
return find_pattern_avx2<alignment, false, false>(begin, end, context);
}
}

template<>
const_scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const scan_context& context) {
return find_pattern_avx2<scan_alignment::X1>(context);
const_scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_avx2<scan_alignment::X1>(begin, end, context);
}

template<>
const_scan_result find_pattern<scan_mode::AVX2, scan_alignment::X16>(const scan_context& context) {
return find_pattern_avx2<scan_alignment::X16>(context);
const_scan_result find_pattern<scan_mode::AVX2, scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_avx2<scan_alignment::X16>(begin, end, context);
}
}
#endif
24 changes: 12 additions & 12 deletions src/arch/x86/AVX512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
const_scan_result find_pattern_avx512(const scan_context& context) {
auto [begin, end, signature, hints] = context;
const_scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;
// 512 bit vector containing first signature byte repeated
const auto firstByte = _mm512_set1_epi8(static_cast<int8_t>(*signature[0]));

Expand Down Expand Up @@ -85,34 +85,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 512 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>({begin, end, signature, hints});
return find_pattern<scan_mode::Single, alignment>(begin, end, context);
}

template<scan_alignment alignment>
const_scan_result find_pattern_avx512(const scan_context& context) {
const_scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 65;

if (cmpeq2 && veccmp) {
return find_pattern_avx512<alignment, true, true>(context);
return find_pattern_avx512<alignment, true, true>(begin, end, context);
} else if (cmpeq2) {
return find_pattern_avx512<alignment, true, false>(context);
return find_pattern_avx512<alignment, true, false>(begin, end, context);
} else if (veccmp) {
return find_pattern_avx512<alignment, false, true>(context);
return find_pattern_avx512<alignment, false, true>(begin, end, context);
} else {
return find_pattern_avx512<alignment, false, false>(context);
return find_pattern_avx512<alignment, false, false>(begin, end, context);
}
}

template<>
const_scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const scan_context& context) {
return find_pattern_avx512<scan_alignment::X1>(context);
const_scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_avx512<scan_alignment::X1>(begin, end, context);
}

template<>
const_scan_result find_pattern<scan_mode::AVX512, scan_alignment::X16>(const scan_context& context) {
return find_pattern_avx512<scan_alignment::X16>(context);
const_scan_result find_pattern<scan_mode::AVX512, scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_avx512<scan_alignment::X16>(begin, end, context);
}
}
#endif
24 changes: 12 additions & 12 deletions src/arch/x86/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
const_scan_result find_pattern_sse(const scan_context& context) {
auto [begin, end, signature, hints] = context;
const_scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;

// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm_set1_epi8(static_cast<int8_t>(*signature[0]));
Expand Down Expand Up @@ -88,34 +88,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 128 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>({begin, end, signature, hints});
return find_pattern<scan_mode::Single, alignment>(begin, end, context);
}

template<scan_alignment alignment>
const_scan_result find_pattern_sse(const scan_context& context) {
const_scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 17;

if (cmpeq2 && veccmp) {
return find_pattern_sse<alignment, true, true>(context);
return find_pattern_sse<alignment, true, true>(begin, end, context);
} else if (cmpeq2) {
return find_pattern_sse<alignment, true, false>(context);
return find_pattern_sse<alignment, true, false>(begin, end, context);
} else if (veccmp) {
return find_pattern_sse<alignment, false, true>(context);
return find_pattern_sse<alignment, false, true>(begin, end, context);
} else {
return find_pattern_sse<alignment, false, false>(context);
return find_pattern_sse<alignment, false, false>(begin, end, context);
}
}

template<>
const_scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const scan_context& context) {
return find_pattern_sse<scan_alignment::X1>(context);
const_scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_sse<scan_alignment::X1>(begin, end, context);
}

template<>
const_scan_result find_pattern<scan_mode::SSE, scan_alignment::X16>(const scan_context& context) {
return find_pattern_sse<scan_alignment::X16>(context);
const_scan_result find_pattern<scan_mode::SSE, scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_sse<scan_alignment::X16>(begin, end, context);
}
}
#endif

0 comments on commit 0caead9

Please sign in to comment.