diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -26,18 +26,43 @@ let methods = [ InterfaceMethod< /*desc=*/[{ - Returns a list of operands into which the result of the - tiled implementation is written into. With `tensor` - operands, this will be used as the initial tensor into which - the tiled results are inserted into. With `memref` operands, - this will be the operand into which the result of the tiled - operation is written into. + Returns all destination (output) OpOperands of the operation. + + TODO: Use DestinationStyleOpInterface in the future. + }], + /*retType=*/"SmallVector", + /*methodName=*/"getDestinationOpOperands", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + llvm_unreachable("getDestinationOpOperands not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns a list of values into which the result of the tiled + implementation is written into. + + With `tensor` values, these will be used as the initial tensors into + which the tiled results are inserted into. With `memref` values, these + will be the buffers into which the results of the tiled operation is + written into. + + By default, this method returns the destination operands of the + operation. Operations that do not have destination operands can + override this method to provide custom destination values. }], /*retType=*/"SmallVector", - /*methodName=*/"getDestinationOperands", + /*methodName=*/"getDestination", /*args=*/(ins "OpBuilder &":$b), /*methodBody=*/"", - /*defaultImplementation=*/"return ValueRange{};" + /*defaultImplementation=*/[{ + auto tileableOp = cast($_op.getOperation()); + SmallVector opOperands = + tileableOp.getDestinationOpOperands(); + return llvm::to_vector(llvm::map_range( + opOperands, [](OpOperand *operand) { return operand->get(); })); + }] >, InterfaceMethod< /*desc=*/[{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -100,9 +100,8 @@ // Create the first part. SmallVector firstResults; TilingInterface firstPart = createSplitPart( - rewriter, op.getLoc(), op, offsets, sizes, - op.getDestinationOperands(rewriter), dimension, minSplitPoint, - iterationSpace[dimension].offset, firstResults); + rewriter, op.getLoc(), op, offsets, sizes, op.getDestination(rewriter), + dimension, minSplitPoint, iterationSpace[dimension].offset, firstResults); // Need to pretend that the original op now takes as operands firstResults, // otherwise tiling interface implementation will take the wrong value to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -235,9 +235,9 @@ if (llvm::any_of(loopRanges, hasStrideOne)) return op->emitOpError("only stride-1 supported atm"); // TODO: support `getTiledImplementation` with >1 produced tiled ops. - auto destOperands = op.getDestinationOperands(b); - if (destOperands.size() != 1) - return op->emitOpError("only single dest operand supported atm"); + auto dest = op.getDestination(b); + if (dest.size() != 1) + return op->emitOpError("only single dest supported atm"); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { @@ -325,9 +325,8 @@ auto tilingInterfaceOp = dyn_cast(tiledOp); assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); - for (auto it : - llvm::zip(llvm::seq(unsigned(0), unsigned(destOperands.size())), - tilingInterfaceOp->getResults(), destOperands)) { + for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())), + tilingInterfaceOp->getResults(), dest)) { b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); SmallVector resultOffsets, resultSizes; if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets, @@ -608,10 +607,9 @@ } } // Generate loop nest: One loop per dimension. - SmallVector destOperand = - tilingInterface.getDestinationOperands(builder); + SmallVector dest = tilingInterface.getDestination(builder); loopNest = mlir::scf::buildLoopNest( - builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), + builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(dest), [&](OpBuilder &b, Location loc, ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { // Compute offsets and sizes of ExtractSliceOp. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -85,7 +85,7 @@ : public TilingInterface::ExternalModel, LinalgOpTy> { /// Return the destination operands. - SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { + SmallVector getDestinationOpOperands(Operation *op) const { return cast(op).getOutputOperands(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -334,8 +334,7 @@ return yieldedValues; }; SmallVector newLoops = replaceLoopNestWithNewYields( - rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), - yieldValueFn); + rewriter, tilingResult.loops, op.getDestination(rewriter), yieldValueFn); for (const auto &loop : llvm::enumerate(tilingResult.loops)) { rewriter.eraseOp(loop.value()); tilingResult.loops[loop.index()] = newLoops[loop.index()]; @@ -500,7 +499,7 @@ cast(fusableProducer->getOwner()); scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); SmallVector unfusedProducerOpDestValues = - unfusedProducerOp.getDestinationOperands(rewriter); + unfusedProducerOp.getDestination(rewriter); for (OpOperand &uses : unfusedProducerOp->getUses()) { if (uses.getOwner() == outerMostTiledLoop.getOperation()) { unsigned resultNumber = uses.get().cast().getResultNumber(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -20,8 +20,11 @@ namespace { struct PadOpTiling : public TilingInterface::ExternalModel { + SmallVector getDestinationOpOperands(Operation *op) const { + return {}; + } - SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { + SmallVector getDestination(Operation *op, OpBuilder &b) const { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); ReifiedRankedShapedTypeDims reifiedShapes;