Skip to content

Commit

Permalink
Add PrintOp to the Interpreter dialect (#2687)
Browse files Browse the repository at this point in the history
While writing an MLIR pass, I realized there's no easy way to inspect
the intermediate states of tensors while debugging them. ProbeOp does
something similar, but this does not provide a human readable way of
inspecting them given arbitrary SSA value.
  • Loading branch information
ghpvnist authored Jan 17, 2025
1 parent c125b32 commit 23d7f60
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 3 deletions.
8 changes: 8 additions & 0 deletions stablehlo/reference/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ class DefaultInterpreterFallback : public InterpreterFallback {
Process *process) final {
llvm::StringRef funcName = op.getParentOfType<func::FuncOp>().getSymName();

if (auto printOp = dyn_cast<stablehlo::interpreter::PrintOp>(op)) {
auto operand =
stablehlo::InterpreterValue(scope.findTensor(printOp.getOperand()));
auto status = stablehlo::interpreter::evalPrintOp(printOp, operand);
return wrapFallbackStatus(std::move(status), funcName,
"interpreter.print");
}

if (auto probeOp = dyn_cast<stablehlo::interpreter::ProbeOp>(op)) {
auto input =
stablehlo::InterpreterValue(scope.findTensor(probeOp.getOperand()));
Expand Down
14 changes: 14 additions & 0 deletions stablehlo/reference/InterpreterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ SmallVector<InterpreterValue> evalRunParallelOp(
return results;
}

llvm::Error evalPrintOp(PrintOp &op, InterpreterValue operand) {
std::string ssaValueStr;
llvm::raw_string_ostream stream(ssaValueStr);
stream << op.getOperand();

// Get the SSA name and print it like: `%0 = `
llvm::outs() << ssaValueStr.substr(0, ssaValueStr.find("=") + 2);

// Prints the tensor value
operand.getTensor().print(llvm::outs());
llvm::outs() << "\n";
return llvm::Error::success();
}

// `serializedProbeFileId` should be a unique positive integer which can be used
// to unambiguously derive a serialized filename for a given `probeId`.
llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
Expand Down
12 changes: 9 additions & 3 deletions stablehlo/reference/InterpreterOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "stablehlo/reference/Value.h"

#define GET_OP_CLASSES
#include "stablehlo/reference/InterpreterOps.h.inc"

namespace mlir {
namespace stablehlo {
namespace interpreter {
Expand All @@ -39,6 +42,12 @@ SmallVector<InterpreterValue> evalRunParallelOp(
ArrayRef<InterpreterValue> inputs, std::queue<StringAttr> &infeed,
SmallVector<SmallVector<StringAttr>> programs, SymbolTable &symbolTable);

// Print the SSA name followed by its type and value like:
// >>> %0 = tensor<i1> {
// ... [true]
// ... }
llvm::Error evalPrintOp(PrintOp &op, InterpreterValue operand);

llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
StringRef probeOutputDir,
int64_t serializedProbeFileId);
Expand All @@ -47,7 +56,4 @@ llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
} // namespace stablehlo
} // namespace mlir

#define GET_OP_CLASSES
#include "stablehlo/reference/InterpreterOps.h.inc"

#endif // STABLEHLO_REFERENCE_INTERPRETEROPS_H
21 changes: 21 additions & 0 deletions stablehlo/reference/InterpreterOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ def Interpreter_RunParallelOp : Op<Interpreter_Dialect, "run_parallel", []> {
let hasVerifier = 1;
}

def Interpreter_PrintOp : Op<Interpreter_Dialect, "print"> {
let summary = "Print operation";
let arguments = (ins
HLO_Tensor:$operand
);
let description = [{
Print the value to stdout.

This is useful to print intermediate states of the tensors while debugging.
This should only be used to debug small tensors since every instance of this
op and its contents are printed to stdout. To gather information in bulk for
larger tensors, prefer using ProbeOp.

Example:
```mlir
interpreter.print %operand : tensor<i1>
```
}];
let assemblyFormat = "$operand attr-dict `:` type($operand)";
}

def Interpreter_ProbeOp : Op<Interpreter_Dialect, "probe",
[SameOperandsAndResultType]> {
let arguments = (ins
Expand Down

0 comments on commit 23d7f60

Please sign in to comment.