Skip to content

Commit

Permalink
factorization working for the edgies case of all edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
gropaul committed Apr 16, 2024
1 parent 53115a4 commit fcd9ff7
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 35 deletions.
7 changes: 7 additions & 0 deletions src/common/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ PhysicalType LogicalType::GetInternalType() {
return PhysicalType::LIST;
case LogicalTypeId::ARRAY:
return PhysicalType::ARRAY;
case LogicalTypeId::FACTORIZED:
return PhysicalType::FACTORIZED;
case LogicalTypeId::POINTER:
// LCOV_EXCL_START
if (sizeof(uintptr_t) == sizeof(uint32_t)) {
Expand Down Expand Up @@ -288,13 +290,16 @@ string TypeIdToString(PhysicalType type) {
return "LIST";
case PhysicalType::ARRAY:
return "ARRAY";
case PhysicalType::FACTORIZED:
return "FACTORIZED";
case PhysicalType::INVALID:
return "INVALID";
case PhysicalType::BIT:
return "BIT";
case PhysicalType::UNKNOWN:
return "UNKNOWN";
}

return "INVALID";
}
// LCOV_EXCL_STOP
Expand Down Expand Up @@ -338,6 +343,8 @@ idx_t GetTypeIdSize(PhysicalType type) {
return 0; // no own payload
case PhysicalType::LIST:
return sizeof(list_entry_t); // offset + len
case PhysicalType::FACTORIZED:
return sizeof(fact_entry_t);
default:
throw InternalException("Invalid PhysicalType for GetTypeIdSize");
}
Expand Down
1 change: 1 addition & 0 deletions src/common/types/data_chunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ void DataChunk::Verify() {
#ifdef DEBUG
D_ASSERT(size() <= capacity);


// verify that all vectors in this chunk have the chunk selection vector
for (idx_t i = 0; i < ColumnCount(); i++) {
data[i].Verify(size());
Expand Down
5 changes: 5 additions & 0 deletions src/common/vector_operations/vector_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ static const ValidityMask &CopyValidityMask(const Vector &v) {
return FlatVector::Validity(v);
case VectorType::FSST_VECTOR:
return FSSTVector::Validity(v);
case VectorType::FACTORIZED_VECTOR:
return FactorizedVector::Validity(v);
default:
throw InternalException("Unsupported vector type in vector copy");
}
Expand Down Expand Up @@ -76,6 +78,9 @@ void VectorOperations::Copy(const Vector &source_p, Vector &target, const Select
case VectorType::FLAT_VECTOR:
finished = true;
break;
case VectorType::FACTORIZED_VECTOR:
finished = true;
break;
default:
throw NotImplementedException("FIXME unimplemented vector type for VectorOperations::Copy");
}
Expand Down
3 changes: 3 additions & 0 deletions src/execution/aggregate_hashtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/execution/ht_entry.hpp"
#include "duckdb/execution/join_hashtable.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"

namespace duckdb {
Expand Down Expand Up @@ -245,6 +246,8 @@ idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashe
D_ASSERT(groups.GetTypes()[i] == layout.GetTypes()[i]);
}
#endif
// todo: magic number 9 (pointer offset)
JoinHashTable::GetChainLengths(payload.data[0], groups.size(), 9);

const auto new_group_count = FindOrCreateGroups(groups, group_hashes, state.addresses, state.new_groups);
VectorOperations::AddInPlace(state.addresses, layout.GetAggrOffset(), payload.size());
Expand Down
58 changes: 49 additions & 9 deletions src/execution/join_hashtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vect

void ScanStructure::AdvancePointers(const SelectionVector &sel, idx_t sel_count) {

if (!ht.chains_longer_than_one) {
if (!ht.chains_longer_than_one || EmitFactVectors()) {
this->count = 0;
return;
}
Expand Down Expand Up @@ -914,6 +914,7 @@ void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &r
idx_t result_count = ScanInnerJoin(keys, chain_match_sel_vector);

if (result_count > 0) {

if (PropagatesBuildSide(ht.join_type)) {
// full/right outer join: mark join matches as FOUND in the HT
auto ptrs = FlatVector::GetData<data_ptr_t>(pointers);
Expand All @@ -925,19 +926,42 @@ void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &r
Store<bool>(true, ptrs[idx] + ht.tuple_size);
}
}

// for right semi join, just mark the entry as found and move on. Propagation happens later
if (ht.join_type != JoinType::RIGHT_SEMI && ht.join_type != JoinType::RIGHT_ANTI) {

// matches were found
// construct the result
// on the LHS, we create a slice using the result vector
result.Slice(left, chain_match_sel_vector, result_count);

// on the RHS, we need to fetch the data from the hash table
for (idx_t i = 0; i < ht.output_columns.size(); i++) {
auto &vector = result.data[left.ColumnCount() + i];
const auto output_col_idx = ht.output_columns[i];
D_ASSERT(vector.GetType() == ht.layout.GetTypes()[output_col_idx]);
GatherResult(vector, chain_match_sel_vector, result_count, output_col_idx);
result.Slice(left, chain_match_sel_vector, result_count, 0);

if (EmitFactVectors()) {
// in our very special case, the aggregate keys are the first vector and the key to be grouped by is
// the second vector

// set the first vector in the result to be the fact vector
auto &fact_vector = result.data[1];
fact_vector.SetVectorType(VectorType::FLAT_VECTOR);
// fact_vector.SetVectorType(VectorType::FACTORIZED_VECTOR);
auto fact_vector_pointer = FactorizedVector::GetData(fact_vector);

auto ptrs = FlatVector::GetData<data_ptr_t>(pointers);

for (idx_t j = 0; j < result_count; j++) {
auto idx = chain_match_sel_vector.get_index(j);
data_ptr_t ptr = ptrs[idx];
fact_vector_pointer[idx] = fact_entry_t(ptr);
}

} else {

// on the RHS, we need to fetch the data from the hash table
for (idx_t i = 0; i < ht.output_columns.size(); i++) {
auto &vector = result.data[left.ColumnCount() + i];
const auto output_col_idx = ht.output_columns[i];
D_ASSERT(vector.GetType() == ht.layout.GetTypes()[output_col_idx]);
GatherResult(vector, chain_match_sel_vector, result_count, output_col_idx);
}
}
}
AdvancePointers();
Expand Down Expand Up @@ -1473,6 +1497,22 @@ unique_ptr<ScanStructure> JoinHashTable::ProbeAndSpill(DataChunk &keys, TupleDat

return ss;
}
void JoinHashTable::GetChainLengths(Vector &row_pointer_v, idx_t count, idx_t pointer_offset) {

row_pointer_v.Flatten(count);
auto row_pointer = FlatVector::GetData<data_ptr_t>(row_pointer_v);

for (idx_t i = 0; i < count; i++) {
auto next_ptr = row_pointer[i];
idx_t chain_length = 0;
while (next_ptr) {
next_ptr = Load<data_ptr_t>(next_ptr + pointer_offset);
chain_length++;
}
// set the chain length in the row pointer
row_pointer[i] = reinterpret_cast<data_ptr_t>(chain_length);
}
}

ProbeSpill::ProbeSpill(JoinHashTable &ht, ClientContext &context, const vector<LogicalType> &probe_types)
: ht(ht), context(context), probe_types(probe_types) {
Expand Down
3 changes: 2 additions & 1 deletion src/include/duckdb/common/enums/vector_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ enum class VectorType : uint8_t {
FSST_VECTOR, // Contains string data compressed with FSST
CONSTANT_VECTOR, // Constant vector represents a single constant
DICTIONARY_VECTOR, // Dictionary vector represents a selection vector on top of another vector
SEQUENCE_VECTOR // Sequence vector represents a sequence with a start point and an increment
SEQUENCE_VECTOR, // Sequence vector represents a sequence with a start point and an increment
FACTORIZED_VECTOR, // Factorized vector represents a set of tuples as the cartesian product of a list of tuples
};

string VectorTypeToString(VectorType type);
Expand Down
62 changes: 38 additions & 24 deletions src/include/duckdb/common/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ struct list_entry_t { // NOLINT: mimic std casing
uint64_t length;
};

// used for the FactorizedVector
struct fact_entry_t { // NOLINT: mimic std casing
fact_entry_t(data_ptr_t data_ptr) : row_ptr(data_ptr) {
}
data_ptr_t row_ptr;
};

using union_tag_t = uint8_t;

//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -167,10 +174,13 @@ enum class PhysicalType : uint8_t {
///// Like LIST, but with 64-bit offsets
// LARGE_LIST = 33,

/// Factorized representation of multiple rows
FACTORIZED = 34,

/// DuckDB Extensions
VARCHAR = 200, // our own string representation, different from STRING and LARGE_STRING above
UINT128 = 203, // 128-bit unsigned integers
INT128 = 204, // 128-bit integers
INT128 = 204, // 128-bit integers
UNKNOWN = 205, // Unknown physical type of user defined types
/// Boolean as 1 bit, LSB bit-packed ordering
BIT = 206,
Expand Down Expand Up @@ -212,8 +222,8 @@ enum class LogicalTypeId : uint8_t {
TIMESTAMP_TZ = 32,
TIME_TZ = 34,
BIT = 36,
STRING_LITERAL = 37, /* string literals, used for constant strings - only exists while binding */
INTEGER_LITERAL = 38,/* integer literals, used for constant integers - only exists while binding */
STRING_LITERAL = 37, /* string literals, used for constant strings - only exists while binding */
INTEGER_LITERAL = 38, /* integer literals, used for constant integers - only exists while binding */

UHUGEINT = 49,
HUGEINT = 50,
Expand All @@ -229,7 +239,8 @@ enum class LogicalTypeId : uint8_t {
AGGREGATE_STATE = 105,
LAMBDA = 106,
UNION = 107,
ARRAY = 108
ARRAY = 108,
FACTORIZED = 109
};

struct ExtraTypeInfo;
Expand Down Expand Up @@ -319,29 +330,32 @@ struct LogicalType {
DUCKDB_API bool HasAlias() const;
DUCKDB_API string GetAlias() const;

//! Returns the maximum logical type when combining the two types - or throws an exception if combining is not possible
DUCKDB_API static LogicalType MaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right);
DUCKDB_API static bool TryGetMaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right, LogicalType &result);
//! Forcibly returns a maximum logical type - similar to MaxLogicalType but never throws. As a fallback either left or right are returned.
//! Returns the maximum logical type when combining the two types - or throws an exception if combining is not
//! possible
DUCKDB_API static LogicalType MaxLogicalType(ClientContext &context, const LogicalType &left,
const LogicalType &right);
DUCKDB_API static bool TryGetMaxLogicalType(ClientContext &context, const LogicalType &left,
const LogicalType &right, LogicalType &result);
//! Forcibly returns a maximum logical type - similar to MaxLogicalType but never throws. As a fallback either left
//! or right are returned.
DUCKDB_API static LogicalType ForceMaxLogicalType(const LogicalType &left, const LogicalType &right);
//! Normalize a type - removing literals
DUCKDB_API static LogicalType NormalizeType(const LogicalType &type);


//! Gets the decimal properties of a numeric type. Fails if the type is not numeric.
//! Gets the decimal properties of a numeric type. Fails if the type is not numeric.
DUCKDB_API bool GetDecimalProperties(uint8_t &width, uint8_t &scale) const;

DUCKDB_API void Verify() const;

DUCKDB_API bool IsValid() const;

template<class F>
template <class F>
bool Contains(F &&predicate) const;
bool Contains(LogicalTypeId type_id) const;

private:
LogicalTypeId id_; // NOLINT: allow this naming for legacy reasons
PhysicalType physical_type_; // NOLINT: allow this naming for legacy reasons
LogicalTypeId id_; // NOLINT: allow this naming for legacy reasons
PhysicalType physical_type_; // NOLINT: allow this naming for legacy reasons
shared_ptr<ExtraTypeInfo> type_info_; // NOLINT: allow this naming for legacy reasons

private:
Expand Down Expand Up @@ -383,9 +397,10 @@ struct LogicalType {
static constexpr const LogicalTypeId LAMBDA = LogicalTypeId::LAMBDA;
static constexpr const LogicalTypeId INVALID = LogicalTypeId::INVALID;
static constexpr const LogicalTypeId ROW_TYPE = LogicalTypeId::BIGINT;
static constexpr const LogicalTypeId FACTORIZED = LogicalTypeId::FACTORIZED;

// explicitly allowing these functions to be capitalized to be in-line with the remaining functions
DUCKDB_API static LogicalType DECIMAL(uint8_t width, uint8_t scale); // NOLINT
DUCKDB_API static LogicalType DECIMAL(uint8_t width, uint8_t scale); // NOLINT
DUCKDB_API static LogicalType VARCHAR_COLLATION(string collation); // NOLINT
DUCKDB_API static LogicalType LIST(const LogicalType &child); // NOLINT
DUCKDB_API static LogicalType STRUCT(child_list_t<LogicalType> children); // NOLINT
Expand All @@ -400,7 +415,7 @@ struct LogicalType {
// ANY but with special rules (default is LogicalType::ANY, 5)
DUCKDB_API static LogicalType ANY_PARAMS(LogicalType target, idx_t cast_score = 5); // NOLINT
//! Integer literal of the specified value
DUCKDB_API static LogicalType INTEGER_LITERAL(const Value &constant); // NOLINT
DUCKDB_API static LogicalType INTEGER_LITERAL(const Value &constant); // NOLINT
// DEPRECATED - provided for backwards compatibility
DUCKDB_API static LogicalType ENUM(const string &enum_name, Vector &ordered_data, idx_t size); // NOLINT
DUCKDB_API static LogicalType USER(const string &user_type_name); // NOLINT
Expand Down Expand Up @@ -534,27 +549,26 @@ struct aggregate_state_t {
vector<LogicalType> bound_argument_types;
};

template<class F>
template <class F>
bool LogicalType::Contains(F &&predicate) const {
if(predicate(*this)) {
if (predicate(*this)) {
return true;
}
switch(id()) {
switch (id()) {
case LogicalTypeId::STRUCT: {
for(const auto &child : StructType::GetChildTypes(*this)) {
if(child.second.Contains(predicate)) {
for (const auto &child : StructType::GetChildTypes(*this)) {
if (child.second.Contains(predicate)) {
return true;
}
}
}
break;
} break;
case LogicalTypeId::LIST:
return ListType::GetChildType(*this).Contains(predicate);
case LogicalTypeId::MAP:
return MapType::KeyType(*this).Contains(predicate) || MapType::ValueType(*this).Contains(predicate);
case LogicalTypeId::UNION:
for(const auto &child : UnionType::CopyMemberTypes(*this)) {
if(child.second.Contains(predicate)) {
for (const auto &child : UnionType::CopyMemberTypes(*this)) {
if (child.second.Contains(predicate)) {
return true;
}
}
Expand Down
13 changes: 13 additions & 0 deletions src/include/duckdb/common/types/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class Vector {
friend struct UnionVector;
friend struct SequenceVector;
friend struct ArrayVector;
friend struct FactorizedVector;

friend class DataChunk;
friend class VectorCacheBuffer;
Expand Down Expand Up @@ -563,4 +564,16 @@ struct SequenceVector {
}
};

struct FactorizedVector {

static inline const ValidityMask &Validity(const Vector &vector) {
D_ASSERT(vector.GetVectorType() == VectorType::FACTORIZED_VECTOR);
return vector.validity;
}

static inline fact_entry_t *GetData(Vector &vector) {
return reinterpret_cast<fact_entry_t *>(vector.data);
}
};

} // namespace duckdb
2 changes: 2 additions & 0 deletions src/include/duckdb/common/types/vector_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,4 +302,6 @@ class ManagedVectorBuffer : public VectorBuffer {
BufferHandle handle;
};

class VectorFactBuffer : public VectorBuffer {};

} // namespace duckdb
7 changes: 7 additions & 0 deletions src/include/duckdb/execution/join_hashtable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,13 @@ class JoinHashTable {
return PointerTableCapacity(count) * sizeof(data_ptr_t);
}

//! Whether or not to emit fact vectors from the HT
static bool EmitFactVectors(){
return true;
}

static void GetChainLengths(Vector &row_pointer_v, idx_t count, idx_t pointer_offset);

//! Get total size of HT if all partitions would be built
idx_t GetTotalSize(vector<unique_ptr<JoinHashTable>> &local_hts, idx_t &max_partition_size,
idx_t &max_partition_count) const;
Expand Down
3 changes: 2 additions & 1 deletion src/include/duckdb/execution/physical_operator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ class CachingOperatorState : public OperatorState {
//! inherit their state class from the CachingOperatorState.
class CachingPhysicalOperator : public PhysicalOperator {
public:
static constexpr const idx_t CACHE_THRESHOLD = 64;
// todo: Reduced CACHE THRESHOLD to 2 for testing purposes
static constexpr const idx_t CACHE_THRESHOLD = 2;
CachingPhysicalOperator(PhysicalOperatorType type, vector<LogicalType> types, idx_t estimated_cardinality);

bool caching_supported;
Expand Down
Loading

0 comments on commit fcd9ff7

Please sign in to comment.