diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -650,6 +650,11 @@ // Misc. vectorization patterns. //----------------------------------------------------------------------------// +/// Helper function that retrieves the value of an IntegerAttr. +static int64_t getIntFromAttr(Attribute attr) { + return attr.cast().getInt(); +} + /// Given a block, return the Value that the block yields if that Value is /// constant. In this context, "constant" means "defined outside of the block". /// Should not be called on blocks that yield more than one value. @@ -675,6 +680,97 @@ return result; } +/// Check if `beforePadding` and `afterTrimming` have the same tensor size, +/// i.e., same dimensions. +/// +/// Dimensions may be static, dynamic or mix of both. However, if a dimension +/// is static (resp. dynamic) in `beforePadding`, the corresponding dimension in +/// `afterTrimming` must also be static (resp. dynamic); otherwise `false` is +/// returned. +/// +/// This is a conservative analysis. In case equal tensor sizes cannot be proven +/// statically, this analysis returns `false` even though the tensor sizes may +/// turn out to be equal at runtime. +static bool hasSameTensorSize(Value beforePadding, SubTensorOp afterTrimming) { + // Input to PadTensorOp may be a CastOp. Try with with both CastOp result + // and CastOp operand. + if (auto castOp = beforePadding.getDefiningOp()) + if (hasSameTensorSize(castOp.source(), afterTrimming)) + return true; + + auto t1 = beforePadding.getType().dyn_cast(); + auto t2 = afterTrimming.getType().dyn_cast(); + // Only RankedTensorType supported. + if (!t1 || !t2) + return false; + // Rank of both values must be the same. + if (t1.getRank() != t2.getRank()) + return false; + + // All static dimensions must be the same. Mixed cases (e.g., dimension static + // in `t1` but dynamic in `t2`) are not supported. + for (unsigned i = 0; i < t1.getRank(); ++i) { + if (t1.isDynamicDim(i) != t2.isDynamicDim(i)) + return false; + if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i)) + return false; + } + + // Nothing more to check if all dimensions are static. + if (t1.getNumDynamicDims() == 0) + return true; + + // All dynamic sizes must be the same. This is more difficult to check and not + // all cases are supported. The only supported case at the moment is when + // `beforePadding` is a SubTensorOp (or a cast thereof). + + // Apart from CastOp, only SubTensorOp is supported. + auto beforeSubtensor = beforePadding.getDefiningOp(); + if (!beforeSubtensor) + return false; + + assert(t1.getRank() == beforeSubtensor.getMixedSizes().size()); + assert(t2.getRank() == afterTrimming.getMixedSizes().size()); + + for (unsigned i = 0; i < t1.getRank(); ++i) { + // Skip static dimensions. + if (!t1.isDynamicDim(i)) + continue; + auto dim1 = beforeSubtensor.getMixedSizes()[i]; + auto dim2 = afterTrimming.getMixedSizes()[i]; + + if (auto v1 = dim1.dyn_cast()) { + // Compare dynamic sizes. + auto v2 = dim2.dyn_cast(); + if (!v2) + return false; // dim1 dynamic, but dim2 static + // Case 1: Values are identical. + if (v1 == v2) + continue; + // Case 2: Both values are identical AffineMinOps. (Should not happen if + // CSE is run.) + auto minOp1 = v1.getDefiningOp(); + auto minOp2 = v2.getDefiningOp(); + if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && + minOp1.operands() == minOp2.operands()) + continue; + // Add additional cases as needed. + } else { + // Compare static sizes. + auto s1 = getIntFromAttr(dim1.get()); + auto a2 = dim2.dyn_cast(); + if (!a2) + return false; // dim1 static, but dim2 dynamic + auto s2 = getIntFromAttr(a2); + if (s1 != s2) + return false; + } + } + + // All tests passed. + return true; +} + /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and /// TransferWriteOp. For now, this only applies when all low and high paddings /// are determined to be zero. @@ -725,10 +821,271 @@ } }; +/// Base pattern for rewriting PadTensorOps whose result is consumed by a given +/// operation type OpTy. +template +struct VectorizePadTensorOpUserPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const final { + bool changed = false; + // Insert users in vector, because some users may be replaced/removed. + for (auto *user : llvm::to_vector<4>(padOp->getUsers())) + if (auto op = dyn_cast(user)) + changed |= rewriteUser(rewriter, padOp, op).succeeded(); + return success(changed); + } + +protected: + virtual LogicalResult rewriteUser(PatternRewriter &rewriter, + PadTensorOp padOp, OpTy op) const = 0; +}; + +/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = vector.transfer_read %0[%c0, %c0], %cst +/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> +/// ``` +/// is rewritten to: +/// ``` +/// %r = vector.transfer_read %src[%c0, %c0], %padding +/// {in_bounds = [true, true]} +/// : tensor, vector<17x5xf32> +/// ``` +/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be +/// sure that the original padding value %cst was never used. +/// +/// This rewrite is possible if: +/// - `xferOp` has no out-of-bounds dims or mask. +/// - Low padding is static 0. +/// - Single, scalar padding value. +struct PadTensorOpVectorizationWithTransferReadPattern + : public VectorizePadTensorOpUserPattern { + using VectorizePadTensorOpUserPattern< + vector::TransferReadOp>::VectorizePadTensorOpUserPattern; + + LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, + vector::TransferReadOp xferOp) const override { + /// Given an OpFoldResult, return true if its value is guaranteed to be a + /// zero integer. + auto isZeroInt = [&](OpFoldResult ofr) { + return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); + }; + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) + return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) + return failure(); + // Padding value of existing `xferOp` is unused. + if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) + return failure(); + + rewriter.updateRootInPlace(xferOp, [&]() { + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + xferOp->setAttr(xferOp.getInBoundsAttrName(), + rewriter.getBoolArrayAttr(inBounds)); + xferOp.sourceMutable().assign(padOp.source()); + xferOp.paddingMutable().assign(padValue); + }); + + return success(); + } +}; + +/// Rewrite use of PadTensorOp result in TransferWriteOp. +/// This pattern rewrites TransferWriteOps that write to a padded tensor value, +/// where the same amount of padding is immediately removed again after the +/// write. In such cases, the TransferWriteOp can write to the non-padded tensor +/// value and apply out-of-bounds masking. E.g.: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %1 = linalg.pad_tensor %0 ... : tensor to tensor<17x5xf32> +/// %2 = vector.transfer_write %vec, %1[...] +/// : vector<17x5xf32>, tensor<17x5xf32> +/// %r = subtensor %2[0, 0] [%s0, %s1] [1, 1] +/// : tensor<17x5xf32> to tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor +/// ``` +/// Note: It is important that the SubTensorOp %r resizes the result of the +/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even +/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from +/// %r's old dimensions. +/// +/// This rewrite is possible if: +/// - Low padding is static 0. +/// - `xferOp` has exactly one use, which is a SubTensorOp. This SubTensorOp +/// trims the same amount of padding that was added beforehand. +/// - Single, scalar padding value. +struct PadTensorOpVectorizationWithTransferWritePattern + : public VectorizePadTensorOpUserPattern { + using VectorizePadTensorOpUserPattern< + vector::TransferWriteOp>::VectorizePadTensorOpUserPattern; + + LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, + vector::TransferWriteOp xferOp) const override { + /// Given an OpFoldResult, return true if its value is guaranteed to be a + /// zero integer. + auto isZeroInt = [&](OpFoldResult ofr) { + return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); + }; + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) + return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) + return failure(); + // TransferWriteOp result must be directly consumed by a SubTensorOp. + if (!xferOp->hasOneUse()) + return failure(); + auto trimPadding = dyn_cast(*xferOp->user_begin()); + if (!trimPadding) + return failure(); + // Only static zero offsets supported when trimming padding. + if (!llvm::all_of(trimPadding.getMixedOffsets(), isZeroInt)) + return failure(); + // trimPadding must remove the amount of padding that was added earlier. + if (!hasSameTensorSize(padOp.source(), trimPadding)) + return failure(); + + rewriter.setInsertionPoint(xferOp); + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + auto newXferOp = rewriter.replaceOpWithNewOp( + xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(), + xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(), + rewriter.getBoolArrayAttr(inBounds)); + rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); + + return success(); + } +}; + +/// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = subtensor_insert %0 into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] +/// : tensor<17x5xf32> into tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = vector.transfer_read %src[%c0, %c0], %padding +/// : tensor, vector<17x5xf32> +/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] +/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor +/// ``` +/// +/// This rewrite is possible if: +/// - Low padding is static 0. +/// - `padOp` result shape is static. +/// - The entire padded tensor is inserted. +/// (Implies that sizes of `insertOp` are all static.) +/// - Only unit strides in `insertOp`. +/// - Single, scalar padding value. +struct PadTensorOpVectorizationWithSubTensorInsertPattern + : public VectorizePadTensorOpUserPattern { + using VectorizePadTensorOpUserPattern< + SubTensorInsertOp>::VectorizePadTensorOpUserPattern; + + LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, + SubTensorInsertOp insertOp) const override { + /// Given an OpFoldResult, return true if its value is guaranteed to be a + /// certain integer value. + auto isStaticInt = [&](OpFoldResult ofr, int64_t val) { + return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(val)); + }; + /// Given an OpFoldResult, return true if its value is guaranteed to be a + /// zero integer. + auto isZeroInt = [&](OpFoldResult ofr) { return isStaticInt(ofr, 0); }; + + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) + return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) + return failure(); + // Dynamic shapes not supported. + if (!padOp.result().getType().cast().hasStaticShape()) + return failure(); + + auto vecType = VectorType::get(padOp.getType().getShape(), + padOp.getType().getElementType()); + unsigned vecRank = vecType.getRank(); + unsigned tensorRank = insertOp.getType().getRank(); + + // Only unit stride supported. + if (!llvm::all_of(insertOp.getMixedStrides(), + [&](auto s) { return isStaticInt(s, 1); })) + return failure(); + + // Check if sizes match: Insert the entire tensor into most minor dims. + auto sizes = insertOp.getMixedSizes(); + for (unsigned i = 0; i < tensorRank - vecRank; ++i) { + if (!isStaticInt(sizes[i], 1)) + return failure(); + } + for (unsigned i = tensorRank - vecRank; i < tensorRank; ++i) { + if (!isStaticInt(sizes[i], vecType.getDimSize(i + vecRank - tensorRank))) + return failure(); + } + + // Read is out-of-bounds and will be padded. + SmallVector outOfBounds(vecRank, false); + auto readMap = AffineMapAttr::get(rewriter.getMultiDimIdentityMap(vecRank)); + // Assuming that low indices of PadTensorOp are all zero. + SmallVector readIndices( + tensorRank, + rewriter.create(padOp.getLoc(), rewriter.getIndexType(), + rewriter.getIndexAttr(0))); + auto read = rewriter.create( + padOp.getLoc(), vecType, padOp.source(), readIndices, readMap, padValue, + /*mask=*/Value(), rewriter.getBoolArrayAttr(outOfBounds)); + + // Compute indices of TransferWriteOp. + SmallVector writeIndices; + llvm::for_each(insertOp.getMixedOffsets(), [&](auto o) { + if (o.template is()) { + writeIndices.push_back(o.template get()); + } else { + // Convert int64 attr to index attr. + auto intAttr = + rewriter.getIndexAttr(getIntFromAttr(o.template get())); + writeIndices.push_back(rewriter.create( + padOp.getLoc(), rewriter.getIndexType(), intAttr)); + } + }); + + // Write is fully in-bounds. + SmallVector inBounds(vecRank, true); + // Write to the most minor dimensions of the tensor. + auto writeMap = AffineMapAttr::get(AffineMap::getMinorIdentityMap( + tensorRank, vecRank, rewriter.getContext())); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getType(), read.getResult(), insertOp.dest(), + writeIndices, writeMap, /*mask=*/Value(), + rewriter.getBoolArrayAttr(inBounds)); + + return success(); + } +}; + void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { patterns.add(patterns.getContext(), baseBenefit); + // Try these specialized patterns first before resorting to the generic one. + patterns.add( + patterns.getContext(), baseBenefit.getBenefit() + 1); } // TODO: cleanup all the convolution vectorization patterns. diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -558,6 +558,97 @@ // ----- +// CHECK-LABEL: func @pad_and_transfer_read +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: return %[[RESULT]] +func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %c6 = constant 6.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<10x13xf32> + %1 = vector.transfer_read %0[%c0, %c0], %c6 + : tensor<10x13xf32>, vector<7x9xf32> + return %1 : vector<7x9xf32> +} + +// ----- + +// CHECK-LABEL: func @pad_and_transfer_write_static +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: vector<7x9xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32> +// CHECK: return %[[RESULT]] +func @pad_and_transfer_write_static( + %arg0: tensor<5x6xf32>, %arg1: vector<7x9xf32>) -> tensor<5x6xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<10x13xf32> + %1 = vector.transfer_write %arg1, %0[%c0, %c0] + : vector<7x9xf32>, tensor<10x13xf32> + %2 = subtensor %1[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32> + return %2 : tensor<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: vector<7x9xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[SUB:.*]] = subtensor %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor to tensor +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor +// CHECK: return %[[RESULT]] +func @pad_and_transfer_write_dynamic_static( + %arg0: tensor, %arg1: vector<7x9xf32>, %size: index, %padding: index) -> tensor { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %s = subtensor %arg0[0, 0] [%size, 6] [1, 1] + : tensor to tensor + %0 = linalg.pad_tensor %s low[0, 0] high[%padding, 7] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor to tensor + %1 = vector.transfer_write %arg1, %0[%c0, %c0] + : vector<7x9xf32>, tensor + %2 = subtensor %1[0, 0] [%size, 6] [1, 1] : tensor to tensor + return %2 : tensor +} + +// ----- + +// CHECK-LABEL: func @pad_and_subtensor_insert +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> +// CHECK: return %[[WRITE]] +func @pad_and_subtensor_insert( + %arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<7x9xf32> + %r = subtensor_insert %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32> + return %r : tensor<12x13xf32> +} + +// ----- + // CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> // CHECK-LABEL: func @sum_exp