diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -650,6 +650,11 @@ // Misc. vectorization patterns. //----------------------------------------------------------------------------// +/// Helper function that retrieves the value of an IntegerAttr. +static int64_t getIntFromAttr(Attribute attr) { + return attr.cast().getInt(); +} + /// Given a block, return the Value that the block yields if that Value is /// constant. In this context, "constant" means "defined outside of the block". /// Should not be called on blocks that yield more than one value. @@ -796,12 +801,122 @@ } }; +/// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = subtensor_insert %0 into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] +/// : tensor<17x5xf32> into tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = vector.transfer_read %src[%c0, %c0], %padding +/// : tensor, vector<17x5xf32> +/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] +/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor +/// ``` +/// +/// This rewrite is possible if: +/// - Low padding is static 0. +/// - `padOp` result shape is static. +/// - The entire padded tensor is inserted. +/// (Implies that sizes of `insertOp` are all static.) +/// - Only unit strides in `insertOp`. +/// - Single, scalar padding value. +struct PadTensorOpVectorizationWithSubTensorInsertPattern + : public VectorizePadTensorOpUserPattern { + using VectorizePadTensorOpUserPattern< + SubTensorInsertOp>::VectorizePadTensorOpUserPattern; + + LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, + SubTensorInsertOp insertOp) const override { + /// Given an OpFoldResult, return true if its value is guaranteed to be a + /// certain integer value. + auto isStaticInt = [&](OpFoldResult ofr, int64_t val) { + return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(val)); + }; + /// Given an OpFoldResult, return true if its value is guaranteed to be a + /// zero integer. + auto isZeroInt = [&](OpFoldResult ofr) { return isStaticInt(ofr, 0); }; + + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) + return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) + return failure(); + // Dynamic shapes not supported. + if (!padOp.result().getType().cast().hasStaticShape()) + return failure(); + + auto vecType = VectorType::get(padOp.getType().getShape(), + padOp.getType().getElementType()); + unsigned vecRank = vecType.getRank(); + unsigned tensorRank = insertOp.getType().getRank(); + + // Only unit stride supported. + if (!llvm::all_of(insertOp.getMixedStrides(), + [&](auto s) { return isStaticInt(s, 1); })) + return failure(); + + // Check if sizes match: Insert the entire tensor into most minor dims. + auto sizes = insertOp.getMixedSizes(); + for (unsigned i = 0; i < tensorRank - vecRank; ++i) { + if (!isStaticInt(sizes[i], 1)) + return failure(); + } + for (unsigned i = tensorRank - vecRank; i < tensorRank; ++i) { + if (!isStaticInt(sizes[i], vecType.getDimSize(i + vecRank - tensorRank))) + return failure(); + } + + // Read is out-of-bounds and will be padded. + SmallVector outOfBounds(vecRank, false); + auto readMap = AffineMapAttr::get(rewriter.getMultiDimIdentityMap(vecRank)); + // Assuming that low indices of PadTensorOp are all zero. + SmallVector readIndices( + tensorRank, + rewriter.create(padOp.getLoc(), rewriter.getIndexType(), + rewriter.getIndexAttr(0))); + auto read = rewriter.create( + padOp.getLoc(), vecType, padOp.source(), readIndices, readMap, padValue, + /*mask=*/Value(), rewriter.getBoolArrayAttr(outOfBounds)); + + // Compute indices of TransferWriteOp. + SmallVector writeIndices; + llvm::for_each(insertOp.getMixedOffsets(), [&](auto o) { + if (o.template is()) { + writeIndices.push_back(o.template get()); + } else { + // Convert int64 attr to index attr. + auto intAttr = + rewriter.getIndexAttr(getIntFromAttr(o.template get())); + writeIndices.push_back(rewriter.create( + padOp.getLoc(), rewriter.getIndexType(), intAttr)); + } + }); + + // Write is fully in-bounds. + SmallVector inBounds(vecRank, true); + // Write to the most minor dimensions of the tensor. + auto writeMap = AffineMapAttr::get(AffineMap::getMinorIdentityMap( + tensorRank, vecRank, rewriter.getContext())); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getType(), read.getResult(), insertOp.dest(), + writeIndices, writeMap, /*mask=*/Value(), + rewriter.getBoolArrayAttr(inBounds)); + + return success(); + } +}; + void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { patterns.add( patterns.getContext(), baseBenefit); // Try these specialized patterns first before resorting to the generic one. - patterns.add( + patterns.add( patterns.getContext(), baseBenefit.getBenefit() + 1); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -580,6 +580,28 @@ // ----- +// CHECK-LABEL: func @pad_and_subtensor_insert +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> +// CHECK: return %[[WRITE]] +func @pad_and_subtensor_insert( + %arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<7x9xf32> + %r = subtensor_insert %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32> + return %r : tensor<12x13xf32> +} + +// ----- + // CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> // CHECK-LABEL: func @sum_exp