diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -129,6 +129,18 @@ Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest); +/// This is a helper function for DestinationStyleOpInterface. If there is a +/// destination operand for the given OpResult, return that operand. Otherwise, +/// return an empty tensor (`tensor.empty`) with the shape of the OpResult. +/// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface. +FailureOr getOrCreateDestination(OpBuilder &b, Location loc, + OpResult opResult); + +/// This is a helper function for DestinationStyleOpInterface. Get or create +/// destinations for every tensor OpResult of the given op. +LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, + SmallVector &result); + /// Function to control the folding of constant and extract slice using ControlConstantExtractSliceFusionFn = std::function; diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td --- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td @@ -199,18 +199,18 @@ /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); + assert(opOperand->getOwner() == $_op.getOperation()); return !opOperand->get().getType().template isa(); }] >, InterfaceMethod< - /*desc=*/"Return the result tied to `opOperand`.", + /*desc=*/"Return the OpResult that is tied to the given OpOperand.", /*retTy=*/"OpResult", /*methodName=*/"getTiedOpResult", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); + assert(opOperand->getOwner() == $_op.getOperation()); auto [start, end] = $_op.getOutputsPositionRange(); int64_t resultIndex = opOperand->getOperandNumber() - start; @@ -219,6 +219,17 @@ return $_op->getResult(resultIndex); }] >, + InterfaceMethod< + /*desc=*/"Return the OpOperand that is tied to the given OpResult.", + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedOpOperand", + /*args=*/(ins "OpResult":$opResult), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opResult.getDefiningOp() == $_op.getOperation()); + return $_op.getOutputOperand(opResult.getResultNumber()); + }] + >, //===------------------------------------------------------------------===// // Other interface methods. //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -24,21 +24,6 @@ }]; let cppNamespace = "::mlir"; let methods = [ - InterfaceMethod< - /*desc=*/[{ - Returns a list of operands into which the result of the - tiled implementation is written into. With `tensor` - operands, this will be used as the initial tensor into which - the tiled results are inserted into. With `memref` operands, - this will be the operand into which the result of the tiled - operation is written into. - }], - /*retType=*/"SmallVector", - /*methodName=*/"getDestinationOperands", - /*args=*/(ins "OpBuilder &":$b), - /*methodBody=*/"", - /*defaultImplementation=*/"return ValueRange{};" - >, InterfaceMethod< /*desc=*/[{ Returns a list of iterator types that describe the number of loops. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -375,10 +375,18 @@ int64_t resultNumber = pUse->get().cast().getResultNumber(); LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); - auto destinationOperands = tileableProducer.getDestinationOperands(rewriter); + // Gather destination tensors. + SmallVector destinationTensors; + if (failed(tensor::getOrCreateDestinations( + rewriter, tileableProducer->getLoc(), tileableProducer, + destinationTensors))) { + diag.attachNote(tileableProducer->getLoc()) + << "failed to get destination tensors for: " << *tileableProducer; + return nullptr; + } BlockAndValueMapping bvm; - bvm.map(destinationOperands[resultNumber], bbArg); + bvm.map(destinationTensors[resultNumber], bbArg); auto tileableProducerClone = cast(rewriter.clone(*tileableProducer, bvm)); auto scopeGuard = @@ -403,7 +411,7 @@ // Replace the use in containingOp. rewriter.updateRootInPlace(containingOp, [&]() { containingOp->setOperand(pUse->getOperandNumber(), - destinationOperands.front()); + destinationTensors.front()); }); return fusedOp; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -97,12 +97,18 @@ return {op, TilingInterface()}; } + // Compute destination tensors. + SmallVector destinationTensors; + LogicalResult destStatus = tensor::getOrCreateDestinations( + rewriter, op.getLoc(), op, destinationTensors); + (void)destStatus; + assert(succeeded(destStatus) && "failed to get destination tensors"); + // Create the first part. SmallVector firstResults; TilingInterface firstPart = createSplitPart( - rewriter, op.getLoc(), op, offsets, sizes, - op.getDestinationOperands(rewriter), dimension, minSplitPoint, - iterationSpace[dimension].offset, firstResults); + rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension, + minSplitPoint, iterationSpace[dimension].offset, firstResults); // Need to pretend that the original op now takes as operands firstResults, // otherwise tiling interface implementation will take the wrong value to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -235,7 +235,11 @@ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; if (llvm::any_of(loopRanges, hasStrideOne)) return op->emitOpError("only stride-1 supported atm"); - auto dest = op.getDestinationOperands(b); + + // Gather destination tensors. + SmallVector dest; + if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) + return op->emitOpError("failed to get destination tensors"); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { @@ -622,11 +626,13 @@ getValueOrCreateConstantIndexOp(builder, loc, tileSizes[i])); } } - // Generate loop nest: One loop per dimension. - SmallVector destOperand = - tilingInterface.getDestinationOperands(builder); + SmallVector destinationTensors; + if (failed(tensor::getOrCreateDestinations(builder, loc, tilingInterface, + destinationTensors))) + return failure(); + loopNest = mlir::scf::buildLoopNest( - builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), + builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destinationTensors), [&](OpBuilder &b, Location loc, ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { // Compute offsets and sizes of ExtractSliceOp. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -84,11 +84,6 @@ struct LinalgOpTilingInterface : public TilingInterface::ExternalModel, LinalgOpTy> { - /// Return the destination operands. - SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { - return cast(op).getOutputOperands(); - } - /// Return the loop iterator type. SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ MLIRArithDialect MLIRBufferizationDialect MLIRBufferizationTransforms + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRIR MLIRMemRefDialect diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" #include "llvm/Support/Debug.h" @@ -274,6 +275,12 @@ op, "missing tile size computation function"); } + // Get destination tensors. + SmallVector destinationTensors; + if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, + destinationTensors))) + return rewriter.notifyMatchFailure(op, "failed to get destinations"); + // 1. Get the range of the loops that are represented by the operation. SmallVector iterationDomain = op.getIterationDomain(rewriter); size_t numLoops = iterationDomain.size(); @@ -378,17 +385,21 @@ } } - FailureOr> replacementOr = - yieldTiledValues(rewriter, op.getDestinationOperands(rewriter), - tilingResult.tiledOp->getResults(), resultOffsetsList, - resultSizesList, tilingResult.loops); + FailureOr> replacementOr = yieldTiledValues( + rewriter, destinationTensors, tilingResult.tiledOp->getResults(), + resultOffsetsList, resultSizesList, tilingResult.loops); if (failed(replacementOr)) return rewriter.notifyMatchFailure(op, "failed to yield replacement"); - if (auto tiledInterfaceOp = dyn_cast(tilingResult.tiledOp)) { + + if (auto dstOp = + dyn_cast(tilingResult.tiledOp)) { auto innerMostLoop = tilingResult.loops.back(); - updateDestinationOperandsForTiledOp( - rewriter, tiledInterfaceOp.getDestinationOperands(rewriter), - innerMostLoop.getRegionIterArgs()); + SmallVector destinationTensors = dstOp.getOutputOperands(); + assert(destinationTensors.size() == + innerMostLoop.getRegionIterArgs().size() && + "unexpected number of outputs"); + updateDestinationOperandsForTiledOp(rewriter, destinationTensors, + innerMostLoop.getRegionIterArgs()); } tilingResult.replacements = replacementOr.value(); @@ -567,20 +578,17 @@ } if (iterArgNumber) { int64_t resultNumber = fusableProducer.getResultNumber(); - if (auto producerOp = - dyn_cast(fusableProducer.getOwner())) { - SmallVector destination = - producerOp.getDestinationOperands(rewriter); - outerMostLoop.setIterArg(iterArgNumber.value(), - destination[resultNumber]); + if (auto dstOp = dyn_cast( + fusableProducer.getOwner())) { + outerMostLoop.setIterArg( + iterArgNumber.value(), + dstOp.getTiedOpOperand(fusableProducer)->get()); } - if (auto tiledAndFusedInterfaceOp = - fusedProducerValue.value().getDefiningOp()) { + if (auto dstOp = fusedProducerValue.value() + .getDefiningOp()) { scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); - SmallVector destination = - tiledAndFusedInterfaceOp.getDestinationOperands(rewriter); updateDestinationOperandsForTiledOp( - rewriter, destination[resultNumber], + rewriter, dstOp.getOutputOperand(resultNumber)->get(), innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); } } diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -24,6 +24,7 @@ MLIRArithUtils MLIRCastInterfaces MLIRComplexDialect + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRIR MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -54,6 +55,59 @@ return result; } +FailureOr tensor::getOrCreateDestination(OpBuilder &b, Location loc, + OpResult opResult) { + auto tensorType = opResult.getType().dyn_cast(); + assert(tensorType && "expected tensor type"); + + // If the op has a destination, it implements DestinationStyleOpInterface and + // we can query the destination operand from that interface. + auto destOp = opResult.getDefiningOp(); + if (destOp) + return destOp.getTiedOpOperand(opResult)->get(); + + // Otherwise, create a new destination tensor with the same shape. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(opResult.getDefiningOp()); + + // Compute sizes. + SmallVector mixedSizes; + if (!tensorType.hasStaticShape()) { + // Dynamic shape: Query ReifyRankedShapedTypeOpInterface. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(opResult.getDefiningOp()); + if (!reifyShapedTypeInterface) + return failure(); + if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + return failure(); + mixedSizes = getAsOpFoldResult(reifiedShapes[opResult.getResultNumber()]); + } else { + // Static shape: Take static sizes directly. + for (int64_t sz : tensorType.getShape()) + mixedSizes.push_back(b.getIndexAttr(sz)); + } + + // Create empty tensor. + Value emptyTensor = + b.create(loc, mixedSizes, tensorType.getElementType()); + return emptyTensor; +} + +LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc, + Operation *op, + SmallVector &result) { + for (OpResult opResult : op->getResults()) { + if (opResult.getType().isa()) { + FailureOr destination = getOrCreateDestination(b, loc, opResult); + if (failed(destination)) + return failure(); + result.push_back(*destination); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -21,21 +21,6 @@ struct PadOpTiling : public TilingInterface::ExternalModel { - SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(op); - (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); - - auto padOp = cast(op); - SmallVector mixedSizes = getAsOpFoldResult(reifiedShapes[0]); - Value emptyTensor = b.create( - op->getLoc(), mixedSizes, padOp.getResultType().getElementType()); - return {emptyTensor}; - } - SmallVector getLoopIteratorTypes(Operation *op) const { auto padOp = cast(op); SmallVector iteratorTypes( diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1902,6 +1902,7 @@ ":ArithUtils", ":BufferizationDialect", ":BufferizationTransforms", + ":DestinationStyleOpInterface", ":DialectUtils", ":FuncDialect", ":IR", @@ -5189,6 +5190,7 @@ ":CastOpInterfaces", ":ComplexDialect", ":ControlFlowInterfaces", + ":DestinationStyleOpInterface", ":DialectUtils", ":IR", ":InferTypeOpInterface",