Skip to content

[MLIR] Cache symbol tables during OneShotBufferization analyses #138125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mscuttari
Copy link
Member

During bufferization, the callee of each func::CallOp / CallableOpInterface operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of the ModuleOp body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols.
The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/

This PR aims to partially address this scaling issue by leveraging the SymbolTableCollection class, whose instance is added to the FuncAnalysisState extension. Later modifications are also expected to address the problem in other methods required by BufferizableOpInterface (e.g., bufferize and getBufferType), which suffer of the same problem but do not provide access to any bufferization state.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels May 1, 2025
@llvmbot
Copy link
Member

llvmbot commented May 1, 2025

@llvm/pr-subscribers-mlir

Author: Michele Scuttari (mscuttari)

Changes

During bufferization, the callee of each func::CallOp / CallableOpInterface operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of the ModuleOp body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols.
The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/

This PR aims to partially address this scaling issue by leveraging the SymbolTableCollection class, whose instance is added to the FuncAnalysisState extension. Later modifications are also expected to address the problem in other methods required by BufferizableOpInterface (e.g., bufferize and getBufferType), which suffer of the same problem but do not provide access to any bufferization state.


Full diff: https://github.com/llvm/llvm-project/pull/138125.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h (+3)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+19-10)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+9-3)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac7..b63c0883c6c15 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
   /// analyzed.
   DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
 
+  /// A collection of cached SymbolTables used for faster function lookup.
+  mutable mlir::SymbolTableCollection symbolTable;
+
   /// This function is called right before analyzing the given FuncOp. It
   /// initializes the data structures for the FuncOp in this state object.
   void startFunctionAnalysis(FuncOp funcOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..86d15d4f0a607 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -76,13 +76,14 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 }
 
 /// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
+static FuncOp getCalledFunction(CallOpInterface callOp,
+                                mlir::SymbolTableCollection &symbolTable) {
   SymbolRefAttr sym =
       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+      symbolTable.lookupNearestSymbolFrom(callOp, sym));
 }
 
 /// Get FuncAnalysisState.
@@ -135,14 +136,14 @@ struct CallOpInterface
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is read.
       return true;
 
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     return funcState.readBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
@@ -150,14 +151,14 @@ struct CallOpInterface
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is written.
       return true;
 
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     return funcState.writtenBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
@@ -165,14 +166,14 @@ struct CallOpInterface
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Any OpResult may be aliasing.
       return detail::unknownGetAliasingValues(opOperand);
 
     // Get aliasing results from state.
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     auto aliasingReturnVals =
         funcState.aliasingReturnVals.lookup(funcOp).lookup(
             opOperand.getOperandNumber());
@@ -199,7 +200,11 @@ struct CallOpInterface
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+
+    // TODO Avoid recomputing the symbol tables every time.
+    mlir::SymbolTableCollection symbolTable;
+
+    FuncOp funcOp = getCalledFunction(callOp, symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     // If the callee was already bufferized, we can directly take the type from
@@ -243,7 +248,11 @@ struct CallOpInterface
     // 2. Rewrite tensor operands as memrefs based on type of the already
     //    bufferized callee.
     SmallVector<Value> newOperands;
-    FuncOp funcOp = getCalledFunction(callOp);
+
+    // TODO Avoid recomputing the symbol tables every time.
+    mlir::SymbolTableCollection symbolTable;
+
+    FuncOp funcOp = getCalledFunction(callOp, symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
     FunctionType funcType = funcOp.getFunctionType();
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index edd6bcf84f460..a025da8635135 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
 }
 
 /// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static func::FuncOp
+getCalledFunction(func::CallOp callOp,
+                  mlir::SymbolTableCollection &symbolTable) {
   SymbolRefAttr sym =
       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<func::FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+      symbolTable.lookupNearestSymbolFrom(callOp, sym));
 }
 
 /// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+
+  // TODO Avoid recomputing the symbol tables every time.
+  mlir::SymbolTableCollection symbolTable;
+
   for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
-      func::FuncOp calledFunction = getCalledFunction(callOp);
+      func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       // If the called function does not have any tensors in its signature, then
       // it is not necessary to bufferize the callee before the caller.

@llvmbot
Copy link
Member

llvmbot commented May 1, 2025

@llvm/pr-subscribers-mlir-bufferization

Author: Michele Scuttari (mscuttari)

Changes

During bufferization, the callee of each func::CallOp / CallableOpInterface operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of the ModuleOp body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols.
The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/

This PR aims to partially address this scaling issue by leveraging the SymbolTableCollection class, whose instance is added to the FuncAnalysisState extension. Later modifications are also expected to address the problem in other methods required by BufferizableOpInterface (e.g., bufferize and getBufferType), which suffer of the same problem but do not provide access to any bufferization state.


Full diff: https://github.com/llvm/llvm-project/pull/138125.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h (+3)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+19-10)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+9-3)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226460ac7..b63c0883c6c15 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
   /// analyzed.
   DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
 
+  /// A collection of cached SymbolTables used for faster function lookup.
+  mutable mlir::SymbolTableCollection symbolTable;
+
   /// This function is called right before analyzing the given FuncOp. It
   /// initializes the data structures for the FuncOp in this state object.
   void startFunctionAnalysis(FuncOp funcOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..86d15d4f0a607 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -76,13 +76,14 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 }
 
 /// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
+static FuncOp getCalledFunction(CallOpInterface callOp,
+                                mlir::SymbolTableCollection &symbolTable) {
   SymbolRefAttr sym =
       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+      symbolTable.lookupNearestSymbolFrom(callOp, sym));
 }
 
 /// Get FuncAnalysisState.
@@ -135,14 +136,14 @@ struct CallOpInterface
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is read.
       return true;
 
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     return funcState.readBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
@@ -150,14 +151,14 @@ struct CallOpInterface
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Assume that OpOperand is written.
       return true;
 
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     return funcState.writtenBbArgs.lookup(funcOp).contains(
         opOperand.getOperandNumber());
   }
@@ -165,14 +166,14 @@ struct CallOpInterface
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
+    FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Any OpResult may be aliasing.
       return detail::unknownGetAliasingValues(opOperand);
 
     // Get aliasing results from state.
-    const FuncAnalysisState &funcState = getFuncAnalysisState(state);
     auto aliasingReturnVals =
         funcState.aliasingReturnVals.lookup(funcOp).lookup(
             opOperand.getOperandNumber());
@@ -199,7 +200,11 @@ struct CallOpInterface
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+
+    // TODO Avoid recomputing the symbol tables every time.
+    mlir::SymbolTableCollection symbolTable;
+
+    FuncOp funcOp = getCalledFunction(callOp, symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     // If the callee was already bufferized, we can directly take the type from
@@ -243,7 +248,11 @@ struct CallOpInterface
     // 2. Rewrite tensor operands as memrefs based on type of the already
     //    bufferized callee.
     SmallVector<Value> newOperands;
-    FuncOp funcOp = getCalledFunction(callOp);
+
+    // TODO Avoid recomputing the symbol tables every time.
+    mlir::SymbolTableCollection symbolTable;
+
+    FuncOp funcOp = getCalledFunction(callOp, symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
     FunctionType funcType = funcOp.getFunctionType();
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index edd6bcf84f460..a025da8635135 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
 }
 
 /// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static func::FuncOp
+getCalledFunction(func::CallOp callOp,
+                  mlir::SymbolTableCollection &symbolTable) {
   SymbolRefAttr sym =
       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<func::FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+      symbolTable.lookupNearestSymbolFrom(callOp, sym));
 }
 
 /// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+
+  // TODO Avoid recomputing the symbol tables every time.
+  mlir::SymbolTableCollection symbolTable;
+
   for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
-      func::FuncOp calledFunction = getCalledFunction(callOp);
+      func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       // If the called function does not have any tensors in its signature, then
       // it is not necessary to bufferize the callee before the caller.

@mscuttari mscuttari marked this pull request as draft May 1, 2025 12:29
@mscuttari
Copy link
Member Author

Marking the PR as draft because I'm seeing a few bufferization tests failing due to a missing FuncAnalysisState extension. Unfortunately it is also quite inconvenient to instantiate it on the need, as the AnalaysisState class is passed as a const reference to the interface methods.

@mscuttari mscuttari marked this pull request as ready for review May 1, 2025 13:04
@mscuttari mscuttari changed the title Cache symbol tables during OneShotBufferization analyses [MLIR] Cache symbol tables during OneShotBufferization analyses May 1, 2025
assert(result && "FuncAnalysisState does not exist");
return *result;

// Unfortunately, at the moment the BufferizableOpInterface methods do provide
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit hacky. And I think it's not really necessary. When running through analyzeModuleOp (OneShotModuleBufferize.cpp), there will always be a FuncAnalysisState.

When running through the normal One-Shot Bufferize (bufferize-function-boundaries=false), there is no FuncAnalysisState. (But this entry point should only be used for functions, not modules. And I think it will ignore function calls anyway.) Implementations like getFuncOpAnalysisState fall back to the conservative path. We should do the same for getCalledFunction: If there is no FuncAnalysisState, use the previous "slow" lookup instead of the symbol table.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants