diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ConvertToDestinationStyle.cpp @@ -13,6 +13,7 @@ // //===----------------------------------------------------------------------===// // +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -134,10 +135,89 @@ } }; +/// Lower tensor.pad to linalg.generic + tensor.insert_slice. +struct PadOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp padOp, + PatternRewriter &rewriter) const override { + // Only ops with exactly one block are supported. + if (!padOp.getBodyRegion().hasOneBlock()) + return failure(); + + Location loc = padOp.getLoc(); + RankedTensorType resultType = padOp.getResultType(); + RankedTensorType srcType = padOp.getSourceType(); + + auto toValue = [&](OpFoldResult ofr) { + if (ofr.is()) + return ofr.get(); + return rewriter + .create(loc, *getConstantIntValue(ofr)) + .getResult(); + }; + + // Compute dynamic result dimensions. + SmallVector mixedLowPad = padOp.getMixedLowPad(); + SmallVector mixedHighPad = padOp.getMixedHighPad(); + SmallVector dynamicSizes; + for (int64_t i = 0; i < resultType.getRank(); ++i) { + if (!resultType.isDynamicDim(i)) + continue; + Value srcDim = rewriter.create(loc, padOp.getSource(), i); + Value lowPad = toValue(mixedLowPad[i]); + Value highPad = toValue(mixedHighPad[i]); + AffineExpr s0, s1, s2; + bindSymbols(padOp->getContext(), s0, s1, s2); + AffineExpr sumExpr = s0 + s1 + s2; + Value sum = rewriter.create( + loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); + dynamicSizes.push_back(sum); + } + + // Create tensor.empty. + auto emptyOp = rewriter.create(loc, resultType, dynamicSizes); + + // Create linalg.generic. + // TODO: Create linalg.fill in case the padded value is a constant. + SmallVector iteratorTypes( + resultType.getRank(), utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); + auto genericOp = rewriter.create( + loc, resultType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + resultType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < resultType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + rewriter.setInsertionPointAfter(genericOp); + + // Create tensor::InsertSliceOp. + SmallVector sliceSizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector sliceStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); + rewriter.replaceOpWithNewOp( + padOp, padOp.getSource(), genericOp->getResult(0), + /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); + + return success(); + } +}; + } // namespace void tensor::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { - patterns.insert( + patterns.insert( patterns.getContext()); } diff --git a/mlir/test/Dialect/Tensor/convert-to-destination-style.mlir b/mlir/test/Dialect/Tensor/convert-to-destination-style.mlir --- a/mlir/test/Dialect/Tensor/convert-to-destination-style.mlir +++ b/mlir/test/Dialect/Tensor/convert-to-destination-style.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-convert-to-destination-style-patterns %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s // CHECK-LABEL: func @tensor_from_elements_0d( // CHECK-SAME: %[[arg0:.*]]: index @@ -69,3 +69,36 @@ } : tensor return %0 : tensor } + +// ----- + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> +// CHECK: #[[$map2:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @tensor_pad( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]] +// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]] +// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor +// CHECK: %[[generic:.*]] = linalg.generic +// CHECK-SAME: {indexing_maps = [#[[$map2]]], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs(%[[empty]] : tensor) { +// CHECK: %[[i0:.*]] = linalg.index 0 +// CHECK: %[[i1:.*]] = linalg.index 1 +// CHECK: %[[mul:.*]] = arith.muli %[[i0]], %[[i1]] +// CHECK: linalg.yield %[[mul]] +// CHECK: } +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]] +// CHECK: %[[inserted:.*]] = tensor.insert_slice %[[t1]] into %[[generic]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] : tensor into tensor +// CHECK: return %[[inserted]] +func.func @tensor_pad(%t1: tensor, %l2: index, %h1: index, + %h2: index) -> tensor { + %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %m = arith.muli %arg0, %arg1 : index + tensor.yield %m : index + } : tensor to tensor + return %0 : tensor +}