Skip to content

Commit

Permalink
Implement scan hint optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Aug 2, 2024
1 parent e0e4801 commit cbb522a
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 89 deletions.
16 changes: 12 additions & 4 deletions include/libhat/Defines.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@
#if __cpp_lib_unreachable >= 202202L
#include <utility>
#define LIBHAT_UNREACHABLE() std::unreachable()
#elif defined(__GNUC__) || defined(__clang__)
#define LIBHAT_UNREACHABLE() __builtin_unreachable()
#elif defined(_MSC_VER)
#define LIBHAT_UNREACHABLE() __assume(false)
#elif defined(__GNUC__)
#define LIBHAT_UNREACHABLE() __builtin_unreachable()
#else
#include <cstdlib>
namespace hat::detail {
Expand All @@ -93,10 +93,10 @@

#if __has_cpp_attribute(assume)
#define LIBHAT_ASSUME(...) [[assume(__VA_ARGS__)]]
#elif defined(_MSC_VER)
#define LIBHAT_ASSUME(...) __assume(__VA_ARGS__)
#elif defined(__clang__)
#define LIBHAT_ASSUME(...) __builtin_assume(__VA_ARGS__)
#elif defined(_MSC_VER)
#define LIBHAT_ASSUME(...) __assume(__VA_ARGS__)
#else
#define LIBHAT_ASSUME(...) \
do { \
Expand All @@ -105,3 +105,11 @@
} \
} while (0)
#endif

#if defined(__GNUC__) || defined(__clang__)
#define LIBHAT_FORCEINLINE inline __attribute__((always_inline))
#elif defined(_MSC_VER)
#define LIBHAT_FORCEINLINE __forceinline
#else
#define LIBHAT_FORCEINLINE inline
#endif
69 changes: 54 additions & 15 deletions include/libhat/Scanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ namespace hat {
};

enum class scan_hint : uint64_t {
none = 0, // no hints
x86_64 = 1 << 0, // The data being scanned is x86_64 machine code
none = 0, // no hints
x86_64 = 1 << 0, // The data being scanned is x86_64 machine code
pair0 = 1 << 1, // Only utilize byte pair based scanning if the signature starts with a byte pair
};

constexpr scan_hint operator|(scan_hint lhs, scan_hint rhs) {
Expand All @@ -85,27 +86,30 @@ namespace hat {

using scan_function_t = const_scan_result(*)(const std::byte* begin, const std::byte* end, const scan_context& context);

struct scanner_context {
size_t vectorSize{};
};

class scan_context {
public:
signature_view signature{};
scan_function_t scanner{};
scan_alignment alignment{};
size_t vectorSize{};
scan_hint hints{};
std::optional<size_t> pairIndex{};

[[nodiscard]] constexpr const_scan_result scan(const std::byte* begin, const std::byte* end) const {
return this->scanner(begin, end, *this);
}

void apply_hints();
void auto_resolve_scanner();
void apply_hints(const scanner_context&);

static constexpr scan_context create(signature_view signature, scan_alignment alignment, scan_hint hints);
private:
scan_context() = default;
};

[[nodiscard]] std::pair<scan_function_t, size_t> resolve_scanner(const scan_context&);

enum class scan_mode {
Single, // std::find + std::equal
SSE, // x86 SSE 4.1
Expand All @@ -117,7 +121,7 @@ namespace hat {
inline constexpr auto alignment_stride = static_cast<std::underlying_type_t<scan_alignment>>(alignment);

template<std::integral type, scan_alignment alignment>
inline consteval auto create_alignment_mask() {
LIBHAT_FORCEINLINE consteval auto create_alignment_mask() {
type mask{};
for (size_t i = 0; i < sizeof(type) * 8; i += alignment_stride<alignment>) {
mask |= (type(1) << i);
Expand All @@ -126,7 +130,7 @@ namespace hat {
}

template<scan_alignment alignment>
inline const std::byte* next_boundary_align(const std::byte* ptr) {
LIBHAT_FORCEINLINE const std::byte* next_boundary_align(const std::byte* ptr) {
constexpr auto stride = alignment_stride<alignment>;
if constexpr (stride == 1) {
return ptr;
Expand All @@ -137,17 +141,53 @@ namespace hat {
}

template<scan_alignment alignment>
inline const std::byte* prev_boundary_align(const std::byte* ptr) {
LIBHAT_FORCEINLINE const std::byte* prev_boundary_align(const std::byte* ptr) {
constexpr auto stride = alignment_stride<alignment>;
if constexpr (stride == 1) {
return ptr;
}
uintptr_t mod = reinterpret_cast<uintptr_t>(ptr) % stride;
const uintptr_t mod = reinterpret_cast<uintptr_t>(ptr) % stride;
return std::assume_aligned<stride>(ptr - mod);
}

template<typename Type>
LIBHAT_FORCEINLINE const std::byte* align_pointer_as(const std::byte* ptr) {
constexpr size_t alignment = alignof(Type);
const uintptr_t mod = reinterpret_cast<uintptr_t>(ptr) % alignment;
ptr += mod ? alignment - mod : 0;
return std::assume_aligned<alignment>(ptr);
}

template<typename Vector>
LIBHAT_FORCEINLINE auto segment_scan(
const std::byte* begin,
const std::byte* end,
const size_t signatureSize,
const size_t cmpOffset
) -> std::tuple<std::span<const std::byte>, std::span<const Vector>, std::span<const std::byte>> {
const auto preBegin = begin;
const auto vecBegin = reinterpret_cast<const Vector*>(align_pointer_as<Vector>(preBegin + cmpOffset));
const auto vecEnd = vecBegin + (static_cast<size_t>(end - reinterpret_cast<const std::byte*>(vecBegin)) - signatureSize) / sizeof(Vector);
const auto preEnd = reinterpret_cast<const std::byte*>(vecBegin) - cmpOffset + signatureSize;
const auto postBegin = reinterpret_cast<const std::byte*>(vecEnd);
const auto postEnd = end;

auto validateRange = [signatureSize](const std::byte* b, const std::byte* e) -> std::span<const std::byte> {
if (b <= e && static_cast<size_t>(e - b) >= signatureSize) {
return {b, e};
}
return {};
};

return {
validateRange(preBegin, preEnd),
std::span{vecBegin, vecEnd},
validateRange(postBegin, postEnd)
};
}

template<scan_mode>
scan_function_t get_scanner(const scan_context&);
scan_function_t resolve_scanner(scan_context&);

template<scan_alignment>
const_scan_result find_pattern_single(const std::byte* begin, const std::byte* end, const scan_context&);
Expand Down Expand Up @@ -205,7 +245,7 @@ namespace hat {
}

template<>
constexpr scan_function_t get_scanner<scan_mode::Single>(const scan_context& context) {
constexpr scan_function_t resolve_scanner<scan_mode::Single>(scan_context& context) {
switch (context.alignment) {
case scan_alignment::X1: return &find_pattern_single<scan_alignment::X1>;
case scan_alignment::X16: return &find_pattern_single<scan_alignment::X16>;
Expand Down Expand Up @@ -235,10 +275,9 @@ namespace hat {
ctx.alignment = alignment;
ctx.hints = hints;
if LIBHAT_IF_CONSTEVAL {
ctx.scanner = get_scanner<scan_mode::Single>(ctx);
ctx.scanner = resolve_scanner<scan_mode::Single>(ctx);
} else {
std::tie(ctx.scanner, ctx.vectorSize) = resolve_scanner(ctx);
ctx.apply_hints();
ctx.auto_resolve_scanner();
}
return ctx;
}
Expand Down
66 changes: 60 additions & 6 deletions src/Scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,85 @@
#include <libhat/Defines.hpp>
#include <libhat/System.hpp>

#include "arch/x86/Frequency.hpp"

namespace hat::detail {

void scan_context::apply_hints() {}
void scan_context::apply_hints(const scanner_context& scanner) {
const bool x86_64 = static_cast<bool>(this->hints & scan_hint::x86_64);
const bool pair0 = static_cast<bool>(this->hints & scan_hint::pair0);

if (x86_64 && !pair0 && scanner.vectorSize && this->alignment == hat::scan_alignment::X1) {
const auto get_score = [this](const std::byte a, const std::byte b) {
constexpr auto& pairs = hat::detail::x86_64::pairs_x1;
const auto it = std::ranges::find(pairs, std::pair{a, b});
return it == pairs.end() ? pairs.size() : pairs.size() - static_cast<size_t>(it - pairs.begin()) - 1;
};

const auto score_pair = [&](auto&& tup) {
auto [a, b] = std::get<1>(tup);
return std::make_tuple(std::get<0>(tup), get_score(a.value(), b.value()));
};

static constexpr auto is_complete_pair = [](auto&& tup) {
auto [a, b] = std::get<1>(tup);
return a.has_value() && b.has_value();
};

auto valid_pairs = this->signature
| std::views::take(scanner.vectorSize)
| std::views::adjacent<2>
| std::views::enumerate
| std::views::filter(is_complete_pair)
| std::views::transform(score_pair);

if (!valid_pairs.empty()) {
this->pairIndex = std::get<0>(std::ranges::max(valid_pairs, std::ranges::less{}, [](auto&& tup) {
return std::get<1>(tup);
}));
}
}

// If no "optimal" pair was found, find the first byte pair in the signature
if (!this->pairIndex.has_value()) {
size_t i{};
for (auto&& [a, b] : this->signature | std::views::adjacent<2>) {
if (a.has_value() && b.has_value()) {
this->pairIndex = i;
break;
}
if (i == 0 && pair0) {
break;
}
i++;
}
}
}

std::pair<scan_function_t, size_t> resolve_scanner(const scan_context& context) {
void scan_context::auto_resolve_scanner() {
#if defined(LIBHAT_X86)
const auto& ext = get_system().extensions;
if (ext.bmi) {
#if !defined(LIBHAT_DISABLE_AVX512)
if (ext.avx512f && ext.avx512bw) {
return {get_scanner<scan_mode::AVX512>(context), 64};
this->scanner = resolve_scanner<scan_mode::AVX512>(*this);
return;
}
#endif
if (ext.avx2) {
return {get_scanner<scan_mode::AVX2>(context), 32};
this->scanner = resolve_scanner<scan_mode::AVX2>(*this);
return;
}
}
#if !defined(LIBHAT_DISABLE_SSE)
if (ext.sse41) {
return {get_scanner<scan_mode::SSE>(context), 16};
this->scanner = resolve_scanner<scan_mode::SSE>(*this);
return;
}
#endif
#endif
// If none of the vectorized implementations are available/supported, then fallback to scanning per-byte
return {get_scanner<scan_mode::Single>(context), 0};
this->scanner = resolve_scanner<scan_mode::Single>(*this);
}
}

Expand Down
52 changes: 31 additions & 21 deletions src/arch/x86/AVX2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

namespace hat::detail {

inline auto load_signature_256(signature_view signature) {
inline auto load_signature_256(const signature_view signature) {
std::byte byteBuffer[32]{}; // The remaining signature bytes
std::byte maskBuffer[32]{}; // A bitmask for the signature bytes we care about
for (size_t i = 1; i < signature.size(); i++) {
for (size_t i = 0; i < signature.size(); i++) {
auto e = signature[i];
if (e.has_value()) {
byteBuffer[i - 1] = *e;
maskBuffer[i - 1] = std::byte{0xFFu};
byteBuffer[i] = *e;
maskBuffer[i] = std::byte{0xFFu};
}
}
return std::make_tuple(
Expand All @@ -28,13 +28,15 @@ namespace hat::detail {
template<scan_alignment alignment, bool cmpeq2, bool veccmp>
const_scan_result find_pattern_avx2(const std::byte* begin, const std::byte* end, const scan_context& context) {
const auto signature = context.signature;
const auto cmpIndex = cmpeq2 ? *context.pairIndex : 0;
LIBHAT_ASSUME(cmpIndex < 32);

// 256 bit vector containing first signature byte repeated
const auto firstByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[0]));
const auto firstByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[cmpIndex]));

__m256i secondByte;
if constexpr (cmpeq2) {
secondByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[1]));
secondByte = _mm256_set1_epi8(static_cast<int8_t>(*signature[cmpIndex + 1]));
}

__m256i signatureBytes, signatureMask;
Expand All @@ -47,19 +49,24 @@ namespace hat::detail {
return {};
}

auto vec = reinterpret_cast<const __m256i*>(begin);
const auto n = static_cast<size_t>(end - signature.size() - begin) / sizeof(__m256i);
const auto e = vec + n;
auto [pre, vec, post] = segment_scan<__m256i>(begin, end, signature.size(), cmpIndex);

for (; vec != e; vec++) {
const auto cmp = _mm256_cmpeq_epi8(firstByte, _mm256_loadu_si256(vec));
if (!pre.empty()) {
const auto result = find_pattern_single<alignment>(pre.data(), pre.data() + pre.size(), context);
if (result.has_result()) {
return result;
}
}

for (auto& it : vec) {
const auto cmp = _mm256_cmpeq_epi8(firstByte, _mm256_loadu_si256(&it));
auto mask = static_cast<uint32_t>(_mm256_movemask_epi8(cmp));

if constexpr (alignment != scan_alignment::X1) {
mask &= create_alignment_mask<uint32_t, alignment>();
if (!mask) continue;
} else if constexpr (cmpeq2) {
const auto cmp2 = _mm256_cmpeq_epi8(secondByte, _mm256_loadu_si256(vec));
const auto cmp2 = _mm256_cmpeq_epi8(secondByte, _mm256_loadu_si256(&it));
auto mask2 = static_cast<uint32_t>(_mm256_movemask_epi8(cmp2));
// avoid loading unaligned memory by letting a match of the first signature byte in the last
// position imply that the second byte also matched
Expand All @@ -68,16 +75,16 @@ namespace hat::detail {

while (mask) {
const auto offset = _tzcnt_u32(mask);
const auto i = reinterpret_cast<const std::byte*>(vec) + offset;
const auto i = reinterpret_cast<const std::byte*>(&it) + offset - cmpIndex;
if constexpr (veccmp) {
const auto data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(i + 1));
const auto data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(i));
const auto cmpToSig = _mm256_cmpeq_epi8(signatureBytes, data);
const auto matched = _mm256_testc_si256(cmpToSig, signatureMask);
if (matched) LIBHAT_UNLIKELY {
return i;
}
} else {
auto match = std::equal(signature.begin() + 1, signature.end(), i + 1, [](auto opt, auto byte) {
auto match = std::equal(signature.begin(), signature.end(), i, [](auto opt, auto byte) {
return !opt.has_value() || *opt == byte;
});
if (match) LIBHAT_UNLIKELY {
Expand All @@ -88,19 +95,22 @@ namespace hat::detail {
}
}

// Look in remaining bytes that couldn't be grouped into 256 bits
begin = reinterpret_cast<const std::byte*>(vec);
return find_pattern_single<alignment>(begin, end, context);
if (!post.empty()) {
return find_pattern_single<alignment>(post.data(), post.data() + post.size(), context);
}
return {};
}

template<>
scan_function_t get_scanner<scan_mode::AVX2>(const scan_context& context) {
scan_function_t resolve_scanner<scan_mode::AVX2>(scan_context& context) {
context.apply_hints({.vectorSize = 32});

const auto alignment = context.alignment;
const auto signature = context.signature;
const bool veccmp = signature.size() <= 33;
const bool veccmp = signature.size() <= 32;

if (alignment == scan_alignment::X1) {
const bool cmpeq2 = signature.size() > 1 && signature[1].has_value();
const bool cmpeq2 = context.pairIndex.has_value();
if (cmpeq2 && veccmp) {
return &find_pattern_avx2<scan_alignment::X1, true, true>;
} else if (cmpeq2) {
Expand Down
Loading

0 comments on commit cbb522a

Please sign in to comment.