diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -129,6 +129,16 @@ Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest); +/// This is a helper function for DestinationStyleOpInterface. If there is a +/// destination operand for the given OpResult, return that operand. Otherwise, +/// return an empty tensor (`tensor.empty`) with the shape of the OpResult. +Value getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult); + +/// This is a helper function for DestinationStyleOpInterface. Get or create +/// destinations for every tensor OpResult of the given op. +SmallVector getOrCreateDestinations(OpBuilder &b, Location loc, + Operation *op); + /// Function to control the folding of constant and extract slice using ControlConstantExtractSliceFusionFn = std::function; diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td --- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td @@ -199,18 +199,18 @@ /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); + assert(opOperand->getOwner() == $_op.getOperation()); return !opOperand->get().getType().template isa(); }] >, InterfaceMethod< - /*desc=*/"Return the result tied to `opOperand`.", + /*desc=*/"Return the OpResult that is tied to the given OpOperand.", /*retTy=*/"OpResult", /*methodName=*/"getTiedOpResult", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); + assert(opOperand->getOwner() == $_op.getOperation()); auto [start, end] = $_op.getOutputsPositionRange(); int64_t resultIndex = opOperand->getOperandNumber() - start; @@ -219,6 +219,17 @@ return $_op->getResult(resultIndex); }] >, + InterfaceMethod< + /*desc=*/"Return the OpOperand that is tied to the given OpResult.", + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedOpOperand", + /*args=*/(ins "OpResult":$opResult), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opResult.getDefiningOp() == $_op.getOperation()); + return $_op.getOutputOperand(opResult.getResultNumber()); + }] + >, //===------------------------------------------------------------------===// // Other interface methods. //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -24,21 +24,6 @@ }]; let cppNamespace = "::mlir"; let methods = [ - InterfaceMethod< - /*desc=*/[{ - Returns a list of operands into which the result of the - tiled implementation is written into. With `tensor` - operands, this will be used as the initial tensor into which - the tiled results are inserted into. With `memref` operands, - this will be the operand into which the result of the tiled - operation is written into. - }], - /*retType=*/"SmallVector", - /*methodName=*/"getDestinationOperands", - /*args=*/(ins "OpBuilder &":$b), - /*methodBody=*/"", - /*defaultImplementation=*/"return ValueRange{};" - >, InterfaceMethod< /*desc=*/[{ Returns a list of iterator types that describe the number of loops. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -375,10 +375,12 @@ int64_t resultNumber = pUse->get().cast().getResultNumber(); LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); - auto destinationOperands = tileableProducer.getDestinationOperands(rewriter); + // Gather destination tensors. + SmallVector destinationTenors = tensor::getOrCreateDestinations( + rewriter, tileableProducer->getLoc(), tileableProducer); BlockAndValueMapping bvm; - bvm.map(destinationOperands[resultNumber], bbArg); + bvm.map(destinationTenors[resultNumber], bbArg); auto tileableProducerClone = cast(rewriter.clone(*tileableProducer, bvm)); auto scopeGuard = @@ -403,7 +405,7 @@ // Replace the use in containingOp. rewriter.updateRootInPlace(containingOp, [&]() { containingOp->setOperand(pUse->getOperandNumber(), - destinationOperands.front()); + destinationTenors.front()); }); return fusedOp; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -97,12 +97,15 @@ return {op, TilingInterface()}; } + // Compute destination tensors. + SmallVector destinationTensors = + tensor::getOrCreateDestinations(rewriter, op.getLoc(), op); + // Create the first part. SmallVector firstResults; TilingInterface firstPart = createSplitPart( - rewriter, op.getLoc(), op, offsets, sizes, - op.getDestinationOperands(rewriter), dimension, minSplitPoint, - iterationSpace[dimension].offset, firstResults); + rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension, + minSplitPoint, iterationSpace[dimension].offset, firstResults); // Need to pretend that the original op now takes as operands firstResults, // otherwise tiling interface implementation will take the wrong value to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -235,7 +235,9 @@ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; if (llvm::any_of(loopRanges, hasStrideOne)) return op->emitOpError("only stride-1 supported atm"); - auto dest = op.getDestinationOperands(b); + + // Gather destination tensors. + SmallVector dest = tensor::getOrCreateDestinations(b, loc, op); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { @@ -622,11 +624,11 @@ getValueOrCreateConstantIndexOp(builder, loc, tileSizes[i])); } } - // Generate loop nest: One loop per dimension. - SmallVector destOperand = - tilingInterface.getDestinationOperands(builder); + SmallVector destinationTensors = + tensor::getOrCreateDestinations(builder, loc, tilingInterface); + loopNest = mlir::scf::buildLoopNest( - builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), + builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destinationTensors), [&](OpBuilder &b, Location loc, ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { // Compute offsets and sizes of ExtractSliceOp. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -84,11 +84,6 @@ struct LinalgOpTilingInterface : public TilingInterface::ExternalModel, LinalgOpTy> { - /// Return the destination operands. - SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { - return cast(op).getOutputOperands(); - } - /// Return the loop iterator type. SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -378,17 +378,21 @@ } } + SmallVector dest = + tensor::getOrCreateDestinations(rewriter, op.getLoc(), op); FailureOr> replacementOr = - yieldTiledValues(rewriter, op.getDestinationOperands(rewriter), - tilingResult.tiledOp->getResults(), resultOffsetsList, - resultSizesList, tilingResult.loops); + yieldTiledValues(rewriter, dest, tilingResult.tiledOp->getResults(), + resultOffsetsList, resultSizesList, tilingResult.loops); if (failed(replacementOr)) return rewriter.notifyMatchFailure(op, "failed to yield replacement"); + if (auto tiledInterfaceOp = dyn_cast(tilingResult.tiledOp)) { auto innerMostLoop = tilingResult.loops.back(); - updateDestinationOperandsForTiledOp( - rewriter, tiledInterfaceOp.getDestinationOperands(rewriter), - innerMostLoop.getRegionIterArgs()); + // Gather destination tensors. + SmallVector dest = tensor::getOrCreateDestinations( + rewriter, tiledInterfaceOp.getLoc(), tiledInterfaceOp); + updateDestinationOperandsForTiledOp(rewriter, dest, + innerMostLoop.getRegionIterArgs()); } tilingResult.replacements = replacementOr.value(); @@ -569,16 +573,17 @@ int64_t resultNumber = fusableProducer.getResultNumber(); if (auto producerOp = dyn_cast(fusableProducer.getOwner())) { - SmallVector destination = - producerOp.getDestinationOperands(rewriter); + SmallVector destination = tensor::getOrCreateDestinations( + rewriter, producerOp.getLoc(), producerOp); outerMostLoop.setIterArg(iterArgNumber.value(), destination[resultNumber]); } if (auto tiledAndFusedInterfaceOp = fusedProducerValue.value().getDefiningOp()) { scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); - SmallVector destination = - tiledAndFusedInterfaceOp.getDestinationOperands(rewriter); + SmallVector destination = tensor::getOrCreateDestinations( + rewriter, tiledAndFusedInterfaceOp.getLoc(), + tiledAndFusedInterfaceOp); updateDestinationOperandsForTiledOp( rewriter, destination[resultNumber], innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -24,6 +24,7 @@ MLIRArithUtils MLIRCastInterfaces MLIRComplexDialect + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRIR MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -54,6 +55,44 @@ return result; } +Value tensor::getOrCreateDestination(OpBuilder &b, Location loc, + OpResult opResult) { + auto tensorType = opResult.getType().dyn_cast(); + assert(tensorType && "expected tensor type"); + + // If the op has a destination, it implements DestinationStyleOpInterface and + // we can query the destination operand from that interface. + auto destOp = opResult.getDefiningOp(); + if (destOp) + return destOp.getTiedOpOperand(opResult)->get(); + + // Otherwise, create a new destination tensor with the same shape. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(opResult.getDefiningOp()); + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(opResult.getDefiningOp()); + assert(reifyShapedTypeInterface && + "creating a destination requires reifying the shape"); + if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + llvm_unreachable("failed to reify result shapes"); + + SmallVector mixedSizes = + getAsOpFoldResult(reifiedShapes[opResult.getResultNumber()]); + Value initTensor = + b.create(loc, mixedSizes, tensorType.getElementType()); + return initTensor; +} + +SmallVector tensor::getOrCreateDestinations(OpBuilder &b, Location loc, + Operation *op) { + SmallVector result; + for (OpResult opResult : op->getResults()) + if (opResult.getType().isa()) + result.push_back(getOrCreateDestination(b, loc, opResult)); + return result; +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -21,21 +21,6 @@ struct PadOpTiling : public TilingInterface::ExternalModel { - SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(op); - (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); - - auto padOp = cast(op); - SmallVector mixedSizes = getAsOpFoldResult(reifiedShapes[0]); - Value emptyTensor = b.create( - op->getLoc(), mixedSizes, padOp.getResultType().getElementType()); - return {emptyTensor}; - } - SmallVector getLoopIteratorTypes(Operation *op) const { auto padOp = cast(op); SmallVector iteratorTypes( diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5143,6 +5143,7 @@ ":CastOpInterfaces", ":ComplexDialect", ":ControlFlowInterfaces", + ":DestinationStyleOpInterface", ":DialectUtils", ":IR", ":InferTypeOpInterface",