Skip to content

Commit

Permalink
add tensor_ext.sum
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Jan 15, 2025
1 parent f3fb651 commit 0e791be
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
36 changes: 36 additions & 0 deletions lib/Dialect/TensorExt/IR/TensorExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,42 @@ LogicalResult RotateOp::verify() {
return success();
}

LogicalResult SumOp::verify() {
auto inputTensor = cast<RankedTensorType>(getTensor().getType());
auto outputTensor = cast<RankedTensorType>(getOutput().getType());

if (inputTensor.getElementType() != outputTensor.getElementType()) {
return emitOpError()
<< "requires input and output tensors to have the same "
"element type, but found "
<< inputTensor.getElementType() << " and "
<< outputTensor.getElementType();
}

// The input and output must have the same shape when removing the index
// given by the operand dim.
unsigned int dim = getDim().getZExtValue();
SmallVector<int64_t, 4> inputShape;
for (int i = 0; i < inputTensor.getRank(); i++) {
if (i == dim) continue;
inputShape.push_back(inputTensor.getShape()[i]);
}

ArrayRef<int64_t> outputShape = outputTensor.getShape();

if (llvm::any_of(llvm::zip(inputShape, outputShape), [](auto pair) {
return std::get<0>(pair) != std::get<1>(pair);
})) {
return emitOpError()
<< "requires input and output tensors to have the same shape, but "
"after summing along dimension "
<< dim << " the input shape becomes " << inputTensor.getShape()
<< " but the output shape is " << outputTensor.getShape();
}

return success();
}

} // namespace tensor_ext
} // namespace heir
} // namespace mlir
16 changes: 14 additions & 2 deletions lib/Dialect/TensorExt/IR/TensorExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class TensorExt_Op<string mnemonic, list<Trait> traits = []> :
Op<TensorExt_Dialect, mnemonic, traits> {
let cppNamespace = "::mlir::heir::tensor_ext";
let assemblyFormat = "operands attr-dict `:` type($output)";
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}

def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor", "output"]>]> {
Expand Down Expand Up @@ -54,9 +54,21 @@ def TensorExt_ConvertLayoutOp : TensorExt_Op<"convert_layout", [Pure, AllTypesMa
let description = [{
// FIXME: add descr
}];

let assemblyFormat = "operands attr-dict `:` type($output)";
let arguments = (ins AnyTensor:$tensor, Builtin_AffineMapAttr:$from_layout, Builtin_AffineMapAttr:$to_layout);
let results = (outs AnyTensor:$output);
}

def TensorExt_SumOp : TensorExt_Op<"sum", [Pure]> {
let summary = "Sum along one dimension of a tensor.";
let description = [{
This op can be thought of as a special case of `linglg.reduce`,
but more concise.
}];
let arguments = (ins AnyTensor:$tensor, IndexAttr:$dim);
let results = (outs AnyTensor:$output);
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor) `->` type($output)";
let hasVerifier = 1;
}

#endif // LIB_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_

0 comments on commit 0e791be

Please sign in to comment.