diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -20,9 +20,11 @@ namespace { -/// Generic conversion for any LinalgOp on tensors. -static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, - const BufferizationOptions &options) { +/// Generic conversion for any DestinationStyleOpInterface on tensors. +static LogicalResult +bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, + DestinationStyleOpInterface op, + const BufferizationOptions &options) { // Take a guard before anything else. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); @@ -71,7 +73,7 @@ // new op. Since the new op does not have any tensor results, it does not // return anything. assert(op->getNumRegions() == 1 && "expected that op has 1 region"); - auto newOp = cast(op.cloneWithoutRegions( + auto newOp = cast(op.cloneWithoutRegions( rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), newOp->getRegion(0).begin()); @@ -105,7 +107,7 @@ SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const { - auto genericOp = cast(op); + auto genericOp = cast(op); // The i-th OpResult may alias with the i-th "out" tensor. return {genericOp.getOutputOperand(opResult.getResultNumber())}; @@ -113,7 +115,7 @@ SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - auto genericOp = cast(op); + auto genericOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. if (genericOp.isOutputTensor(&opOperand)) @@ -128,7 +130,8 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - return bufferizeLinalgOp(rewriter, cast(op), options); + return bufferizeDestinationStyleOpInterface( + rewriter, cast(op), options); } };