-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[MLIR] Cache symbol tables during OneShotBufferization analyses #138125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: Michele Scuttari (mscuttari) ChangesDuring bufferization, the callee of each This PR aims to partially address this scaling issue by leveraging the Full diff: https://github.com/llvm/llvm-project/pull/138125.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac7..b63c0883c6c15 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
/// analyzed.
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
+ /// A collection of cached SymbolTables used for faster function lookup.
+ mutable mlir::SymbolTableCollection symbolTable;
+
/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
void startFunctionAnalysis(FuncOp funcOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..86d15d4f0a607 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -76,13 +76,14 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
}
/// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
+static FuncOp getCalledFunction(CallOpInterface callOp,
+ mlir::SymbolTableCollection &symbolTable) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ symbolTable.lookupNearestSymbolFrom(callOp, sym));
}
/// Get FuncAnalysisState.
@@ -135,14 +136,14 @@ struct CallOpInterface
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is read.
return true;
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
return funcState.readBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
@@ -150,14 +151,14 @@ struct CallOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is written.
return true;
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
return funcState.writtenBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
@@ -165,14 +166,14 @@ struct CallOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Any OpResult may be aliasing.
return detail::unknownGetAliasingValues(opOperand);
// Get aliasing results from state.
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
auto aliasingReturnVals =
funcState.aliasingReturnVals.lookup(funcOp).lookup(
opOperand.getOperandNumber());
@@ -199,7 +200,11 @@ struct CallOpInterface
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
+ FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
// If the callee was already bufferized, we can directly take the type from
@@ -243,7 +248,11 @@ struct CallOpInterface
// 2. Rewrite tensor operands as memrefs based on type of the already
// bufferized callee.
SmallVector<Value> newOperands;
- FuncOp funcOp = getCalledFunction(callOp);
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
+ FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
FunctionType funcType = funcOp.getFunctionType();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index edd6bcf84f460..a025da8635135 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
}
/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static func::FuncOp
+getCalledFunction(func::CallOp callOp,
+ mlir::SymbolTableCollection &symbolTable) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ symbolTable.lookupNearestSymbolFrom(callOp, sym));
}
/// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp);
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
|
@llvm/pr-subscribers-mlir-bufferization Author: Michele Scuttari (mscuttari) ChangesDuring bufferization, the callee of each This PR aims to partially address this scaling issue by leveraging the Full diff: https://github.com/llvm/llvm-project/pull/138125.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac7..b63c0883c6c15 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
/// analyzed.
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
+ /// A collection of cached SymbolTables used for faster function lookup.
+ mutable mlir::SymbolTableCollection symbolTable;
+
/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
void startFunctionAnalysis(FuncOp funcOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..86d15d4f0a607 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -76,13 +76,14 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
}
/// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
+static FuncOp getCalledFunction(CallOpInterface callOp,
+ mlir::SymbolTableCollection &symbolTable) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ symbolTable.lookupNearestSymbolFrom(callOp, sym));
}
/// Get FuncAnalysisState.
@@ -135,14 +136,14 @@ struct CallOpInterface
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is read.
return true;
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
return funcState.readBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
@@ -150,14 +151,14 @@ struct CallOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is written.
return true;
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
return funcState.writtenBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
@@ -165,14 +166,14 @@ struct CallOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+ FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Any OpResult may be aliasing.
return detail::unknownGetAliasingValues(opOperand);
// Get aliasing results from state.
- const FuncAnalysisState &funcState = getFuncAnalysisState(state);
auto aliasingReturnVals =
funcState.aliasingReturnVals.lookup(funcOp).lookup(
opOperand.getOperandNumber());
@@ -199,7 +200,11 @@ struct CallOpInterface
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
+ FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
// If the callee was already bufferized, we can directly take the type from
@@ -243,7 +248,11 @@ struct CallOpInterface
// 2. Rewrite tensor operands as memrefs based on type of the already
// bufferized callee.
SmallVector<Value> newOperands;
- FuncOp funcOp = getCalledFunction(callOp);
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
+ FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
FunctionType funcType = funcOp.getFunctionType();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index edd6bcf84f460..a025da8635135 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
}
/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static func::FuncOp
+getCalledFunction(func::CallOp callOp,
+ mlir::SymbolTableCollection &symbolTable) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ symbolTable.lookupNearestSymbolFrom(callOp, sym));
}
/// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp);
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
|
Marking the PR as draft because I'm seeing a few bufferization tests failing due to a missing |
assert(result && "FuncAnalysisState does not exist"); | ||
return *result; | ||
|
||
// Unfortunately, at the moment the BufferizableOpInterface methods do provide |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a bit hacky. And I think it's not really necessary. When running through analyzeModuleOp
(OneShotModuleBufferize.cpp
), there will always be a FuncAnalysisState
.
When running through the normal One-Shot Bufferize (bufferize-function-boundaries=false
), there is no FuncAnalysisState
. (But this entry point should only be used for functions, not modules. And I think it will ignore function calls anyway.) Implementations like getFuncOpAnalysisState
fall back to the conservative path. We should do the same for getCalledFunction
: If there is no FuncAnalysisState
, use the previous "slow" lookup instead of the symbol table.
During bufferization, the callee of each
func::CallOp
/CallableOpInterface
operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of theModuleOp
body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols.The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/
This PR aims to partially address this scaling issue by leveraging the
SymbolTableCollection
class, whose instance is added to theFuncAnalysisState
extension. Later modifications are also expected to address the problem in other methods required byBufferizableOpInterface
(e.g.,bufferize
andgetBufferType
), which suffer of the same problem but do not provide access to any bufferization state.