diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -44,6 +44,24 @@ //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; +/// Materialize a buffer allocation for the given tensor.pad op and lower the +/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.: +/// +/// %0 = tensor.pad low[%l] high[%h] %t ... +/// +/// is lowered to: +/// +/// %alloc = memref.alloc +/// linalg.fill ... outs(%alloc) +/// %subview = memref.subview %alloc [%l] [...] [1] +/// memref.tensor_store %t, %subview +/// %0 = bufferization.to_tensor %alloc restrict writable +/// +/// In addition to rewriting the IR as shown above, the result of the +/// bufferization.to_tensor op is returned. +Value bufferizePadOp(RewriterBase &rewriter, tensor::PadOp padOp, + Attribute memorySpace); + void populatePadTensorTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options); @@ -60,6 +78,11 @@ /// style ops. void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns); +/// Populate patterns that bufferize and materialize allocations for +/// non-destination-style ops. +void populateBufferizeNonDpsOpsPatterns(RewritePatternSet &patterns, + Attribute memorySpace = {}); + /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -14,6 +14,8 @@ //===----------------------------------------------------------------------===// // #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -134,7 +136,115 @@ return success(); } }; +} // namespace + +static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, + Location loc, PadOp padOp, + Value dest) { + OpBuilder::InsertionGuard g(rewriter); + RankedTensorType resultType = padOp.getResultType(); + + // Examine the yielded value to decide if a linalg.generic is neede or a + // linalg.fill is sufficient. + Value yieldedValue = + cast(padOp.getBody()->getTerminator()).getValue(); + Attribute constYieldedValue; + // Is the yielded value a bbArg defined outside of the PadOp? + bool outsideBbArg = + yieldedValue.isa() && + yieldedValue.cast().getOwner()->getParentOp() != + padOp.getOperation(); + // Is the yielded value an OpResult defined outside of the PadOp? + bool outsideOpResult = + yieldedValue.isa() && + yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); + bool invariantYieldedValue = outsideBbArg || outsideOpResult; + if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { + // Padding with a constant: Create linalg.fill. + Dialect *arithDialect = + rewriter.getContext()->getLoadedDialect(); + Value fillValue = + arithDialect + ->materializeConstant(rewriter, constYieldedValue, + yieldedValue.getType(), yieldedValue.getLoc()) + ->getResult(0); + auto fillOp = rewriter.create(loc, ValueRange(fillValue), + ValueRange(dest)); + return fillOp; + } + + if (invariantYieldedValue) { + // Padding with an invariant value. + auto fillOp = rewriter.create(loc, ValueRange(yieldedValue), + ValueRange(dest)); + return fillOp; + } + + // Create linalg.generic. + SmallVector iteratorTypes(resultType.getRank(), + utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); + auto genericOp = rewriter.create( + loc, resultType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{dest}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + resultType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < resultType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + return genericOp; +} + +Value linalg::bufferizePadOp(RewriterBase &rewriter, PadOp padOp, + Attribute memorySpace) { + OpBuilder::InsertionGuard g(rewriter); + Location loc = padOp.getLoc(); + RankedTensorType resultType = padOp.getResultType(); + + // Create buffer allocation. + auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout( + resultType, memorySpace) + .cast(); + ReifiedRankedShapedTypeDims reifiedShape; + if (failed(cast(padOp.getOperation()) + .reifyResultShapes(rewriter, reifiedShape))) + llvm_unreachable("failed to reify tensor.pad op result shape"); + SmallVector dynamicSizes; + for (int64_t i = 0; i < resultType.getRank(); ++i) + if (resultType.isDynamicDim(i)) + dynamicSizes.push_back(reifiedShape[0][i]); + Value alloc = rewriter.create(loc, memrefType, dynamicSizes); + + // Create linalg.fill or linalg.generic. + Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc); + rewriter.setInsertionPointAfter(fillOp); + + // Create memref.tensor_store. + SmallVector sizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector strides(resultType.getRank(), + rewriter.getIndexAttr(1)); + Value subview = rewriter.create( + loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides); + rewriter.create(loc, padOp.getSource(), subview); + + // Create bufferization.to_tensor with "restrict" and "writable". The returned + // tensor is a new buffer allocation, so it does not alias with any buffer. + Value toTensorOp = rewriter.create( + loc, alloc, /*restrict=*/true, /*writable=*/true); + rewriter.replaceOp(padOp, toTensorOp); + return toTensorOp; +} +namespace { /// Lower tensor.pad to linalg.generic + tensor.insert_slice. struct PadOpConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -159,65 +269,10 @@ dynamicSizes.push_back(reifiedShape[0][i]); auto emptyOp = rewriter.create(loc, resultType, dynamicSizes); - // Examine the yielded value to decide if a linalg.generic is neede or a - // linalg.fill is sufficient. - Value filled; - Value yieldedValue = - cast(padOp.getBody()->getTerminator()).getValue(); - Attribute constYieldedValue; - // Is the yielded value a bbArg defined outside of the PadOp? - bool outsideBbArg = - yieldedValue.isa() && - yieldedValue.cast().getOwner()->getParentOp() != - padOp.getOperation(); - // Is the yielded value an OpResult defined outside of the PadOp? - bool outsideOpResult = - yieldedValue.isa() && - yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); - bool invariantYieldedValue = outsideBbArg || outsideOpResult; - if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { - // Padding with a constant: Create linalg.fill. - Dialect *arithDialect = - rewriter.getContext()->getLoadedDialect(); - Value fillValue = arithDialect - ->materializeConstant(rewriter, constYieldedValue, - yieldedValue.getType(), - yieldedValue.getLoc()) - ->getResult(0); - auto fillOp = rewriter.create( - loc, ValueRange(fillValue), ValueRange(emptyOp.getResult())); - rewriter.setInsertionPointAfter(fillOp); - filled = fillOp.getResult(0); - } else if (invariantYieldedValue) { - // Padding with an invariant value. - auto fillOp = rewriter.create( - loc, ValueRange(yieldedValue), ValueRange(emptyOp.getResult())); - rewriter.setInsertionPointAfter(fillOp); - filled = fillOp.getResult(0); - } else { - // Create linalg.generic. - SmallVector iteratorTypes( - resultType.getRank(), utils::IteratorType::parallel); - SmallVector indexingMaps( - 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); - auto genericOp = rewriter.create( - loc, resultType, /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ - indexingMaps, iteratorTypes); - Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, - resultType.getElementType(), loc); - rewriter.setInsertionPointToStart(body); - SmallVector bbArgReplacements; - for (int64_t i = 0; i < resultType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create(loc, i)); - rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); - - // Update terminator. - auto yieldOp = cast(body->getTerminator()); - rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); - rewriter.setInsertionPointAfter(genericOp); - filled = genericOp->getResult(0); - } + // Create linalg.fill or linalg.generic. + Operation *fillOp = + movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult()); + rewriter.setInsertionPointAfter(fillOp); // Create tensor::InsertSliceOp. SmallVector sliceSizes = @@ -225,15 +280,37 @@ SmallVector sliceStrides(resultType.getRank(), rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp( - padOp, padOp.getSource(), filled, + padOp, padOp.getSource(), fillOp->getResult(0), /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); return success(); } }; +struct BufferizePadOp : public OpRewritePattern { + BufferizePadOp(MLIRContext *ctx, Attribute memorySpace = {}, + PatternBenefit benefit = 1) + : OpRewritePattern(ctx, benefit), memorySpace(memorySpace) {} + + LogicalResult matchAndRewrite(PadOp padOp, + PatternRewriter &rewriter) const override { + // Only ops with exactly one block are supported. + if (!padOp.getBodyRegion().hasOneBlock()) + return failure(); + (void)linalg::bufferizePadOp(rewriter, padOp, memorySpace); + return success(); + } + +private: + const Attribute memorySpace; +}; } // namespace +void linalg::populateBufferizeNonDpsOpsPatterns(RewritePatternSet &patterns, + Attribute memorySpace) { + patterns.insert(patterns.getContext(), memorySpace); +} + void linalg::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { patterns.insert( diff --git a/mlir/test/Dialect/Linalg/bufferize-non-dps-ops.mlir b/mlir/test/Dialect/Linalg/bufferize-non-dps-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/bufferize-non-dps-ops.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -split-input-file \ +// RUN: -test-linalg-transform-patterns=test-bufferize-non-dps-ops-patterns \ +// RUN: -canonicalize %s | FileCheck %s + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> +// CHECK-LABEL: func @tensor_pad_constant( +// CHECK-SAME: %[[t:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]] +// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]] +// CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref +// CHECK: linalg.fill ins(%[[c50]] : index) outs(%[[alloc]] : memref) +// CHECK: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] +// CHECK: memref.tensor_store %[[t]], %[[subview]] +// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable : memref +// CHECK: return %[[r]] +func.func @tensor_pad_constant(%t: tensor, %l2: index, %h1: index, + %h2: index) -> tensor { + %0 = tensor.pad %t low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %c = arith.constant 50 : index + tensor.yield %c : index + } : tensor to tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -132,6 +132,10 @@ *this, "test-convert-to-destination-style-patterns", llvm::cl::desc("Test patterns that convert ops to destination style"), llvm::cl::init(false)}; + Option testBufferizeNonDpsOpsPatterns{ + *this, "test-bufferize-non-dps-ops-patterns", + llvm::cl::desc("Test patterns that bufferize non-DPS ops"), + llvm::cl::init(false)}; }; } // namespace @@ -228,6 +232,12 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyBufferizeNonDpsOpsPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + populateBufferizeNonDpsOpsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -256,6 +266,8 @@ return applyEraseUnnecessaryInputs(getOperation()); if (testConvertToDestinationStylePatterns) applyConvertToDestinationStylePatterns(getOperation()); + if (testBufferizeNonDpsOpsPatterns) + applyBufferizeNonDpsOpsPatterns(getOperation()); } namespace mlir {