diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -696,10 +696,80 @@ } }; +/// Base pattern for rewriting PadTensorOps whose result is consumed by a given +/// operation type OpTy. +template +struct VectorizePadTensorOpUserPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const final { + bool changed = false; + // Insert users in vector, because some users may be replaced/removed. + for (auto *user : llvm::to_vector<4>(padOp->getUsers())) + if (auto op = dyn_cast(user)) + changed |= rewriteUser(rewriter, padOp, op).succeeded(); + return success(changed); + } + + protected: + virtual LogicalResult rewriteUser( + PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0; +}; + +/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = vector.transfer_read %0[%c0, %c0], %cst +/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> +/// ``` +/// is rewritten to: +/// ``` +/// %r = vector.transfer_read %src[%c0, %c0], %padding +/// {in_bounds = [true, true]} +/// : tensor, vector<17x5xf32> +/// ``` +/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be +/// sure that the original padding value %cst was never used. +/// +/// This rewrite is possible if: +/// - `xferOp` has no out-of-bounds dims or mask. +/// - Low padding is static 0. +/// - Single, scalar padding value. +struct PadTensorOpVectorizationWithTransferReadPattern + : public VectorizePadTensorOpUserPattern { + using VectorizePadTensorOpUserPattern + ::VectorizePadTensorOpUserPattern; + + LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, + vector::TransferReadOp xferOp) const override { + // Low padding must be static 0. + if (!padOp.hasZeroLowPad()) return failure(); + // Pad value must be a constant. + auto padValue = padOp.getConstantPaddingValue(); + if (!padValue) return failure(); + // Padding value of existing `xferOp` is unused. + if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure(); + + rewriter.updateRootInPlace(xferOp, [&]() { + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + xferOp->setAttr(xferOp.getInBoundsAttrName(), + rewriter.getBoolArrayAttr(inBounds)); + xferOp.sourceMutable().assign(padOp.source()); + xferOp.paddingMutable().assign(padValue); + }); + + return success(); + } +}; + void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { patterns.add( patterns.getContext(), baseBenefit); + // Try these specialized patterns first before resorting to the generic one. + patterns.add( + patterns.getContext(), baseBenefit.getBenefit() + 1); } // TODO: cleanup all the convolution vectorization patterns. diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -558,6 +558,28 @@ // ----- +// CHECK-LABEL: func @pad_and_transfer_read +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: return %[[RESULT]] +func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %c6 = constant 6.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<10x13xf32> + %1 = vector.transfer_read %0[%c0, %c0], %c6 + : tensor<10x13xf32>, vector<7x9xf32> + return %1 : vector<7x9xf32> +} + +// ----- + // CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> // CHECK-LABEL: func @sum_exp