From 14225fae964310837cafb62a033b452dadc40c7e Mon Sep 17 00:00:00 2001 From: PGross Date: Sat, 30 Mar 2024 11:38:07 +0100 Subject: [PATCH] added option to dynamically disable salt comparison during probing --- src/execution/join_hashtable.cpp | 87 +++++++++++++------ .../duckdb/execution/join_hashtable.hpp | 4 +- 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/src/execution/join_hashtable.cpp b/src/execution/join_hashtable.cpp index a3965e59707d..f203d5e897fc 100644 --- a/src/execution/join_hashtable.cpp +++ b/src/execution/join_hashtable.cpp @@ -151,9 +151,13 @@ inline void IncrementAndWrap(idx_t &value, const uint64_t &capacity_mask) { value &= capacity_mask; } -void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, - const SelectionVector &sel, idx_t &count, Vector &pointers_result_v, - SelectionVector &match_sel) { +//! Gets a pointer to the entry in the HT for each of the hashes_v using linear probing. Will update the match_sel +//! vector and the count argument to the number and position of the matches +template +static inline void +GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_state, JoinHashTable::ProbeState &state, + Vector &hashes_v, const SelectionVector &sel, idx_t &count, JoinHashTable *ht, + aggr_ht_entry_t *entries, Vector &pointers_result_v, SelectionVector &match_sel) { UnifiedVectorFormat hashes_v_unified; hashes_v.ToUnifiedFormat(count, hashes_v_unified); @@ -168,14 +172,15 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta // first, filter out the empty rows and calculate the offset for (idx_t i = 0; i < count; i++) { - const auto row_index = sel.get_index(i); auto uvf_index = hashes_v_unified.sel->get_index(row_index); - auto ht_offset = hashes[uvf_index] & bitmask; + auto ht_offset = hashes[uvf_index] & ht->bitmask; ht_offsets_dense[i] = ht_offset; ht_offsets[row_index] = ht_offset; } + // have a dense loop to have as few instructions as possible while producing cache misses as this is the + // first location where we access the big entries array for (idx_t i = 0; i < count; i++) { idx_t ht_offset = ht_offsets_dense[i]; auto &entry = entries[ht_offset]; @@ -185,14 +190,17 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta } for (idx_t i = 0; i < non_empty_count; i++) { + // transform the dense index to the actual index in the sel vector idx_t dense_index = state.non_empty_sel.get_index(i); const auto row_index = sel.get_index(dense_index); state.non_empty_sel.set_index(i, row_index); - auto uvf_index = hashes_v_unified.sel->get_index(row_index); - auto hash = hashes[uvf_index]; - hash_t row_salt = aggr_ht_entry_t::ExtractSalt(hash); - salts[row_index] = row_salt; + if (USE_SALTS) { + auto uvf_index = hashes_v_unified.sel->get_index(row_index); + auto hash = hashes[uvf_index]; + hash_t row_salt = aggr_ht_entry_t::ExtractSalt(hash); + salts[row_index] = row_salt; + } } auto pointers_result = FlatVector::GetData(pointers_result_v); @@ -216,26 +224,30 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta const auto row_index = remaining_sel->get_index(i); idx_t &ht_offset = ht_offsets[row_index]; + bool occupied; + aggr_ht_entry_t entry; - hash_t row_salt = salts[row_index]; + if (USE_SALTS) { - bool occupied; - bool salt_match; + hash_t row_salt = salts[row_index]; - aggr_ht_entry_t entry; + // increment the ht_offset of the entry as long as next entry is occupied and salt does not match + while (true) { + entry = entries[ht_offset]; + occupied = entry.IsOccupied(); + bool salt_match = entry.GetSalt() == row_salt; - // increment the ht_offset of the entry as long as next entry is occupied and salt does not match - while (true) { - entry = entries[ht_offset]; - occupied = entry.IsOccupied(); - salt_match = entry.GetSalt() == row_salt; + // condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next + // entry + if (!occupied || salt_match) { + break; + } - // condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next entry - if (!occupied || salt_match) { - break; + IncrementAndWrap(ht_offset, ht->bitmask); } - - IncrementAndWrap(ht_offset, bitmask); + } else { + entry = entries[ht_offset]; + occupied = entry.IsOccupied(); } // the entries we need to process in the next iteration are the ones that are occupied and the row_salt @@ -252,9 +264,9 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta // Perform row comparisons, after function call salt_match_sel will point to the keys that match - idx_t key_match_count = - row_matcher_build.Match(keys, key_state.vector_data, state.salt_match_sel, salt_match_count, layout, - state.row_ptr_insert_to_v, &state.key_no_match_sel, key_no_match_count); + idx_t key_match_count = ht->row_matcher_build.Match(keys, key_state.vector_data, state.salt_match_sel, + salt_match_count, ht->layout, state.row_ptr_insert_to_v, + &state.key_no_match_sel, key_no_match_count); D_ASSERT(key_match_count + key_no_match_count == salt_match_count); @@ -272,7 +284,7 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta const auto row_index = state.key_no_match_sel.get_index(i); auto &ht_offset = ht_offsets[row_index]; - IncrementAndWrap(ht_offset, bitmask); + IncrementAndWrap(ht_offset, ht->bitmask); } } @@ -281,6 +293,26 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta } } +inline bool JoinHashTable::UseSalt() const { + // only compare salts with the ht entries if the capacity is larger than 8192 so + // that it does not fit into the CPU cache and if there is only one equality condition as otherwise + // we potentially need to compare multiple keys + return this->capacity > 8192 && this->equality_predicate_columns.size() == 1; +} + +void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, + const SelectionVector &sel, idx_t &count, Vector &pointers_result_v, + SelectionVector &match_sel) { + + if (UseSalt()) { + GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, this, entries, pointers_result_v, + match_sel); + } else { + GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, this, entries, pointers_result_v, + match_sel); + } +} + void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes) { if (count == keys.size()) { // no null values are filtered: use regular hash functions @@ -494,7 +526,6 @@ static void InsertHashesLoop(atomic entries[], Vector row_locat idx_t &ht_offset_and_salt = ht_offsets_and_salt[row_index]; const hash_t salt = aggr_ht_entry_t::ExtractSalt(ht_offset_and_salt); - idx_t increment; idx_t ht_offset; bool occupied; diff --git a/src/include/duckdb/execution/join_hashtable.hpp b/src/include/duckdb/execution/join_hashtable.hpp index 62c1fd3f1cdf..849ece2131ff 100644 --- a/src/include/duckdb/execution/join_hashtable.hpp +++ b/src/include/duckdb/execution/join_hashtable.hpp @@ -265,8 +265,10 @@ class JoinHashTable { const SelectionVector *¤t_sel); void Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes); + bool UseSalt() const; + //! Gets a pointer to the entry in the HT for each of the hashes_v using linear probing. Will update the match_sel - //! vectorand the count argument to the number and position of the matches + //! vector and the count argument to the number and position of the matches void GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, const SelectionVector &sel, idx_t &count, Vector &pointers_result_v, SelectionVector &match_sel);