Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark Hashes in RowLayout to mitigate nullptr write #9

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 104 additions & 29 deletions src/execution/join_hashtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,11 @@ JoinHashTable::JoinHashTable(ClientContext &context, const vector<JoinCondition>
sink_collection =
make_uniq<RadixPartitionedTupleData>(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1);

// create one single entry that suites as a dead-end and set its next pointer to invalid
dead_end = make_unsafe_uniq_array_uninitialized<data_t>(layout.GetRowWidth());
memset(dead_end.get(), 0, layout.GetRowWidth());
// todo: no magic number!
Store<hash_t>(0x8000000000000000, dead_end.get() + pointer_offset);

if (join_type == JoinType::SINGLE) {
auto &config = ClientConfig::GetConfig(context);
Expand Down Expand Up @@ -222,19 +225,23 @@ static inline void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &
const auto row_index = remaining_sel->get_index(i);

idx_t &ht_offset = ht_offsets[row_index];

bool occupied;
bool salt_match;

ht_entry_t entry;

if (USE_SALTS) {
hash_t row_salt = salts[row_index];
// increment the ht_offset of the entry as long as next entry is occupied and salt does not match
while (true) {
entry = entries[ht_offset];

occupied = entry.IsOccupied();
bool salt_match = entry.GetSalt() == row_salt;
salt_match = entry.GetSalt() == row_salt;

// condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next
// entry
// condition for incrementing the ht_offset: occupied and salt does not match and entry has
// collision -> reverse the condition to break out of the loop
if (!occupied || salt_match) {
break;
}
Expand All @@ -249,11 +256,18 @@ static inline void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &
// the entries we need to process in the next iteration are the ones that are occupied and the row_salt
// does not match, the ones that are empty need no further processing
state.salt_match_sel.set_index(salt_match_count, row_index);
salt_match_count += occupied;

// entry might be empty, so the pointer in the entry is nullptr, but this does not matter as the row
// will not be compared anyway as with an empty entry we are already done
row_ptr_insert_to[row_index] = entry.GetPointerOrNull();
if (USE_SALTS) {
// here we stopped because
// (a) we found an empty entry or -> no result
// (b) we found a matching salt -> we need to compare the keys
// (c) or (no salt match and no collision) -> no result
row_ptr_insert_to[row_index] = entry.GetPointerOrNull();
salt_match_count += occupied && salt_match;
} else {
row_ptr_insert_to[row_index] = entry.GetPointerOrNull();
salt_match_count += occupied;
}
}

if (salt_match_count != 0) {
Expand Down Expand Up @@ -305,19 +319,55 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta
}
}

void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes) {
template <bool HAS_SEL>
static inline void MarkHashes(Vector &hashes_v, const SelectionVector *sel, idx_t count) {
// hashes can be either a flat vector or a constant vector after being processed by VectorOperations::Hash
if (hashes_v.GetVectorType() == VectorType::CONSTANT_VECTOR) {
auto constant_hash = ConstantVector::GetData<hash_t>(hashes_v);
*constant_hash |= ht_entry_t::COLLISION_BIT_MASK;
} else {

D_ASSERT(hashes_v.GetVectorType() == VectorType::FLAT_VECTOR);
auto hashes = FlatVector::GetData<hash_t>(hashes_v);

for (idx_t i = 0; i < count; i++) {
auto idx = HAS_SEL ? sel->get_index(i) : i;
hashes[idx] |= ht_entry_t::COLLISION_BIT_MASK;
}
}
}

template <bool MARK_HASHES>
static inline void HashInternal(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes,
idx_t n_equality_types) {
if (count == keys.size()) {
// no null values are filtered: use regular hash functions
VectorOperations::Hash(keys.data[0], hashes, keys.size());
for (idx_t i = 1; i < equality_types.size(); i++) {
for (idx_t i = 1; i < n_equality_types; i++) {
VectorOperations::CombineHash(hashes, keys.data[i], keys.size());
}

if (MARK_HASHES) {
MarkHashes<false>(hashes, nullptr, count);
}
} else {
// null values were filtered: use selection vector
VectorOperations::Hash(keys.data[0], hashes, sel, count);
for (idx_t i = 1; i < equality_types.size(); i++) {
for (idx_t i = 1; i < n_equality_types; i++) {
VectorOperations::CombineHash(hashes, keys.data[i], sel, count);
}

if (MARK_HASHES) {
MarkHashes<true>(hashes, &sel, count);
}
}
}

void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes, bool mark_hashes) {
if (mark_hashes) {
HashInternal<true>(keys, sel, count, hashes, equality_types.size());
} else {
HashInternal<false>(keys, sel, count, hashes, equality_types.size());
}
}

Expand Down Expand Up @@ -399,7 +449,7 @@ void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChu

// hash the keys and obtain an entry in the list
// note that we only hash the keys used in the equality comparison
Hash(keys, *current_sel, added_count, hash_values);
Hash(keys, *current_sel, added_count, hash_values, true);

// Re-reference and ToUnifiedFormat the hash column after computing it
source_chunk.data[col_offset].Reference(hash_values);
Expand Down Expand Up @@ -435,6 +485,9 @@ idx_t JoinHashTable::PrepareKeys(DataChunk &keys, vector<TupleDataVectorFormat>
}

static void StorePointer(const_data_ptr_t pointer, data_ptr_t target) {
// todo: make comment more clear
// pointer does not be a nullptr as we mark chain ends for a mask now
D_ASSERT(pointer != nullptr);
Store<uint64_t>(cast_pointer_to_uint64(pointer), target);
}

Expand All @@ -453,10 +506,8 @@ static inline data_ptr_t InsertRowToEntry(atomic<ht_entry_t> &entry, const data_
// if we expect the entry to be empty, if the operation fails we need to cancel the whole operation as another
// key might have been inserted in the meantime that does not match the current key
if (EXPECT_EMPTY) {
// add nullptr to the end of the list to mark the end
StorePointer(nullptr, row_ptr_to_insert + pointer_offset);

ht_entry_t new_empty_entry = ht_entry_t::GetDesiredEntry(row_ptr_to_insert, salt);
ht_entry_t new_empty_entry = ht_entry_t::GetNewEntry(row_ptr_to_insert, salt);
ht_entry_t expected_empty_entry = ht_entry_t::GetEmptyEntry();
entry.compare_exchange_strong(expected_empty_entry, new_empty_entry, std::memory_order_acquire,
std::memory_order_relaxed);
Expand All @@ -469,23 +520,31 @@ static inline data_ptr_t InsertRowToEntry(atomic<ht_entry_t> &entry, const data_
// if we expect the entry to be full, we know that even if the insert fails the keys still match so we can
// just keep trying until we succeed
ht_entry_t expected_current_entry = entry.load(std::memory_order_relaxed);
ht_entry_t desired_new_entry = ht_entry_t::GetDesiredEntry(row_ptr_to_insert, salt);
D_ASSERT(expected_current_entry.IsOccupied());

ht_entry_t desired_updated_entry = ht_entry_t::UpdateWithPointer(expected_current_entry, row_ptr_to_insert);

do {
data_ptr_t current_row_pointer = expected_current_entry.GetPointer();
StorePointer(current_row_pointer, row_ptr_to_insert + pointer_offset);
} while (!entry.compare_exchange_weak(expected_current_entry, desired_new_entry, std::memory_order_release,
std::memory_order_relaxed));
} while (!entry.compare_exchange_weak(expected_current_entry, desired_updated_entry,
std::memory_order_release, std::memory_order_relaxed));

return nullptr;
}
} else {
// if we are not in parallel mode, we can just do the operation without any checks
ht_entry_t current_entry = entry.load(std::memory_order_relaxed);
data_ptr_t current_row_pointer = current_entry.GetPointerOrNull();
StorePointer(current_row_pointer, row_ptr_to_insert + pointer_offset);
entry = ht_entry_t::GetDesiredEntry(row_ptr_to_insert, salt);

if (EXPECT_EMPTY) {
entry = ht_entry_t::GetNewEntry(row_ptr_to_insert, salt);
} else {

data_ptr_t current_row_pointer = current_entry.GetPointer();
StorePointer(current_row_pointer, row_ptr_to_insert + pointer_offset);

entry = ht_entry_t::UpdateWithPointer(current_entry, row_ptr_to_insert);
}
return nullptr;
}
}
Expand Down Expand Up @@ -520,9 +579,9 @@ static inline void InsertMatchesAndIncrementMisses(atomic<ht_entry_t> entries[],
JoinHashTable &ht, const data_ptr_t lhs_row_locations[],
idx_t ht_offsets_and_salts[], const idx_t capacity_mask,
const idx_t key_match_count, const idx_t key_no_match_count) {
if (key_match_count != 0) {
ht.chains_longer_than_one = true;
}

// mark the ht as having chains longer than one if we have a key match
ht.chains_longer_than_one = ht.chains_longer_than_one || key_match_count != 0;

// Insert the rows that match
for (idx_t i = 0; i < key_match_count; i++) {
Expand All @@ -537,21 +596,24 @@ static inline void InsertMatchesAndIncrementMisses(atomic<ht_entry_t> entries[],
InsertRowToEntry<PARALLEL, false>(entry, row_ptr_to_insert, salt, ht.pointer_offset);
}

// Linear probing: each of the entries that do not match move to the next entry in the HT
// Linear probing: each of the entries that do not match move to the next entry in the HT, also we mark them with
// the collision bit
for (idx_t i = 0; i < key_no_match_count; i++) {
const auto need_compare_idx = state.key_no_match_sel.get_index(i);
const auto entry_index = state.salt_match_sel.get_index(need_compare_idx);

// increment the ht_offset of the entry
idx_t &ht_offset_and_salt = ht_offsets_and_salts[entry_index];
IncrementAndWrap(ht_offset_and_salt, capacity_mask);

// add the entry to the remaining sel vector to get processed in the next loop iteration
state.remaining_sel.set_index(i, entry_index);
}
}

template <bool PARALLEL>
static void InsertHashesLoop(atomic<ht_entry_t> entries[], Vector &row_locations, Vector &hashes_v, const idx_t &count,
JoinHashTable::InsertState &state, const TupleDataCollection &data_collection,
JoinHashTable::InsertState &state, TupleDataCollection &data_collection,
JoinHashTable &ht) {
D_ASSERT(hashes_v.GetType().id() == LogicalType::HASH);
ApplyBitmaskAndGetSaltBuild(hashes_v, count, ht.bitmask);
Expand Down Expand Up @@ -722,6 +784,8 @@ void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool para
const auto count = iterator.GetCurrentChunkCount();
for (idx_t i = 0; i < count; i++) {
hash_data[i] = Load<hash_t>(row_locations[i] + pointer_offset);
// the hashes must have the HASK_MARK_MASK bit set to 1 in order to distinct them from pointers later
D_ASSERT((hash_data[i] & ht_entry_t::HASK_MARK_MASK) == ht_entry_t::HASK_MARK_MASK);
}
TupleDataChunkState &chunk_state = iterator.GetChunkState();

Expand Down Expand Up @@ -760,7 +824,7 @@ void JoinHashTable::Probe(ScanStructure &scan_structure, DataChunk &keys, TupleD
} else {
Vector hashes(LogicalType::HASH);
// hash all the keys
Hash(keys, *current_sel, scan_structure.count, hashes);
Hash(keys, *current_sel, scan_structure.count, hashes, false);

// now initialize the pointers of the scan structure based on the hashes
GetRowPointers(keys, key_state, probe_state, hashes, *current_sel, scan_structure.count,
Expand Down Expand Up @@ -862,6 +926,16 @@ idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vect
}
}

static inline bool IsNextPointer(const data_ptr_t ptr) {
// the chain has a next pointer if the hash that was stored in the pointer is overwritten with the next pointer.
// All hashes have been marked with the most significant bit set to 1, so we can check if this bit is set to
// determine if the pointer is a next pointer or a hash
const hash_t masked = cast_pointer_to_uint64(ptr) & ht_entry_t::COLLISION_BIT_MASK;

// if there is no more bit set, than the value is a pointer to the next entry
return masked == 0;
}

void ScanStructure::AdvancePointers(const SelectionVector &sel, const idx_t sel_count) {

if (!ht.chains_longer_than_one) {
Expand All @@ -875,7 +949,7 @@ void ScanStructure::AdvancePointers(const SelectionVector &sel, const idx_t sel_
for (idx_t i = 0; i < sel_count; i++) {
auto idx = sel.get_index(i);
ptrs[idx] = LoadPointer(ptrs[idx] + ht.pointer_offset);
if (ptrs[idx]) {
if (IsNextPointer(ptrs[idx])) {
this->sel_vector.set_index(new_count++, idx);
}
}
Expand Down Expand Up @@ -1011,6 +1085,7 @@ void ScanStructure::NextRightSemiOrAntiJoin(DataChunk &keys) {
for (idx_t i = 0; i < result_count; i++) {
const auto idx = chain_match_sel_vector.get_index(i);
auto &ptr = ptrs[idx];

if (Load<bool>(ptr + ht.tuple_size)) { // Early out: chain has been fully marked as found before
ptr = ht.dead_end.get();
continue;
Expand All @@ -1022,7 +1097,7 @@ void ScanStructure::NextRightSemiOrAntiJoin(DataChunk &keys) {
// Technically it is, but it does not matter, since the only value that can be written is "true"
Store<bool>(true, ptr + ht.tuple_size);
auto next_ptr = LoadPointer(ptr + ht.pointer_offset);
if (!next_ptr) {
if (!IsNextPointer(next_ptr)) {
break;
}
ptr = next_ptr;
Expand Down Expand Up @@ -1486,7 +1561,7 @@ void JoinHashTable::ProbeAndSpill(ScanStructure &scan_structure, DataChunk &keys
ProbeSpillLocalAppendState &spill_state, DataChunk &spill_chunk) {
// hash all the keys
Vector hashes(LogicalType::HASH);
Hash(keys, *FlatVector::IncrementalSelectionVector(), keys.size(), hashes);
Hash(keys, *FlatVector::IncrementalSelectionVector(), keys.size(), hashes, false);

// find out which keys we can match with the current pinned partitions
SelectionVector true_sel;
Expand Down
10 changes: 10 additions & 0 deletions src/include/duckdb/common/types/hash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@ struct interval_t; // NOLINT
// bias
// see: https://nullprogram.com/blog/2018/07/31/

inline hash_t TempMod10(uint64_t x) {
uint64_t modulo = x % 10;
uint64_t hash = modulo;
// add salt as well, same as module but from bit 48
uint64_t salt = x << 48;

return hash + salt;
}

inline hash_t MurmurHash64(uint64_t x) {
// return TempMod10(x);
x ^= x >> 32;
x *= 0xd6e8feb86659fd93U;
x ^= x >> 32;
Expand Down
Loading
Loading