diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -878,6 +878,23 @@ LogicalResult matchAndRewrite(PadTensorOp padOp, PatternRewriter &rewriter) const override; + +private: + bool hasSameTensorSize(Value beforePadding, SubTensorOp afterTrimming) const; + + LogicalResult tryVectorizeUser(PatternRewriter &rewriter, + linalg::PadTensorOp padOp, + SubTensorInsertOp insertOp, + Value padValue) const; + + LogicalResult tryVectorizeUser(PatternRewriter &rewriter, + linalg::PadTensorOp padOp, + vector::TransferReadOp xferOp, + Value padValue) const; + + LogicalResult tryVectorizeUser(PatternRewriter &rewriter, + linalg::PadTensorOp padOp, + vector::TransferWriteOp xferOp) const; }; /// Match and rewrite for the pattern: diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -669,17 +669,9 @@ return true; }; - auto resultShapedType = padOp.result().getType().cast(); - // Bail on non-static shapes. - if (!resultShapedType.hasStaticShape()) - return failure(); - // If any pad_low is not a static 0, needs a mask. Bail for now. if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex)) return failure(); - VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); - if (!vectorType) - return failure(); // Only support padding with a constant for now, i.e. either: // 1. A BBarg from a different block. @@ -694,6 +686,33 @@ if (!definingOp && padValue.cast().getOwner() == &block) return failure(); + // Shortcut: Try to avoid creating an InitTensorOp. + bool changed = false; + // Insert users in vector, because some users may be replaced/removed. + for (auto *user : llvm::to_vector<4>(padOp->getUsers())) { + if (auto insertOp = dyn_cast(user)) { + if (insertOp.source() == padOp.result()) + changed |= + succeeded(tryVectorizeUser(rewriter, padOp, insertOp, padValue)); + } else if (auto readOp = dyn_cast(user)) { + changed |= succeeded(tryVectorizeUser(rewriter, padOp, readOp, padValue)); + } else if (auto writeOp = dyn_cast(user)) { + changed |= succeeded(tryVectorizeUser(rewriter, padOp, writeOp)); + } + } + if (changed) + return success(); + // End of shortcut: Must create an InitTensorOp... + + // Bail on non-static shapes. + auto resultShapedType = padOp.result().getType().cast(); + if (!resultShapedType.hasStaticShape()) + return failure(); + + VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); + if (!vectorType) + return failure(); + // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail. if (llvm::any_of(padOp.getMixedHighPad(), [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); })) @@ -715,6 +734,275 @@ return success(); } +/// Helper function that retrieves the value of an IntegerAttr. +static int64_t getIntFromAttr(Attribute attr) { + return attr.cast().getInt(); +} + +/// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = subtensor_insert %0 into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] +/// : tensor<17x5xf32> into tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = vector.transfer_read %src[%c0, %c0], %padding +/// : tensor, vector<17x5xf32> +/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] +/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor +/// ``` +LogicalResult PadTensorOpVectorizationPattern::tryVectorizeUser( + PatternRewriter &rewriter, linalg::PadTensorOp padOp, + SubTensorInsertOp insertOp, Value padValue) const { + // Dynamic shapes not supported. + if (!padOp.result().getType().cast().hasStaticShape()) + return failure(); + + auto vecType = VectorType::get(padOp.getType().getShape(), + padOp.getType().getElementType()); + unsigned vecRank = vecType.getRank(); + unsigned tensorRank = insertOp.getType().getRank(); + + // Only unit stride supported. + if (llvm::any_of(insertOp.getMixedStrides(), [](auto s) { + return s.template is() || + getIntFromAttr(s.template get()) != 1; + })) + return failure(); + + // Check if sizes match: Insert the entire tensor into most minor dims. + auto sizes = insertOp.getMixedSizes(); + for (unsigned i = 0; i < tensorRank - vecRank; ++i) { + if (sizes[i].is() || getIntFromAttr(sizes[i].get()) != 1) + return failure(); + } + for (unsigned i = tensorRank - vecRank; i < tensorRank; ++i) { + if (sizes[i].is() || + getIntFromAttr(sizes[i].get()) != + vecType.getDimSize(i + vecRank - tensorRank)) + return failure(); + } + + // Read is out-of-bounds and will be padded. + SmallVector outOfBounds(vecRank, false); + auto readMap = AffineMapAttr::get(rewriter.getMultiDimIdentityMap(vecRank)); + // Assuming that low indices of PadTensorOp are all zero. Must use a different + // starting point + masking for the vector read when the pattern is extended. + SmallVector readIndices( + tensorRank, + rewriter.create(padOp.getLoc(), rewriter.getIndexType(), + rewriter.getIndexAttr(0))); + auto read = rewriter.create( + padOp.getLoc(), vecType, padOp.source(), readIndices, readMap, padValue, + /*mask=*/Value(), rewriter.getBoolArrayAttr(outOfBounds)); + + // Compute indices of TransferWriteOp. + SmallVector writeIndices; + llvm::for_each(insertOp.getMixedOffsets(), [&](auto o) { + if (o.template is()) { + writeIndices.push_back(o.template get()); + } else { + // Convert int64 attr to index attr. + auto intAttr = + rewriter.getIndexAttr(getIntFromAttr(o.template get())); + writeIndices.push_back(rewriter.create( + padOp.getLoc(), rewriter.getIndexType(), intAttr)); + } + }); + + // Write is fully in-bounds. + SmallVector inBounds(vecRank, true); + // Write to the most minor dimensions of the tensor. + auto writeMap = AffineMapAttr::get(AffineMap::getMinorIdentityMap( + tensorRank, vecRank, rewriter.getContext())); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getType(), read.getResult(), insertOp.dest(), + writeIndices, writeMap, /*mask=*/Value(), + rewriter.getBoolArrayAttr(inBounds)); + + return success(); +} + +/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = vector.transfer_read %0[%c0, %c0], %cst +/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> +/// ``` +/// is rewritten to: +/// ``` +/// %r = vector.transfer_read %src[%c0, %c0], %padding +/// {in_bounds = [true, true]} +/// : tensor, vector<17x5xf32> +/// ``` +/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be +/// sure that the original padding value %cst was never used. +LogicalResult PadTensorOpVectorizationPattern::tryVectorizeUser( + PatternRewriter &rewriter, linalg::PadTensorOp padOp, + vector::TransferReadOp xferOp, Value padValue) const { + if (xferOp.hasOutOfBoundsDim()) + return failure(); + if (xferOp.mask()) + return failure(); + + rewriter.updateRootInPlace(xferOp, [&]() { + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + xferOp->setAttr(xferOp.getInBoundsAttrName(), + rewriter.getBoolArrayAttr(inBounds)); + xferOp.sourceMutable().assign(padOp.source()); + xferOp.paddingMutable().assign(padValue); + }); + + return success(); +} + +/// Check if `beforePadding` and `afterTrimming` have the same tensor size, +/// i.e., same dimensions. +/// +/// Dimensions may be static, dynamic or mix of both. However, if a dimension +/// is static (resp. dynamic) in `beforePadding`, the corresponding dimension in +/// `afterTrimming` must also be static (resp. dynamic); otherwise `false` is +/// returned. +/// +/// This is a conservative analysis. In case equal tensor sizes cannot be proven +/// statically, this analysis returns `false` even though the tensor sizes may +/// turn out to be equal at runtime. +bool PadTensorOpVectorizationPattern::hasSameTensorSize( + Value beforePadding, SubTensorOp afterTrimming) const { + // Input to PadTensorOp may be a CastOp. Try with with both CastOp result + // and CastOp operand. + if (auto castOp = beforePadding.getDefiningOp()) + if (hasSameTensorSize(castOp.source(), afterTrimming)) + return true; + + auto t1 = beforePadding.getType().dyn_cast(); + auto t2 = afterTrimming.getType().dyn_cast(); + // Only RankedTensorType supported. + if (!t1 || !t2) + return false; + // Rank of both values must be the same. + if (t1.getRank() != t2.getRank()) + return false; + + // All static dimensions must be the same. Mixed cases (e.g., dimension static + // in `t1` but dynamic in `t2`) are not supported. + for (unsigned i = 0; i < t1.getRank(); ++i) { + if (t1.isDynamicDim(i) != t2.isDynamicDim(i)) + return false; + if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i)) + return false; + } + + // Nothing more to check if all dimensions are static. + if (t1.getNumDynamicDims() == 0) + return true; + + // All dynamic sizes must be the same. This is more difficult to check and not + // all cases are supported. The only supported case at the moment is when + // `beforePadding` is a SubTensorOp (or a cast thereof). + + // Apart from CastOp, only SubTensorOp is supported. + auto beforeSubtensor = beforePadding.getDefiningOp(); + if (!beforeSubtensor) + return false; + + assert(t1.getRank() == beforeSubtensor.getMixedSizes().size()); + assert(t2.getRank() == afterTrimming.getMixedSizes().size()); + + for (unsigned i = 0; i < t1.getRank(); ++i) { + // Skip static dimensions. + if (!t1.isDynamicDim(i)) + continue; + auto dim1 = beforeSubtensor.getMixedSizes()[i]; + auto dim2 = afterTrimming.getMixedSizes()[i]; + + if (auto v1 = dim1.dyn_cast()) { + // Compare dynamic sizes. + auto v2 = dim2.dyn_cast(); + if (!v2) + return false; // dim1 dynamic, but dim2 static + // Case 1: Values are identical. + if (v1 == v2) + continue; + // Case 2: Both values are identical AffineMinOps. (Should not happen if + // CSE is run.) + auto minOp1 = v1.getDefiningOp(); + auto minOp2 = v2.getDefiningOp(); + if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && + minOp1.operands() == minOp2.operands()) + continue; + // Add additional cases as needed. + } else { + // Compare static sizes. + auto s1 = getIntFromAttr(dim1.get()); + auto a2 = dim2.dyn_cast(); + if (!a2) + return false; // dim1 static, but dim2 dynamic + auto s2 = getIntFromAttr(a2); + if (s1 != s2) + return false; + } + } + + // All tests passed. + return true; +} + +/// Rewrite use of PadTensorOp result in TransferWriteOp. +/// This pattern rewrites TransferWriteOps that write to a padded tensor value, +/// where the same amount of padding is immediately removed again after the +/// write. In such cases, the TransferWriteOp can write to the non-padded tensor +/// value and apply out-of-bounds masking. E.g.: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %1 = linalg.pad_tensor %0 ... : tensor to tensor<17x5xf32> +/// %2 = vector.transfer_write %vec, %1[...] +/// : vector<17x5xf32>, tensor<17x5xf32> +/// %r = subtensor %2[0, 0] [%s0, %s1] [1, 1] +/// : tensor<17x5xf32> to tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor +/// ``` +/// Note: It is important that the SubTensorOp %r resizes the result of the +/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even +/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from +/// %r's old dimensions. +LogicalResult PadTensorOpVectorizationPattern::tryVectorizeUser( + PatternRewriter &rewriter, linalg::PadTensorOp padOp, + vector::TransferWriteOp xferOp) const { + OpBuilder::InsertionGuard guard(rewriter); + + // TransferWriteOp result must be directly consumed by a SubTensorOp. + if (!xferOp->hasOneUse()) + return failure(); + auto trimPadding = dyn_cast(*xferOp->user_begin()); + if (!trimPadding) + return failure(); + // Only static zero offsets supported when trimming padding. + if (llvm::any_of(trimPadding.getMixedOffsets(), [](auto s) { + return s.template is() || + getIntFromAttr(s.template get()) != 0; + })) + return failure(); + // trimPadding must remove the same amount of padding that was added earlier. + if (!hasSameTensorSize(padOp.source(), trimPadding)) + return failure(); + + rewriter.setInsertionPoint(xferOp); + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + auto newXferOp = rewriter.replaceOpWithNewOp( + xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(), + xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(), + rewriter.getBoolArrayAttr(inBounds)); + rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); + + return success(); +} + // TODO: cleanup all the convolution vectorization patterns. template LogicalResult ConvOpVectorization::matchAndRewrite( diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -558,6 +558,97 @@ // ----- +// CHECK-LABEL: func @pad_and_transfer_read +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: return %[[RESULT]] +func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %c6 = constant 6.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<10x13xf32> + %1 = vector.transfer_read %0[%c0, %c0], %c6 + : tensor<10x13xf32>, vector<7x9xf32> + return %1 : vector<7x9xf32> +} + +// ----- + +// CHECK-LABEL: func @pad_and_transfer_write_static +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: vector<7x9xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32> +// CHECK: return %[[RESULT]] +func @pad_and_transfer_write_static( + %arg0: tensor<5x6xf32>, %arg1: vector<7x9xf32>) -> tensor<5x6xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<10x13xf32> + %1 = vector.transfer_write %arg1, %0[%c0, %c0] + : vector<7x9xf32>, tensor<10x13xf32> + %2 = subtensor %1[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32> + return %2 : tensor<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: vector<7x9xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[SUB:.*]] = subtensor %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor to tensor +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor +// CHECK: return %[[RESULT]] +func @pad_and_transfer_write_dynamic_static( + %arg0: tensor, %arg1: vector<7x9xf32>, %size: index, %padding: index) -> tensor { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %s = subtensor %arg0[0, 0] [%size, 6] [1, 1] + : tensor to tensor + %0 = linalg.pad_tensor %s low[0, 0] high[%padding, 7] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor to tensor + %1 = vector.transfer_write %arg1, %0[%c0, %c0] + : vector<7x9xf32>, tensor + %2 = subtensor %1[0, 0] [%size, 6] [1, 1] : tensor to tensor + return %2 : tensor +} + +// ----- + +// CHECK-LABEL: func @pad_and_subtensor_insert +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> +// CHECK: return %[[WRITE]] +func @pad_and_subtensor_insert( + %arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<7x9xf32> + %r = subtensor_insert %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32> + return %r : tensor<12x13xf32> +} + +// ----- + // CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> // CHECK-LABEL: func @sum_exp