diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -817,18 +817,6 @@ .hasTensorSemantics(); } - Operation *clone(OpBuilder & b, Location loc, TypeRange resultTypes, - ValueRange operands) { - return cast(*this->getOperation()) - .clone(b, loc, resultTypes, operands); - } - - Operation *cloneWithoutRegions(OpBuilder & b, Location loc, - TypeRange resultTypes, ValueRange operands) { - return cast(*this->getOperation()) - .cloneWithoutRegions(b, loc, resultTypes, operands); - } - //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. // These are useful when cloning and changing operand types. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -28,6 +28,8 @@ namespace mlir { class OpBuilder; +class TypeRange; +class ValueRange; /// Tests whether the given maps describe a row major matmul. The test is /// permutation-invariant. Note that this only checks the affine maps from an @@ -108,6 +110,16 @@ Operation *op; }; +// Clone the current operation with the operands. This is used to abstract away +// the optional underlying region creation. +Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, + ValueRange newOperands); + +// Clone the current operation with the operands but leave the regions empty. +Operation *cloneWithoutRegions(OpBuilder &b, Operation *op, + TypeRange newResultTypes, + ValueRange newOperands); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td --- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td @@ -262,51 +262,6 @@ opOperand.get().getType().template isa(); }); }] - >, - //===------------------------------------------------------------------===// - // Other static interface methods. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location and operands. This - is used to abstract away the optional underlying region creation. This - does not change the balance between input, init_buffer and - init_tensors operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"clone", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands), - [{ - BlockAndValueMapping bvm; - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (Region &r : $_op->getRegions()) - r.cloneInto(state.addRegion(), bvm); - return b.create(state); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location, operands - and BlockAndValueMapping but leave the regions empty. This is - used to abstract away the optional underlying region creation. - This does not change the balance between input, init_buffer - and init_tensors operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"cloneWithoutRegions", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands), - [{ - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt) - state.addRegion(); - return b.create(state); - }] > ]; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1793,8 +1793,7 @@ newResultTypes.push_back(newOperands.back().getType()); } // Clone op. - Operation *newOp = - op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); + Operation *newOp = clone(rewriter, op, newResultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { @@ -1856,7 +1855,7 @@ SmallVector resultTypes(linalgOp->result_type_begin(), linalgOp->result_type_end()); resultTypes[resultNumber] = resultType; - Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands); + Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); // Create a tensor.cast operation back to the original type. Value castBack = rewriter.create( @@ -2006,8 +2005,7 @@ return failure(); // Clone op. - Operation *newOp = - linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands); + Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -124,7 +124,7 @@ tiledOperands[opOperand->getOperandNumber()].getType()); Operation *newOp = - linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands); + clone(rewriter, linalgOp, resultTensorTypes, tiledOperands); rewriter.replaceOp(sliceOp, newOp->getResults()); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -74,8 +74,8 @@ // new op. Since the new op does not have any tensor results, it does not // return anything. assert(op->getNumRegions() == 1 && "expected that op has 1 region"); - auto newOp = cast(op.cloneWithoutRegions( - rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); + auto newOp = cast(cloneWithoutRegions( + rewriter, op, /*resultTypes=*/TypeRange{}, newOperands)); rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), newOp->getRegion(0).begin()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -165,7 +165,7 @@ staticStridesVector)); } - Operation *clonedOp = producer.clone(b, loc, resultTypes, clonedShapes); + Operation *clonedOp = clone(b, producer, resultTypes, clonedShapes); // Shift all IndexOp results by the tile offset. SmallVector allIvs = llvm::to_vector( diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -182,7 +182,7 @@ TypeRange resultTypes = ValueRange(tiledOperands) .take_back(producerOp.getNumDpsInits()) .getTypes(); - LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); + LinalgOp clonedOp = clone(b, producerOp, resultTypes, tiledOperands); // Shift all IndexOp results by the tile offset. offsetIndices(b, clonedOp, allIvs); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -671,7 +671,7 @@ SmallVector resultTensorTypes = getTensorOutputTypes(op, tiledOperands); - res = op.clone(b, loc, resultTensorTypes, tiledOperands); + res = clone(b, op, resultTensorTypes, tiledOperands); tensorResults = insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); return scf::ValueVector(tensorResults.begin(), tensorResults.end()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -125,8 +125,7 @@ SmallVector resultTensorTypes = getTensorOutputTypes(linalgOp, tiledOperands); - Operation *tiledOp = - linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); + Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); offsetIndices(b, cast(tiledOp), offsets); return {tiledOp}; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -195,7 +195,7 @@ // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes(); - paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); + paddedOp = clone(b, opToPad, resultTensorTypes, newOperands); // Recover the slice out of the new static results. This keeps the original // linalg op around because it uses the dims of the original results. diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -8,6 +8,8 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc" @@ -92,3 +94,23 @@ auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); return indexingMaps == maps; } + +Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, + ValueRange newOperands) { + BlockAndValueMapping bvm; + OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes, + op->getAttrs()); + for (Region &r : op->getRegions()) + r.cloneInto(state.addRegion(), bvm); + return b.create(state); +} + +Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op, + TypeRange newResultTypes, + ValueRange newOperands) { + OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes, + op->getAttrs()); + for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt) + state.addRegion(); + return b.create(state); +}