diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -784,6 +784,141 @@ } }; +/// Rewrite use of PadTensorOp result in TransferWriteOp. +/// This pattern rewrites TransferWriteOps that write to a padded tensor value, +/// where the same amount of padding is immediately removed again after the +/// write. In such cases, the TransferWriteOp can write to the non-padded tensor +/// value and apply out-of-bounds masking. E.g.: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %1 = linalg.pad_tensor %0 ... : tensor to tensor<17x5xf32> +/// %2 = vector.transfer_write %vec, %1[...] +/// : vector<17x5xf32>, tensor<17x5xf32> +/// %r = subtensor %2[0, 0] [%s0, %s1] [1, 1] +/// : tensor<17x5xf32> to tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor +/// ``` +/// Note: It is important that the SubTensorOp %r resizes the result of the +/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even +/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from +/// %r's old dimensions. +/// +/// This rewrite is possible if: +/// - Low padding is static 0. +/// - `xferOp` has exactly one use, which is a SubTensorOp. This SubTensorOp +/// trims the same amount of padding that was added beforehand. +/// - Single, scalar padding value. +struct PadTensorOpVectorizationWithTransferWritePattern + : public VectorizePadTensorOpUserPattern { + using VectorizePadTensorOpUserPattern + ::VectorizePadTensorOpUserPattern; + + LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, + vector::TransferWriteOp xferOp) const override { + // Low padding must be static 0. + if (!padOp.hasZeroLowPad()) return failure(); + // Pad value must be a constant. + auto padValue = padOp.getConstantPaddingValue(); + if (!padValue) return failure(); + // TransferWriteOp result must be directly consumed by a SubTensorOp. + if (!xferOp->hasOneUse()) return failure(); + auto trimPadding = dyn_cast(*xferOp->user_begin()); + if (!trimPadding) return failure(); + // Only static zero offsets supported when trimming padding. + if (!trimPadding.hasZeroOffset()) return failure(); + // trimPadding must remove the amount of padding that was added earlier. + if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure(); + + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + auto newXferOp = rewriter.replaceOpWithNewOp( + xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(), + xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(), + rewriter.getBoolArrayAttr(inBounds)); + rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); + + return success(); + } + + /// Check if `beforePadding` and `afterTrimming` have the same tensor size, + /// i.e., same dimensions. + /// + /// Dimensions may be static, dynamic or mix of both. In case of dynamic + /// dimensions, this function tries to infer the (static) tensor size by + /// looking at the defining op and utilizing op-specific knowledge. + /// + /// This is a conservative analysis. In case equal tensor sizes cannot be + /// proven statically, this analysis returns `false` even though the tensor + /// sizes may turn out to be equal at runtime. + bool hasSameTensorSize(Value beforePadding, SubTensorOp afterTrimming) const { + // If the input to PadTensorOp is a CastOp, try with with both CastOp result + // and CastOp operand. + if (auto castOp = beforePadding.getDefiningOp()) + if (hasSameTensorSize(castOp.source(), afterTrimming)) return true; + + auto t1 = beforePadding.getType().dyn_cast(); + auto t2 = afterTrimming.getType().dyn_cast(); + // Only RankedTensorType supported. + if (!t1 || !t2) return false; + // Rank of both values must be the same. + if (t1.getRank() != t2.getRank()) return false; + + // All static dimensions must be the same. Mixed cases (e.g., dimension + // static in `t1` but dynamic in `t2`) are not supported. + for (unsigned i = 0; i < t1.getRank(); ++i) { + if (t1.isDynamicDim(i) != t2.isDynamicDim(i)) + return false; + if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i)) + return false; + } + + // Nothing more to check if all dimensions are static. + if (t1.getNumDynamicDims() == 0) return true; + + // All dynamic sizes must be the same. The only supported case at the moment + // is when `beforePadding` is a SubTensorOp (or a cast thereof). + + // Apart from CastOp, only SubTensorOp is supported. + auto beforeSubtensor = beforePadding.getDefiningOp(); + if (!beforeSubtensor) return false; + + assert(static_cast(t1.getRank()) + == beforeSubtensor.getMixedSizes().size()); + assert(static_cast(t2.getRank()) + == afterTrimming.getMixedSizes().size()); + + for (unsigned i = 0; i < t1.getRank(); ++i) { + // Skip static dimensions. + if (!t1.isDynamicDim(i)) continue; + auto size1 = beforeSubtensor.getMixedSizes()[i]; + auto size2 = afterTrimming.getMixedSizes()[i]; + + // Case 1: Same value or same constant int. + if (isEqualConstantIntOrValue(size1, size2)) continue; + + // Other cases: Take a deeper look at defining ops of values. + auto v1 = size1.dyn_cast(); + auto v2 = size2.dyn_cast(); + if (!v1 || !v2) return false; + + // Case 2: Both values are identical AffineMinOps. (Should not happen if + // CSE is run.) + auto minOp1 = v1.getDefiningOp(); + auto minOp2 = v2.getDefiningOp(); + if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() + && minOp1.operands() == minOp2.operands()) continue; + + // Add additional cases as needed. + } + + // All tests passed. + return true; + } +}; + /// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.: /// ``` /// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> @@ -807,8 +942,8 @@ /// - Single, scalar padding value. struct PadTensorOpVectorizationWithSubTensorInsertPattern : public VectorizePadTensorOpUserPattern { - using VectorizePadTensorOpUserPattern< - SubTensorInsertOp>::VectorizePadTensorOpUserPattern; + using VectorizePadTensorOpUserPattern + ::VectorizePadTensorOpUserPattern; LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, SubTensorInsertOp insertOp) const override { @@ -864,6 +999,7 @@ patterns.getContext(), baseBenefit); // Try these specialized patterns first before resorting to the generic one. patterns.add( patterns.getContext(), baseBenefit.getBenefit() + 1); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -580,6 +580,54 @@ // ----- + +// CHECK-LABEL: func @pad_and_transfer_write_static +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: vector<7x9xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32> +// CHECK: return %[[RESULT]] +func @pad_and_transfer_write_static( + %arg0: tensor<5x6xf32>, %arg1: vector<7x9xf32>) -> tensor<5x6xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<10x13xf32> + %1 = vector.transfer_write %arg1, %0[%c0, %c0] + : vector<7x9xf32>, tensor<10x13xf32> + %2 = subtensor %1[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32> + return %2 : tensor<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: vector<7x9xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[SUB:.*]] = subtensor %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor to tensor +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor +// CHECK: return %[[RESULT]] +func @pad_and_transfer_write_dynamic_static( + %arg0: tensor, %arg1: vector<7x9xf32>, %size: index, %padding: index) -> tensor { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %s = subtensor %arg0[0, 0] [%size, 6] [1, 1] + : tensor to tensor + %0 = linalg.pad_tensor %s low[0, 0] high[%padding, 7] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor to tensor + %1 = vector.transfer_write %arg1, %0[%c0, %c0] + : vector<7x9xf32>, tensor + %2 = subtensor %1[0, 0] [%size, 6] [1, 1] : tensor to tensor + return %2 : tensor +} + +// ----- + // CHECK-LABEL: func @pad_and_subtensor_insert // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32> // CHECK-NOT: linalg.pad_tensor