diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -878,6 +878,23 @@ LogicalResult matchAndRewrite(PadTensorOp padOp, PatternRewriter &rewriter) const override; + +private: + bool hasSameRutimeValue(Value v1, Value v2) const; + + LogicalResult tryVectorizeUser(PatternRewriter &rewriter, + linalg::PadTensorOp padOp, + SubTensorInsertOp insertOp, + Value padValue) const; + + LogicalResult tryVectorizeUser(PatternRewriter &rewriter, + linalg::PadTensorOp padOp, + vector::TransferReadOp xferOp, + Value padValue) const; + + LogicalResult tryVectorizeUser(PatternRewriter &rewriter, + linalg::PadTensorOp padOp, + vector::TransferWriteOp xferOp) const; }; /// Match and rewrite for the pattern: diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -694,6 +694,24 @@ if (!definingOp && padValue.cast().getOwner() == &block) return failure(); + // Shortcut: Try to avoid creating an InitTensorOp. + bool changed = false; + // Insert users in vector, because some users may be replaced/removed. + for (auto *user : llvm::to_vector<4>(padOp->getUsers())) { + if (auto insertOp = dyn_cast(user)) { + if (insertOp.source() == padOp.result()) + changed |= + succeeded(tryVectorizeUser(rewriter, padOp, insertOp, padValue)); + } else if (auto readOp = dyn_cast(user)) { + changed |= succeeded(tryVectorizeUser(rewriter, padOp, readOp, padValue)); + } else if (auto writeOp = dyn_cast(user)) { + changed |= succeeded(tryVectorizeUser(rewriter, padOp, writeOp)); + } + } + if (changed) + return success(); + // End of shortcut: Must create an InitTensorOp... + // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail. if (llvm::any_of(padOp.getMixedHighPad(), [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); })) @@ -715,6 +733,200 @@ return success(); } +/// Helper function that retrieves the value of an IntegerAttr. +static int64_t getIntFromAttr(Attribute attr) { + return attr.cast().getInt(); +} + +/// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = subtensor_insert %0 into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] +/// : tensor<17x5xf32> into tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = vector.transfer_read %src[%c0, %c0], %padding +/// : tensor, vector<17x5xf32> +/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] +/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor +/// ``` +LogicalResult PadTensorOpVectorizationPattern::tryVectorizeUser( + PatternRewriter &rewriter, linalg::PadTensorOp padOp, + SubTensorInsertOp insertOp, Value padValue) const { + auto vecType = VectorType::get(padOp.getType().getShape(), + padOp.getType().getElementType()); + unsigned vecRank = vecType.getRank(); + unsigned tensorRank = insertOp.getType().getRank(); + + // Only unit stride supported. + if (llvm::any_of(insertOp.getMixedStrides(), [](auto s) { + return s.template is() || + getIntFromAttr(s.template get()) != 1; + })) + return failure(); + + // Check if sizes match: Insert the entire tensor into most minor dims. + auto sizes = insertOp.getMixedSizes(); + for (unsigned i = 0; i < tensorRank - vecRank; ++i) { + if (sizes[i].is() || getIntFromAttr(sizes[i].get()) != 1) + return failure(); + } + for (unsigned i = tensorRank - vecRank; i < tensorRank; ++i) { + if (sizes[i].is() || + getIntFromAttr(sizes[i].get()) != + vecType.getDimSize(i + vecRank - tensorRank)) + return failure(); + } + + // Read is out-of-bounds and will be padded. + SmallVector outOfBounds(vecRank, false); + auto readMap = AffineMapAttr::get(rewriter.getMultiDimIdentityMap(vecRank)); + // Assuming that low indices of PadTensorOp are all %c0. Must use a different + // starting point + masking for the vector read when the pattern is extended. + auto read = rewriter.create( + padOp.getLoc(), vecType, padOp.source(), padOp.low(), readMap, padValue, + /*mask=*/Value(), rewriter.getBoolArrayAttr(outOfBounds)); + + // Compute indices of TransferWriteOp. + SmallVector indices; + llvm::for_each(insertOp.getMixedOffsets(), [&](auto o) { + if (o.template is()) { + indices.push_back(o.template get()); + } else { + // Convert int64 attr to index attr. + auto intAttr = + rewriter.getIndexAttr(getIntFromAttr(o.template get())); + indices.push_back(rewriter.create( + padOp.getLoc(), rewriter.getIndexType(), intAttr)); + } + }); + + // Write is fully in-bounds. + SmallVector inBounds(vecRank, true); + // Write to the most minor dimensions of the tensor. + auto writeMap = AffineMapAttr::get(AffineMap::getMinorIdentityMap( + tensorRank, vecRank, rewriter.getContext())); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getType(), read.getResult(), insertOp.dest(), indices, + writeMap, /*mask=*/Value(), rewriter.getBoolArrayAttr(inBounds)); + + return success(); +} + +/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = vector.transfer_read %0[%c0, %c0], %cst +/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> +/// ``` +/// is rewritten to: +/// ``` +/// %r = vector.transfer_read %src[%c0, %c0], %padding +/// : tensor, vector<17x5xf32> +/// ``` +/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be +/// sure that the original padding value %cst was never used. +LogicalResult PadTensorOpVectorizationPattern::tryVectorizeUser( + PatternRewriter &rewriter, linalg::PadTensorOp padOp, + vector::TransferReadOp xferOp, Value padValue) const { + if (xferOp.hasOutOfBoundsDim()) + return failure(); + if (xferOp.mask()) + return failure(); + + rewriter.updateRootInPlace(xferOp, [&]() { + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + xferOp->setAttr(xferOp.getInBoundsAttrName(), + rewriter.getBoolArrayAttr(inBounds)); + xferOp.sourceMutable().assign(padOp.source()); + xferOp.paddingMutable().assign(padValue); + }); + + return success(); +} + +/// Check if both values have the same runtime value. Used in +/// `tryVectorizeUser`. +bool PadTensorOpVectorizationPattern::hasSameRutimeValue(Value v1, + Value v2) const { + // Case 1: Both values are identical. + if (v1 == v2) + return true; + + // Case 2: Both values are identical AffineMinOps. (Should not happen if + // CSE is run.) + auto minOp1 = v1.getDefiningOp(); + auto minOp2 = v2.getDefiningOp(); + if (minOp1 && minOp2) { + return minOp1.getAffineMap() == minOp2.getAffineMap() && + minOp1.getNumOperands() == 1 && minOp2.getNumOperands() == 1 && + minOp1->getOperand(0) == minOp2->getOperand(0); + } + + return false; +} + +/// Rewrite use of PadTensorOp result in TransferWriteOp. +/// This pattern rewrites TransferWriteOps that write to a padded tensor value, +/// where the same amount of padding is immediately removed again after the +/// write. In such cases, the TransferWriteOp can write to the non-padded tensor +/// value and apply out-of-bounds masking. E.g.: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %1 = linalg.pad_tensor %0 ... : tensor to tensor<17x5xf32> +/// %2 = vector.transfer_write %vec, %1[...] +/// : vector<17x5xf32>, tensor<17x5xf32> +/// %r = subtensor %2[0, 0] [%s0, %s1] [1, 1] +/// : tensor<17x5xf32> to tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor +/// ``` +/// Note: It is important that the SubTensorOp %r resizes the result of the +/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even +/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from +/// %r's old dimensions. +LogicalResult PadTensorOpVectorizationPattern::tryVectorizeUser( + PatternRewriter &rewriter, linalg::PadTensorOp padOp, + vector::TransferWriteOp xferOp) const { + OpBuilder::InsertionGuard guard(rewriter); + + // TransferWriteOp result must be directly consumed by a SubTensorOp. + if (!xferOp->hasOneUse()) + return failure(); + auto trimPadding = dyn_cast(*xferOp->user_begin()); + if (!trimPadding) + return failure(); + // Get size of input of PadTensorOp. (Assuming that it is a SubTensorOp.) + auto *padInput = padOp.source().getDefiningOp(); + if (!padInput) + return failure(); + auto subtensorOp = dyn_cast(padInput); + if (!subtensorOp) + return failure(); + // Sizes of both SubTensorOps must be identical. + if (subtensorOp.sizes().size() != trimPadding.sizes().size()) + return failure(); + if (llvm::any_of(llvm::zip(subtensorOp.sizes(), trimPadding.sizes()), + [&](auto d) { + return !hasSameRutimeValue(std::get<0>(d), std::get<1>(d)); + })) + return failure(); + + rewriter.setInsertionPoint(xferOp); + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + auto newXferOp = rewriter.replaceOpWithNewOp( + xferOp, subtensorOp.getType(), xferOp.vector(), subtensorOp, + xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(), + rewriter.getBoolArrayAttr(inBounds)); + rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); + + return success(); +} + // TODO: cleanup all the convolution vectorization patterns. template LogicalResult ConvOpVectorization::matchAndRewrite(