Skip to content

Commit

Permalink
Fetch actual scanner impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Aug 1, 2024
1 parent 707e10d commit 3ee98fd
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 97 deletions.
46 changes: 27 additions & 19 deletions include/libhat/Scanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,13 @@ namespace hat {
scan_context() = default;
};

template<scan_alignment alignment>
[[nodiscard]] std::pair<scan_function_t, size_t> resolve_scanner();

void apply_hints(scan_context& context);
[[nodiscard]] std::pair<scan_function_t, size_t> resolve_scanner(const scan_context&);

enum class scan_mode {
FastFirst, // std::find + std::equal
SSE, // x86 SSE 4.1
AVX2, // x86 AVX2
AVX512, // x86 AVX512

// Fallback mode to use for SIMD remaining bytes
Single = FastFirst
Single, // std::find + std::equal
SSE, // x86 SSE 4.1
AVX2, // x86 AVX2
AVX512, // x86 AVX512
};

template<scan_alignment alignment>
Expand All @@ -132,8 +126,9 @@ namespace hat {
return mask;
}

template<scan_alignment alignment, auto stride = alignment_stride<alignment>>
template<scan_alignment alignment>
inline const std::byte* next_boundary_align(const std::byte* ptr) {
constexpr auto stride = alignment_stride<alignment>;
if constexpr (stride == 1) {
return ptr;
}
Expand All @@ -142,20 +137,24 @@ namespace hat {
return ptr;
}

template<scan_alignment alignment, auto stride = alignment_stride<alignment>>
template<scan_alignment alignment>
inline const std::byte* prev_boundary_align(const std::byte* ptr) {
constexpr auto stride = alignment_stride<alignment>;
if constexpr (stride == 1) {
return ptr;
}
uintptr_t mod = reinterpret_cast<uintptr_t>(ptr) % stride;
return ptr - mod;
}

template<scan_mode, scan_alignment>
const_scan_result find_pattern(const std::byte* begin, const std::byte* end, const scan_context&);
template<scan_mode>
scan_function_t get_scanner(const scan_context&);

template<scan_alignment>
const_scan_result find_pattern_single(const std::byte* begin, const std::byte* end, const scan_context&);

template<>
inline constexpr const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
constexpr const_scan_result find_pattern_single<scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;
const auto firstByte = *signature[0];
const auto scanEnd = end - signature.size() + 1;
Expand All @@ -182,7 +181,7 @@ namespace hat {
}

template<>
inline const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
inline const_scan_result find_pattern_single<scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;
const auto firstByte = *signature[0];

Expand All @@ -206,6 +205,15 @@ namespace hat {
return nullptr;
}

template<>
constexpr scan_function_t get_scanner<scan_mode::Single>(const scan_context& context) {
switch (context.alignment) {
case scan_alignment::X1: return &find_pattern_single<scan_alignment::X1>;
case scan_alignment::X16: return &find_pattern_single<scan_alignment::X16>;
}
std::unreachable();
}

[[nodiscard]] constexpr auto truncate(const signature_view signature) noexcept {
// Truncate the leading wildcards from the signature
size_t offset = 0;
Expand All @@ -229,9 +237,9 @@ namespace hat {
ctx.hints = hints;
ctx.alignment = alignment;
if LIBHAT_IF_CONSTEVAL {
ctx.scanner = &find_pattern<detail::scan_mode::Single, alignment>;
ctx.scanner = get_scanner<scan_mode::Single>(ctx);
} else {
std::tie(ctx.scanner, ctx.vectorSize) = resolve_scanner<alignment>();
std::tie(ctx.scanner, ctx.vectorSize) = resolve_scanner(ctx);
ctx.apply_hints();
}
return ctx;
Expand Down
14 changes: 5 additions & 9 deletions src/Scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,28 @@ namespace hat::detail {

void scan_context::apply_hints() {}

template<scan_alignment alignment>
std::pair<scan_function_t, size_t> resolve_scanner() {
std::pair<scan_function_t, size_t> resolve_scanner(const scan_context& context) {
#if defined(LIBHAT_X86)
const auto& ext = get_system().extensions;
if (ext.bmi) {
#if !defined(LIBHAT_DISABLE_AVX512)
if (ext.avx512f && ext.avx512bw) {
return {&find_pattern<scan_mode::AVX512, alignment>, 64};
return {get_scanner<scan_mode::AVX512>(context), 64};
}
#endif
if (ext.avx2) {
return {&find_pattern<scan_mode::AVX2, alignment>, 32};
return {get_scanner<scan_mode::AVX2>(context), 32};
}
}
#if !defined(LIBHAT_DISABLE_SSE)
if (ext.sse41) {
return {&find_pattern<scan_mode::SSE, alignment>, 16};
return {get_scanner<scan_mode::SSE>(context), 16};
}
#endif
#endif
// If none of the vectorized implementations are available/supported, then fallback to scanning per-byte
return {&find_pattern<scan_mode::Single, alignment>, 0};
return {get_scanner<scan_mode::Single>(context), 0};
}

template std::pair<scan_function_t, size_t> resolve_scanner<scan_alignment::X1>();
template std::pair<scan_function_t, size_t> resolve_scanner<scan_alignment::X16>();
}

// Validate return value const-ness for the root find_pattern impl
Expand Down
46 changes: 23 additions & 23 deletions src/arch/x86/AVX2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,34 +90,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 256 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>(begin, end, context);
return find_pattern_single<alignment>(begin, end, context);
}

template<scan_alignment alignment>
const_scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
template<>
scan_function_t get_scanner<scan_mode::AVX2>(const scan_context& context) {
const auto alignment = context.alignment;
const auto signature = context.signature;
const bool veccmp = signature.size() <= 33;

if (cmpeq2 && veccmp) {
return find_pattern_avx2<alignment, true, true>(begin, end, context);
} else if (cmpeq2) {
return find_pattern_avx2<alignment, true, false>(begin, end, context);
} else if (veccmp) {
return find_pattern_avx2<alignment, false, true>(begin, end, context);
} else {
return find_pattern_avx2<alignment, false, false>(begin, end, context);
if (alignment == scan_alignment::X1) {
const bool cmpeq2 = signature.size() > 1 && signature[1].has_value();
if (cmpeq2 && veccmp) {
return &find_pattern_avx2<scan_alignment::X1, true, true>;
} else if (cmpeq2) {
return &find_pattern_avx2<scan_alignment::X1, true, false>;
} else if (veccmp) {
return &find_pattern_avx2<scan_alignment::X1, false, true>;
} else {
return &find_pattern_avx2<scan_alignment::X1, false, false>;
}
} else if (alignment == scan_alignment::X16) {
if (veccmp) {
return &find_pattern_avx2<scan_alignment::X16, false, true>;
} else {
return &find_pattern_avx2<scan_alignment::X16, false, false>;
}
}
}

template<>
const_scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_avx2<scan_alignment::X1>(begin, end, context);
}

template<>
const_scan_result find_pattern<scan_mode::AVX2, scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_avx2<scan_alignment::X16>(begin, end, context);
std::unreachable();
}
}
#endif
46 changes: 23 additions & 23 deletions src/arch/x86/AVX512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,34 +85,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 512 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>(begin, end, context);
return find_pattern_single<alignment>(begin, end, context);
}

template<scan_alignment alignment>
const_scan_result find_pattern_avx512(const std::byte* begin, const std::byte* end, const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
template<>
scan_function_t get_scanner<scan_mode::AVX512>(const scan_context& context) {
const auto alignment = context.alignment;
const auto signature = context.signature;
const bool veccmp = signature.size() <= 65;

if (cmpeq2 && veccmp) {
return find_pattern_avx512<alignment, true, true>(begin, end, context);
} else if (cmpeq2) {
return find_pattern_avx512<alignment, true, false>(begin, end, context);
} else if (veccmp) {
return find_pattern_avx512<alignment, false, true>(begin, end, context);
} else {
return find_pattern_avx512<alignment, false, false>(begin, end, context);
if (alignment == scan_alignment::X1) {
const bool cmpeq2 = signature.size() > 1 && signature[1].has_value();
if (cmpeq2 && veccmp) {
return &find_pattern_avx512<scan_alignment::X1, true, true>;
} else if (cmpeq2) {
return &find_pattern_avx512<scan_alignment::X1, true, false>;
} else if (veccmp) {
return &find_pattern_avx512<scan_alignment::X1, false, true>;
} else {
return &find_pattern_avx512<scan_alignment::X1, false, false>;
}
} else if (alignment == scan_alignment::X16) {
if (veccmp) {
return &find_pattern_avx512<scan_alignment::X16, false, true>;
} else {
return &find_pattern_avx512<scan_alignment::X16, false, false>;
}
}
}

template<>
const_scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_avx512<scan_alignment::X1>(begin, end, context);
}

template<>
const_scan_result find_pattern<scan_mode::AVX512, scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_avx512<scan_alignment::X16>(begin, end, context);
std::unreachable();
}
}
#endif
46 changes: 23 additions & 23 deletions src/arch/x86/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,34 +88,34 @@ namespace hat::detail {

// Look in remaining bytes that couldn't be grouped into 128 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern<scan_mode::Single, alignment>(begin, end, context);
return find_pattern_single<alignment>(begin, end, context);
}

template<scan_alignment alignment>
const_scan_result find_pattern_sse(const std::byte* begin, const std::byte* end, const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
template<>
scan_function_t get_scanner<scan_mode::SSE>(const scan_context& context) {
const auto alignment = context.alignment;
const auto signature = context.signature;
const bool veccmp = signature.size() <= 17;

if (cmpeq2 && veccmp) {
return find_pattern_sse<alignment, true, true>(begin, end, context);
} else if (cmpeq2) {
return find_pattern_sse<alignment, true, false>(begin, end, context);
} else if (veccmp) {
return find_pattern_sse<alignment, false, true>(begin, end, context);
} else {
return find_pattern_sse<alignment, false, false>(begin, end, context);
if (alignment == scan_alignment::X1) {
const bool cmpeq2 = signature.size() > 1 && signature[1].has_value();
if (cmpeq2 && veccmp) {
return &find_pattern_sse<scan_alignment::X1, true, true>;
} else if (cmpeq2) {
return &find_pattern_sse<scan_alignment::X1, true, false>;
} else if (veccmp) {
return &find_pattern_sse<scan_alignment::X1, false, true>;
} else {
return &find_pattern_sse<scan_alignment::X1, false, false>;
}
} else if (alignment == scan_alignment::X16) {
if (veccmp) {
return &find_pattern_sse<scan_alignment::X16, false, true>;
} else {
return &find_pattern_sse<scan_alignment::X16, false, false>;
}
}
}

template<>
const_scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_sse<scan_alignment::X1>(begin, end, context);
}

template<>
const_scan_result find_pattern<scan_mode::SSE, scan_alignment::X16>(const std::byte* begin, const std::byte* end, const scan_context& context) {
return find_pattern_sse<scan_alignment::X16>(begin, end, context);
std::unreachable();
}
}
#endif

0 comments on commit 3ee98fd

Please sign in to comment.