Skip to content

Commit

Permalink
Match iterator constness for find_pattern return value
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Jul 30, 2024
1 parent fc8cae2 commit 65d0d0b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 32 deletions.
45 changes: 28 additions & 17 deletions include/libhat/Scanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@

namespace hat {

class scan_result {
template<typename T> requires (std::is_pointer_v<T> && sizeof(std::remove_pointer_t<T>) == 1)
class scan_result_base {
using rel_t = int32_t;
public:
constexpr scan_result() : result(nullptr) {}
constexpr scan_result(std::nullptr_t) : result(nullptr) {} // NOLINT(google-explicit-constructor)
constexpr scan_result(const std::byte* result) : result(result) {} // NOLINT(google-explicit-constructor)
using underlying_type = T;

constexpr scan_result_base() : result(nullptr) {}
constexpr scan_result_base(std::nullptr_t) : result(nullptr) {} // NOLINT(google-explicit-constructor)
constexpr scan_result_base(T result) : result(result) {} // NOLINT(google-explicit-constructor)

/// Reads an integer of the specified type located at an offset from the signature result
template<std::integral Int>
Expand All @@ -32,25 +35,28 @@ namespace hat {
}

/// Resolve the relative address located at an offset from the signature result
[[nodiscard]] constexpr const std::byte* rel(size_t offset) const {
[[nodiscard]] constexpr T rel(size_t offset) const {
return this->has_result() ? this->result + this->read<rel_t>(offset) + offset + sizeof(rel_t) : nullptr;
}

[[nodiscard]] constexpr bool has_result() const {
return this->result != nullptr;
}

[[nodiscard]] constexpr const std::byte* operator*() const {
[[nodiscard]] constexpr T operator*() const {
return this->result;
}

[[nodiscard]] constexpr const std::byte* get() const {
[[nodiscard]] constexpr T get() const {
return this->result;
}
private:
const std::byte* result;
T result;
};

using scan_result = scan_result_base<std::byte*>;
using const_scan_result = scan_result_base<const std::byte*>;

enum class scan_alignment {
X1 = 1,
X16 = 16
Expand Down Expand Up @@ -122,13 +128,13 @@ namespace hat {
}

template<scan_mode, scan_alignment>
scan_result find_pattern(const scan_context&);
const_scan_result find_pattern(const scan_context&);

template<scan_alignment alignment>
scan_result find_pattern(const scan_context&);
const_scan_result find_pattern(const scan_context&);

template<>
inline constexpr scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const scan_context& context) {
inline constexpr const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X1>(const scan_context& context) {
auto [begin, end, signature, _] = context;
const auto firstByte = *signature[0];
const auto scanEnd = end - signature.size() + 1;
Expand All @@ -155,14 +161,14 @@ namespace hat {
}

template<>
inline scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X16>(const scan_context& context) {
inline constexpr const_scan_result find_pattern<scan_mode::FastFirst, scan_alignment::X16>(const scan_context& context) {
auto [begin, end, signature, _] = context;
const auto firstByte = *signature[0];

const auto scanBegin = next_boundary_align<scan_alignment::X16>(begin);
const auto scanEnd = prev_boundary_align<scan_alignment::X16>(end - signature.size() + 1);
if (scanBegin >= scanEnd) {
return {};
return nullptr;
}

for (auto i = scanBegin; i != scanEnd; i += 16) {
Expand Down Expand Up @@ -198,7 +204,7 @@ namespace hat {

/// Root implementation of find_pattern
template<scan_alignment alignment = scan_alignment::X1, detail::byte_iterator Iter>
constexpr scan_result find_pattern(
constexpr auto find_pattern(
Iter beginIt,
Iter endIt,
signature_view signature,
Expand All @@ -216,17 +222,22 @@ namespace hat {

const auto begin = std::to_address(beginIt) + offset;
const auto end = std::to_address(endIt);

using result_t = std::conditional_t<std::is_const_v<std::remove_pointer_t<decltype(begin)>>, const_scan_result, scan_result>;

if (begin >= end || signature.size() > static_cast<size_t>(std::distance(begin, end))) {
return nullptr;
return result_t{nullptr};
}

hat::scan_result result;
const_scan_result result;
if LIBHAT_IF_CONSTEVAL {
result = detail::find_pattern<detail::scan_mode::Single, alignment>({begin, end, signature, hints});
} else {
result = detail::find_pattern<alignment>({begin, end, signature, hints});
}
return result.has_result() ? result.get() - offset : nullptr;
return result.has_result()
? const_cast<typename result_t::underlying_type>(result.get() - offset)
: result_t{nullptr};
}
}

Expand Down
19 changes: 16 additions & 3 deletions src/Scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace hat::detail {

template<scan_alignment alignment>
scan_result find_pattern(const scan_context& context) {
const_scan_result find_pattern(const scan_context& context) {
#if defined(LIBHAT_X86)
const auto& ext = get_system().extensions;
if (ext.bmi) {
Expand All @@ -29,6 +29,19 @@ namespace hat::detail {
return find_pattern<scan_mode::Single, alignment>(context);
}

template scan_result find_pattern<scan_alignment::X1>(const scan_context& context);
template scan_result find_pattern<scan_alignment::X16>(const scan_context& context);
template const_scan_result find_pattern<scan_alignment::X1>(const scan_context& context);
template const_scan_result find_pattern<scan_alignment::X16>(const scan_context& context);
}

// Validate return value const-ness for the root find_pattern impl
namespace hat {
static_assert(std::is_same_v<scan_result, decltype(find_pattern(
std::declval<std::byte*>(),
std::declval<std::byte*>(),
std::declval<signature_view>()))>);

static_assert(std::is_same_v<const_scan_result, decltype(find_pattern(
std::declval<const std::byte*>(),
std::declval<const std::byte*>(),
std::declval<signature_view>()))>);
}
8 changes: 4 additions & 4 deletions src/arch/x86/AVX2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_avx2(const scan_context& context) {
const_scan_result find_pattern_avx2(const scan_context& context) {
auto [begin, end, signature, hints] = context;

// 256 bit vector containing first signature byte repeated
Expand Down Expand Up @@ -94,7 +94,7 @@ namespace hat::detail {
}

template<scan_alignment alignment>
scan_result find_pattern_avx2(const scan_context& context) {
const_scan_result find_pattern_avx2(const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 33;
Expand All @@ -111,12 +111,12 @@ namespace hat::detail {
}

template<>
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const scan_context& context) {
const_scan_result find_pattern<scan_mode::AVX2, scan_alignment::X1>(const scan_context& context) {
return find_pattern_avx2<scan_alignment::X1>(context);
}

template<>
scan_result find_pattern<scan_mode::AVX2, scan_alignment::X16>(const scan_context& context) {
const_scan_result find_pattern<scan_mode::AVX2, scan_alignment::X16>(const scan_context& context) {
return find_pattern_avx2<scan_alignment::X16>(context);
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/arch/x86/AVX512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_avx512(const scan_context& context) {
const_scan_result find_pattern_avx512(const scan_context& context) {
auto [begin, end, signature, hints] = context;
// 512 bit vector containing first signature byte repeated
const auto firstByte = _mm512_set1_epi8(static_cast<int8_t>(*signature[0]));
Expand Down Expand Up @@ -89,7 +89,7 @@ namespace hat::detail {
}

template<scan_alignment alignment>
scan_result find_pattern_avx512(const scan_context& context) {
const_scan_result find_pattern_avx512(const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 65;
Expand All @@ -106,12 +106,12 @@ namespace hat::detail {
}

template<>
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const scan_context& context) {
const_scan_result find_pattern<scan_mode::AVX512, scan_alignment::X1>(const scan_context& context) {
return find_pattern_avx512<scan_alignment::X1>(context);
}

template<>
scan_result find_pattern<scan_mode::AVX512, scan_alignment::X16>(const scan_context& context) {
const_scan_result find_pattern<scan_mode::AVX512, scan_alignment::X16>(const scan_context& context) {
return find_pattern_avx512<scan_alignment::X16>(context);
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/arch/x86/SSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace hat::detail {
}

template<scan_alignment alignment, bool cmpeq2, bool veccmp>
scan_result find_pattern_sse(const scan_context& context) {
const_scan_result find_pattern_sse(const scan_context& context) {
auto [begin, end, signature, hints] = context;

// 256 bit vector containing first signature byte repeated
Expand Down Expand Up @@ -92,7 +92,7 @@ namespace hat::detail {
}

template<scan_alignment alignment>
scan_result find_pattern_sse(const scan_context& context) {
const_scan_result find_pattern_sse(const scan_context& context) {
auto& signature = context.signature;
const bool cmpeq2 = alignment == scan_alignment::X1 && signature.size() > 1 && signature[1].has_value();
const bool veccmp = signature.size() <= 17;
Expand All @@ -109,12 +109,12 @@ namespace hat::detail {
}

template<>
scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const scan_context& context) {
const_scan_result find_pattern<scan_mode::SSE, scan_alignment::X1>(const scan_context& context) {
return find_pattern_sse<scan_alignment::X1>(context);
}

template<>
scan_result find_pattern<scan_mode::SSE, scan_alignment::X16>(const scan_context& context) {
const_scan_result find_pattern<scan_mode::SSE, scan_alignment::X16>(const scan_context& context) {
return find_pattern_sse<scan_alignment::X16>(context);
}
}
Expand Down

0 comments on commit 65d0d0b

Please sign in to comment.