Skip to content

Commit

Permalink
added option to dynamically disable salt comparison during probing
Browse files Browse the repository at this point in the history
  • Loading branch information
gropaul committed Apr 2, 2024
1 parent e6b477a commit 14225fa
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 29 deletions.
87 changes: 59 additions & 28 deletions src/execution/join_hashtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,13 @@ inline void IncrementAndWrap(idx_t &value, const uint64_t &capacity_mask) {
value &= capacity_mask;
}

void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v,
const SelectionVector &sel, idx_t &count, Vector &pointers_result_v,
SelectionVector &match_sel) {
//! Gets a pointer to the entry in the HT for each of the hashes_v using linear probing. Will update the match_sel
//! vector and the count argument to the number and position of the matches
template <bool USE_SALTS>
static inline void
GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_state, JoinHashTable::ProbeState &state,
Vector &hashes_v, const SelectionVector &sel, idx_t &count, JoinHashTable *ht,
aggr_ht_entry_t *entries, Vector &pointers_result_v, SelectionVector &match_sel) {

UnifiedVectorFormat hashes_v_unified;
hashes_v.ToUnifiedFormat(count, hashes_v_unified);
Expand All @@ -168,14 +172,15 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta

// first, filter out the empty rows and calculate the offset
for (idx_t i = 0; i < count; i++) {

const auto row_index = sel.get_index(i);
auto uvf_index = hashes_v_unified.sel->get_index(row_index);
auto ht_offset = hashes[uvf_index] & bitmask;
auto ht_offset = hashes[uvf_index] & ht->bitmask;
ht_offsets_dense[i] = ht_offset;
ht_offsets[row_index] = ht_offset;
}

// have a dense loop to have as few instructions as possible while producing cache misses as this is the
// first location where we access the big entries array
for (idx_t i = 0; i < count; i++) {
idx_t ht_offset = ht_offsets_dense[i];
auto &entry = entries[ht_offset];
Expand All @@ -185,14 +190,17 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta
}

for (idx_t i = 0; i < non_empty_count; i++) {
// transform the dense index to the actual index in the sel vector
idx_t dense_index = state.non_empty_sel.get_index(i);
const auto row_index = sel.get_index(dense_index);
state.non_empty_sel.set_index(i, row_index);

auto uvf_index = hashes_v_unified.sel->get_index(row_index);
auto hash = hashes[uvf_index];
hash_t row_salt = aggr_ht_entry_t::ExtractSalt(hash);
salts[row_index] = row_salt;
if (USE_SALTS) {
auto uvf_index = hashes_v_unified.sel->get_index(row_index);
auto hash = hashes[uvf_index];
hash_t row_salt = aggr_ht_entry_t::ExtractSalt(hash);
salts[row_index] = row_salt;
}
}

auto pointers_result = FlatVector::GetData<data_ptr_t>(pointers_result_v);
Expand All @@ -216,26 +224,30 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta
const auto row_index = remaining_sel->get_index(i);

idx_t &ht_offset = ht_offsets[row_index];
bool occupied;
aggr_ht_entry_t entry;

hash_t row_salt = salts[row_index];
if (USE_SALTS) {

bool occupied;
bool salt_match;
hash_t row_salt = salts[row_index];

aggr_ht_entry_t entry;
// increment the ht_offset of the entry as long as next entry is occupied and salt does not match
while (true) {
entry = entries[ht_offset];
occupied = entry.IsOccupied();
bool salt_match = entry.GetSalt() == row_salt;

// increment the ht_offset of the entry as long as next entry is occupied and salt does not match
while (true) {
entry = entries[ht_offset];
occupied = entry.IsOccupied();
salt_match = entry.GetSalt() == row_salt;
// condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next
// entry
if (!occupied || salt_match) {
break;
}

// condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next entry
if (!occupied || salt_match) {
break;
IncrementAndWrap(ht_offset, ht->bitmask);
}

IncrementAndWrap(ht_offset, bitmask);
} else {
entry = entries[ht_offset];
occupied = entry.IsOccupied();
}

// the entries we need to process in the next iteration are the ones that are occupied and the row_salt
Expand All @@ -252,9 +264,9 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta

// Perform row comparisons, after function call salt_match_sel will point to the keys that match

idx_t key_match_count =
row_matcher_build.Match(keys, key_state.vector_data, state.salt_match_sel, salt_match_count, layout,
state.row_ptr_insert_to_v, &state.key_no_match_sel, key_no_match_count);
idx_t key_match_count = ht->row_matcher_build.Match(keys, key_state.vector_data, state.salt_match_sel,
salt_match_count, ht->layout, state.row_ptr_insert_to_v,
&state.key_no_match_sel, key_no_match_count);

D_ASSERT(key_match_count + key_no_match_count == salt_match_count);

Expand All @@ -272,7 +284,7 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta
const auto row_index = state.key_no_match_sel.get_index(i);
auto &ht_offset = ht_offsets[row_index];

IncrementAndWrap(ht_offset, bitmask);
IncrementAndWrap(ht_offset, ht->bitmask);
}
}

Expand All @@ -281,6 +293,26 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta
}
}

inline bool JoinHashTable::UseSalt() const {
// only compare salts with the ht entries if the capacity is larger than 8192 so
// that it does not fit into the CPU cache and if there is only one equality condition as otherwise
// we potentially need to compare multiple keys
return this->capacity > 8192 && this->equality_predicate_columns.size() == 1;
}

void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v,
const SelectionVector &sel, idx_t &count, Vector &pointers_result_v,
SelectionVector &match_sel) {

if (UseSalt()) {
GetRowPointersInternal<true>(keys, key_state, state, hashes_v, sel, count, this, entries, pointers_result_v,
match_sel);
} else {
GetRowPointersInternal<false>(keys, key_state, state, hashes_v, sel, count, this, entries, pointers_result_v,
match_sel);
}
}

void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes) {
if (count == keys.size()) {
// no null values are filtered: use regular hash functions
Expand Down Expand Up @@ -494,7 +526,6 @@ static void InsertHashesLoop(atomic<aggr_ht_entry_t> entries[], Vector row_locat

idx_t &ht_offset_and_salt = ht_offsets_and_salt[row_index];
const hash_t salt = aggr_ht_entry_t::ExtractSalt(ht_offset_and_salt);
idx_t increment;
idx_t ht_offset;

bool occupied;
Expand Down
4 changes: 3 additions & 1 deletion src/include/duckdb/execution/join_hashtable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,10 @@ class JoinHashTable {
const SelectionVector *&current_sel);
void Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes);

bool UseSalt() const;

//! Gets a pointer to the entry in the HT for each of the hashes_v using linear probing. Will update the match_sel
//! vectorand the count argument to the number and position of the matches
//! vector and the count argument to the number and position of the matches
void GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v,
const SelectionVector &sel, idx_t &count, Vector &pointers_result_v,
SelectionVector &match_sel);
Expand Down

0 comments on commit 14225fa

Please sign in to comment.