Skip to content

Commit

Permalink
added very basic algorithm to find factorized consumer-producer pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
gropaul committed Jan 9, 2025
1 parent d31aace commit 4adfea2
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 35 deletions.
57 changes: 44 additions & 13 deletions src/include/duckdb/optimizer/factorization_optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ struct FactorConsumer {
explicit FactorConsumer(column_binding_set_t flat_columns, const LogicalOperator &op)
: flat_columns(std::move(flat_columns)), op(op) {
}

void Print() const {
Printer::Print("FactorConsumer");
for (auto &column : flat_columns) {
Printer::Print(
StringUtil::Format("ColumnIndex: %llu, TableIndex: %llu", column.column_index, column.table_index));
}
}

//! The columns that need to be flat, e.g. aggregate keys can be inside the factor
column_binding_set_t flat_columns;
const LogicalOperator &op;
Expand All @@ -35,24 +33,29 @@ struct FactorProducer {
explicit FactorProducer(column_binding_set_t factor_columns, const LogicalOperator &op)
: factor_columns(std::move(factor_columns)), op(op) {
}

void Print() const {
Printer::Print("FactorProducer");
for (auto &column : factor_columns) {
Printer::Print(
StringUtil::Format("ColumnIndex: %llu, TableIndex: %llu", column.column_index, column.table_index));
}
}

//! The columns that are factorized
column_binding_set_t factor_columns;
const LogicalOperator &op;
};

class ColumnBindingAccumulator : public LogicalOperatorVisitor {
struct FactorOperatorMatch {
explicit FactorOperatorMatch(FactorConsumer &consumer, FactorProducer &producer)
: consumer(consumer), producer(producer) {
}
FactorConsumer &consumer;
FactorProducer &producer;
};

class ColumnBindingCollector final : public LogicalOperatorVisitor {
public:
explicit ColumnBindingAccumulator();
explicit ColumnBindingCollector();
column_binding_set_t GetColumnReferences() {
return column_references;
}
Expand All @@ -64,26 +67,54 @@ class ColumnBindingAccumulator : public LogicalOperatorVisitor {
column_binding_set_t column_references;
};


//! Todo: Add a description
class FactorizationOptimizer : public LogicalOperatorVisitor {
class FactorizedOperatorCollector final : public LogicalOperatorVisitor {
public:
explicit FactorizationOptimizer();
FactorizedOperatorCollector() : consumers() {
}

explicit FactorizedOperatorCollector(const vector<FactorConsumer> &consumers) : consumers(consumers) {
}
vector<FactorOperatorMatch> GetPotentialMatches() {
return matches;
}

public:
void VisitOperator(LogicalOperator &op) override;

protected:
private:
vector<FactorConsumer> consumers;
vector<FactorOperatorMatch> matches;

private:
static bool Match(const FactorConsumer &consumer, const FactorProducer &producer) {
// todo: check if this producer can work with the consumers to create a match
return true;
}

static void AddFactorizedPreAggregate(LogicalAggregate &aggregate);
static bool HindersConsumption(LogicalOperator &op, FactorConsumer &consumer) {
// todo: filter out current sources that are not factorisable because of this operator
return false;
}

static bool CanProduceFactors(LogicalOperator &op);
FactorProducer GetFactorProducer(LogicalOperator &op);
static FactorProducer GetFactorProducer(LogicalOperator &op);

static bool CanConsumeFactors(LogicalOperator &op);
FactorConsumer GetFactorConsumer(LogicalOperator &op);
static FactorConsumer GetFactorConsumer(LogicalOperator &op);
};

//! Todo: Add a description
class FactorizationOptimizer : public LogicalOperatorVisitor {
public:
explicit FactorizationOptimizer();

public:
void VisitOperator(LogicalOperator &op) override;

private:
private:
static void AddFactorizedPreAggregate(LogicalAggregate &aggregate);
};

} // namespace duckdb
68 changes: 46 additions & 22 deletions src/optimizer/factorization_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ static column_binding_set_t VectorToSet(const std::vector<ColumnBinding> &vector
return set;
}

ColumnBindingAccumulator::ColumnBindingAccumulator() {
ColumnBindingCollector::ColumnBindingCollector() {

}

unique_ptr<Expression> ColumnBindingAccumulator::VisitReplace(BoundColumnRefExpression &expr,
unique_ptr<Expression> ColumnBindingCollector::VisitReplace(BoundColumnRefExpression &expr,
unique_ptr<Expression> *expr_ptr) {
column_references.insert(expr.binding);
return nullptr;
Expand All @@ -24,23 +25,12 @@ FactorizationOptimizer::FactorizationOptimizer() {
}

void FactorizationOptimizer::VisitOperator(LogicalOperator &op) {
auto column_bindings = op.GetColumnBindings();

if (CanConsumeFactors(op)) {
const auto consumer = GetFactorConsumer(op);
}
FactorizedOperatorCollector collector;
collector.VisitOperator(op);

if (CanProduceFactors(op)) {
const auto producer = GetFactorProducer(op);
}
const auto matches = collector.GetPotentialMatches();

if (op.type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) {
AddFactorizedPreAggregate(op.Cast<LogicalAggregate>());
}

for (auto &child : op.children) {
VisitOperator(*child);
}
printf("Matches: %lu\n", matches.size());
}

void FactorizationOptimizer::AddFactorizedPreAggregate(LogicalAggregate &aggregate) {
Expand All @@ -52,7 +42,7 @@ void FactorizationOptimizer::AddFactorizedPreAggregate(LogicalAggregate &aggrega
aggregate.children[0] = std::move(pre_aggregate);
}

bool FactorizationOptimizer::CanProduceFactors(LogicalOperator &op) {
bool FactorizedOperatorCollector::CanProduceFactors(LogicalOperator &op) {
switch (op.type) {
case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: {
const JoinType join_type = op.Cast<LogicalComparisonJoin>().join_type;
Expand All @@ -72,7 +62,7 @@ bool FactorizationOptimizer::CanProduceFactors(LogicalOperator &op) {
}
}

FactorProducer FactorizationOptimizer::GetFactorProducer(LogicalOperator &op) {
FactorProducer FactorizedOperatorCollector::GetFactorProducer(LogicalOperator &op) {
switch (op.type) {
case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: {
const auto &comparison_join = op.Cast<LogicalComparisonJoin>();
Expand All @@ -85,7 +75,7 @@ FactorProducer FactorizationOptimizer::GetFactorProducer(LogicalOperator &op) {
}
}

bool FactorizationOptimizer::CanConsumeFactors(LogicalOperator &op) {
bool FactorizedOperatorCollector::CanConsumeFactors(LogicalOperator &op) {
switch (op.type) {
case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY:
return true;
Expand All @@ -94,11 +84,11 @@ bool FactorizationOptimizer::CanConsumeFactors(LogicalOperator &op) {
}
}

FactorConsumer FactorizationOptimizer::GetFactorConsumer(LogicalOperator &op) {
FactorConsumer FactorizedOperatorCollector::GetFactorConsumer(LogicalOperator &op) {
switch (op.type) {
case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: {

ColumnBindingAccumulator accumulator;
ColumnBindingCollector accumulator;
for (auto &group_by_expression : op.Cast<LogicalAggregate>().groups) {
accumulator.VisitExpression(&group_by_expression);
}
Expand All @@ -111,4 +101,38 @@ FactorConsumer FactorizationOptimizer::GetFactorConsumer(LogicalOperator &op) {
}
}

void FactorizedOperatorCollector::VisitOperator(LogicalOperator &op) {

if (CanProduceFactors(op)) {
auto producer = GetFactorProducer(op);

for (auto &consumer : consumers) {
if (Match(consumer, producer)) {
matches.push_back(FactorOperatorMatch(consumer, producer));
}
}
}

vector<FactorConsumer> consumer_for_children;
for (auto &consumers : consumers) {
if (!HindersConsumption(op, consumers)) {
consumer_for_children.push_back(consumers);
}
}

if (CanConsumeFactors(op)) {
const auto consumer = GetFactorConsumer(op);
consumer_for_children.push_back(consumer);
}

for (auto &child : op.children) {
FactorizedOperatorCollector collector(consumer_for_children);
collector.VisitOperator(*child);

for (auto &match : collector.GetPotentialMatches()) {
matches.push_back(match);
}
}
}

} // namespace duckdb

0 comments on commit 4adfea2

Please sign in to comment.