Skip to content

Commit

Permalink
improvement: offsets and salts in the same vector
Browse files Browse the repository at this point in the history
  • Loading branch information
gropaul committed Mar 11, 2024
1 parent f2ea8b2 commit 9fe76a4
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 47 deletions.
85 changes: 40 additions & 45 deletions src/execution/join_hashtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ JoinHashTable::ProbeState::ProbeState()
key_no_match_sel(STANDARD_VECTOR_SIZE), salt_match_sel(STANDARD_VECTOR_SIZE) {
}

JoinHashTable::InsertState::InsertState() : hash_salts_v(LogicalType::UBIGINT), remaining_sel(STANDARD_VECTOR_SIZE) {
JoinHashTable::InsertState::InsertState() : remaining_sel(STANDARD_VECTOR_SIZE) {
}

JoinHashTable::JoinHashTable(BufferManager &buffer_manager_p, const vector<JoinCondition> &conditions_p,
Expand Down Expand Up @@ -116,43 +116,35 @@ void JoinHashTable::Merge(JoinHashTable &other) {
sink_collection->Combine(*other.sink_collection);
}

static void ApplyBitmaskAndGetSaltBuild(Vector &hashes_v, const idx_t &count, const idx_t &bitmask,
Vector &hash_salts_v) {

auto hash_salts = FlatVector::GetData<idx_t>(hash_salts_v);
static void ApplyBitmaskAndGetSaltBuild(Vector &hashes_v, const idx_t &count, const idx_t &bitmask) {

if (hashes_v.GetVectorType() == VectorType::CONSTANT_VECTOR) {

D_ASSERT(!ConstantVector::IsNull(hashes_v));

auto indices = ConstantVector::GetData<hash_t>(hashes_v);
hash_t salt = aggr_ht_entry_t::ExtractSalt(*indices);
hash_t salt = aggr_ht_entry_t::ExtractSaltWithNull(*indices);
idx_t offset = *indices & bitmask;
*indices = offset;
*indices = offset | salt;
hashes_v.Flatten(count);
for (idx_t i = 0; i < count; i++) {
hash_salts[i] = salt;
}

} else {
hashes_v.Flatten(count);
auto hashes = FlatVector::GetData<hash_t>(hashes_v);

for (idx_t i = 0; i < count; i++) {
hash_salts[i] = aggr_ht_entry_t::ExtractSalt(hashes[i]);
hashes[i] = hashes[i] & bitmask;
idx_t salt = aggr_ht_entry_t::ExtractSaltWithNull(hashes[i]);
idx_t offset = hashes[i] & bitmask;
hashes[i] = offset | salt;
}
}
}

// uses an AND operation to apply the bitmask instead of an in condition
inline void IncrementAndWrap(idx_t &value, idx_t increment, uint64_t bitmask) {
// uses an AND operation to apply the capacity_mask instead of an in condition
inline void IncrementAndWrap(idx_t &value, const idx_t &increment, const uint64_t &capacity_mask) {
value += increment;
value &= bitmask;
}

inline void IncrementAndWrap(idx_t &value, uint64_t bitmask) {
IncrementAndWrap(value, 1, bitmask);
// leave the salt bits unchanged
value &= capacity_mask | 0xFFFF000000000000;
}

void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v,
Expand Down Expand Up @@ -262,7 +254,7 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta
const auto uvf_index = hashes_v_unified.sel->get_index(row_index);
auto &ht_offset = ht_offsets[uvf_index];

IncrementAndWrap(ht_offset, bitmask);
IncrementAndWrap(ht_offset, 1, bitmask);
}

remaining_sel = &state.key_no_match_sel;
Expand Down Expand Up @@ -452,11 +444,10 @@ static void InsertHashesLoop(atomic<aggr_ht_entry_t> entries[], Vector row_locat

D_ASSERT(hashes_v.GetType().id() == LogicalType::HASH);

ApplyBitmaskAndGetSaltBuild(hashes_v, count, ht->bitmask, state.hash_salts_v);
ApplyBitmaskAndGetSaltBuild(hashes_v, count, ht->bitmask);

// the offset for each row to insert
auto ht_offsets = FlatVector::GetData<idx_t>(hashes_v);
auto hash_salts = FlatVector::GetData<idx_t>(state.hash_salts_v);
auto ht_offsets_and_salt = FlatVector::GetData<idx_t>(hashes_v);
auto row_ptr_insert_to = FlatVector::GetData<data_ptr_t>(state.row_ptr_insert_to_v);
auto row_ptrs_to_insert = FlatVector::GetData<data_ptr_t>(row_locations);

Expand All @@ -471,14 +462,14 @@ static void InsertHashesLoop(atomic<aggr_ht_entry_t> entries[], Vector row_locat
// a new list
for (idx_t i = 0; i < remaining_count; i++) {
const auto row_index = remaining_sel->get_index(i);
auto &ht_offset = ht_offsets[row_index];
auto salt = hash_salts[row_index];
auto &ht_offset_with_salt = ht_offsets_and_salt[row_index];
auto salt = aggr_ht_entry_t::ExtractSalt(ht_offset_with_salt);

idx_t increment;

// increment the ht_offset of the entry as long as next entry is occupied and salt does not match
// increment the ht_offset_with_salt of the entry as long as next entry is occupied and salt does not match
do {
atomic<aggr_ht_entry_t> &atomic_entry = entries[ht_offset];
atomic<aggr_ht_entry_t> &atomic_entry = entries[ht_offset_with_salt & JoinHashTable::POINTER_MASK];
aggr_ht_entry_t entry = atomic_entry.load();

bool occupied = entry.IsOccupied();
Expand All @@ -489,24 +480,30 @@ static void InsertHashesLoop(atomic<aggr_ht_entry_t> entries[], Vector row_locat
bool successful_insertion =
InsertRowToEntry<PARALLEL, true>(atomic_entry, row_ptr, salt, ht->pointer_offset);

// if the insertion was successful, we can stop the loop for this entry
if (successful_insertion) {
if (PARALLEL) {
// if the insertion was successful, we can stop the loop for this entry
if (successful_insertion) {
break;
}
// if the insertion was not successful, the entry was occupied in the meantime
else {
occupied = true;
entry = atomic_entry.load();
}
} else {
// if we are not in parallel mode, we can just stop the loop for this entry
break;
}
// if the insertion was not successful, the entry was occupied in the meantime
else {
occupied = true;
entry = atomic_entry.load();
}
}

bool salt_match = entry.GetSalt() == hash_salts[row_index];
bool salt_match = entry.GetSalt() == salt;
state.salt_match_sel.set_index(salt_match_count, row_index);
salt_match_count += salt_match;

// condition for incrementing the ht_offset: occupied and salt does not match -> move to next entry
// condition for incrementing the ht_offset_with_salt: occupied and salt does not match -> move to next
// entry
increment = !salt_match;
IncrementAndWrap(ht_offset, increment, ht->bitmask);
IncrementAndWrap(ht_offset_with_salt, increment, ht->bitmask);
} while (increment);
}

Expand Down Expand Up @@ -539,7 +536,7 @@ static void InsertHashesLoop(atomic<aggr_ht_entry_t> entries[], Vector row_locat
// Get the pointers to the rows that need to be compared
for (idx_t need_compare_idx = 0; need_compare_idx < salt_match_count; need_compare_idx++) {
const auto entry_index = state.salt_match_sel.get_index(need_compare_idx);
const auto &entry = entries[ht_offsets[entry_index]];
const auto &entry = entries[ht_offsets_and_salt[entry_index] & JoinHashTable::POINTER_MASK];
row_ptr_insert_to[need_compare_idx] = entry.load().GetPointer();
}

Expand All @@ -562,9 +559,10 @@ static void InsertHashesLoop(atomic<aggr_ht_entry_t> entries[], Vector row_locat
const auto need_compare_idx = match_sel.get_index(i);
const auto entry_index = state.salt_match_sel.get_index(need_compare_idx);

auto &entry = entries[ht_offsets[entry_index]];
idx_t offset_and_salt = ht_offsets_and_salt[entry_index];
auto &entry = entries[offset_and_salt & JoinHashTable::POINTER_MASK];
data_ptr_t row_ptr = row_ptrs_to_insert[entry_index];
auto salt = hash_salts[entry_index];
auto salt = offset_and_salt & JoinHashTable::SALT_MASK;
InsertRowToEntry<PARALLEL, false>(entry, row_ptr, salt, ht->pointer_offset);

ht->chains_longer_than_one = true;
Expand All @@ -576,12 +574,9 @@ static void InsertHashesLoop(atomic<aggr_ht_entry_t> entries[], Vector row_locat
const auto need_compare_idx = state.key_no_match_sel.get_index(i);
const auto entry_index = state.salt_match_sel.get_index(need_compare_idx);

auto &ht_offset = ht_offsets[entry_index];
auto &ht_offset = ht_offsets_and_salt[entry_index];

ht_offset++;
if (ht_offset >= ht->capacity) {
ht_offset = 0;
}
IncrementAndWrap(ht_offset, 1, ht->bitmask);

state.remaining_sel.set_index(i, entry_index);
}
Expand Down
6 changes: 6 additions & 0 deletions src/include/duckdb/execution/aggregate_hashtable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ struct aggr_ht_entry_t {
// Leaves upper bits intact, sets lower bits to all 1's
return hash | POINTER_MASK;
}

static inline hash_t ExtractSaltWithNull(const hash_t &hash) {
// Leaves upper bits intact, sets lower bits to all 0's
return hash & SALT_MASK;
}

inline hash_t GetSalt() const {
return ExtractSalt(value);
}
Expand Down
7 changes: 5 additions & 2 deletions src/include/duckdb/execution/join_hashtable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,15 @@ class JoinHashTable {

struct InsertState : ProbeState {
InsertState();
Vector hash_salts_v;

/// Because of the index hick up
SelectionVector remaining_sel;
};

//! Upper 16 bits are salt
static constexpr const hash_t SALT_MASK = 0xFFFF000000000000;
//! Lower 48 bits are the pointer
static constexpr const hash_t POINTER_MASK = 0x0000FFFFFFFFFFFF;

JoinHashTable(BufferManager &buffer_manager, const vector<JoinCondition> &conditions,
vector<LogicalType> build_types, JoinType type, const vector<idx_t> &output_columns);
~JoinHashTable();
Expand Down

0 comments on commit 9fe76a4

Please sign in to comment.