diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" @@ -70,9 +71,106 @@ } }; +/// Lower tensor.pad to linalg.generic + tensor.insert_slice. +struct PadOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp padOp, + PatternRewriter &rewriter) const override { + // Only ops with exactly one block are supported. + if (!padOp.getBodyRegion().hasOneBlock()) + return failure(); + + // Create tensor.empty. + Location loc = padOp.getLoc(); + RankedTensorType resultType = padOp.getResultType(); + ReifiedRankedShapedTypeDims reifiedShape; + if (failed(cast(padOp.getOperation()) + .reifyResultShapes(rewriter, reifiedShape))) + return rewriter.notifyMatchFailure( + padOp, "failed to reify tensor.pad op result shape"); + SmallVector dynamicSizes; + for (int64_t i = 0; i < resultType.getRank(); ++i) + if (resultType.isDynamicDim(i)) + dynamicSizes.push_back(reifiedShape[0][i]); + auto emptyOp = rewriter.create(loc, resultType, dynamicSizes); + + // Examine the yielded value to decide if a linalg.generic is neede or a + // linalg.fill is sufficient. + Value filled; + Value yieldedValue = + cast(padOp.getBody()->getTerminator()).getValue(); + Attribute constYieldedValue; + // Is the yielded value a bbArg defined outside of the PadOp? + bool outsideBbArg = + yieldedValue.isa() && + yieldedValue.cast().getOwner()->getParentOp() != + padOp.getOperation(); + // Is the yielded value an OpResult defined outside of the PadOp? + bool outsideOpResult = + yieldedValue.isa() && + yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); + bool invariantYieldedValue = outsideBbArg || outsideOpResult; + if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { + // Padding with a constant: Create linalg.fill. + Dialect *arithDialect = + rewriter.getContext()->getLoadedDialect(); + Value fillValue = arithDialect + ->materializeConstant(rewriter, constYieldedValue, + yieldedValue.getType(), + yieldedValue.getLoc()) + ->getResult(0); + auto fillOp = rewriter.create( + loc, ValueRange(fillValue), ValueRange(emptyOp.getResult())); + rewriter.setInsertionPointAfter(fillOp); + filled = fillOp.getResult(0); + } else if (invariantYieldedValue) { + // Padding with an invariant value. + auto fillOp = rewriter.create( + loc, ValueRange(yieldedValue), ValueRange(emptyOp.getResult())); + rewriter.setInsertionPointAfter(fillOp); + filled = fillOp.getResult(0); + } else { + // Create linalg.generic. + SmallVector iteratorTypes( + resultType.getRank(), utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); + auto genericOp = rewriter.create( + loc, resultType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + resultType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < resultType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + rewriter.setInsertionPointAfter(genericOp); + filled = genericOp->getResult(0); + } + + // Create tensor::InsertSliceOp. + SmallVector sliceSizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector sliceStrides(resultType.getRank(), + rewriter.getIndexAttr(1)); + rewriter.replaceOpWithNewOp( + padOp, padOp.getSource(), filled, + /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); + + return success(); + } +}; + } // namespace void linalg::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); + patterns.insert(patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir --- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir +++ b/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s // CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @tensor_generate( @@ -21,3 +21,86 @@ } : tensor return %0 : tensor } + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> +// CHECK: #[[$map2:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @tensor_pad( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]] +// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]] +// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor +// CHECK: %[[generic:.*]] = linalg.generic +// CHECK-SAME: {indexing_maps = [#[[$map2]]], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs(%[[empty]] : tensor) { +// CHECK: %[[i0:.*]] = linalg.index 0 +// CHECK: %[[i1:.*]] = linalg.index 1 +// CHECK: %[[mul:.*]] = arith.muli %[[i0]], %[[i1]] +// CHECK: linalg.yield %[[mul]] +// CHECK: } +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK: %[[inserted:.*]] = tensor.insert_slice %[[t1]] into %[[generic]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] : tensor into tensor +// CHECK: return %[[inserted]] +func.func @tensor_pad(%t1: tensor, %l2: index, %h1: index, + %h2: index) -> tensor { + %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %m = arith.muli %arg0, %arg1 : index + tensor.yield %m : index + } : tensor to tensor + return %0 : tensor +} + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> +// CHECK-LABEL: func @tensor_pad_constant( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]] +// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]] +// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor +// CHECK: %[[filled:.*]] = linalg.fill ins(%[[c50]] : index) outs(%[[empty]] : tensor) +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK: %[[inserted:.*]] = tensor.insert_slice %[[t1]] into %[[filled]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] : tensor into tensor +// CHECK: return %[[inserted]] +func.func @tensor_pad_constant(%t1: tensor, %l2: index, %h1: index, + %h2: index) -> tensor { + %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %c = arith.constant 50 : index + tensor.yield %c : index + } : tensor to tensor + return %0 : tensor +} + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> +// CHECK-LABEL: func @tensor_pad_invariant( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index, %[[padding:.*]]: index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]] +// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]] +// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor +// CHECK: %[[filled:.*]] = linalg.fill ins(%[[padding]] : index) outs(%[[empty]] : tensor) +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK: %[[inserted:.*]] = tensor.insert_slice %[[t1]] into %[[filled]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] : tensor into tensor +// CHECK: return %[[inserted]] +func.func @tensor_pad_invariant(%t1: tensor, %l2: index, %h1: index, + %h2: index, %padding: index) -> tensor { + %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %padding : index + } : tensor to tensor + return %0 : tensor +}