diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -674,10 +674,8 @@ /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and /// SubTensorInsertOp. For now, only constant padding values are supported. -/// Note: This rewrite is not yet a vectorization, but some of the generated ops -/// may be vectorized down the line (e.g., FillOp). -/// TODO: If there is enough static shape information, generate TransferReadOps -/// and TransferWriteOps instead of SubTensorInsertOp. +/// If there is enough static type information, TransferReadOps and +/// TransferWriteOps may be generated instead of SubTensorInsertOps. struct GenericPadTensorOpVectorizationPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -724,6 +722,20 @@ rewriter.create(padOp.getLoc(), init, padValue).result(); auto sourceType = padOp.getSourceType(); + + // Copy of source with static shape can be vectorized. + if (sourceType.hasStaticShape()) { + auto vecType = + VectorType::get(sourceType.getShape(), sourceType.getElementType()); + vectorizeStaticShapeSource(rewriter, padOp, fill, vecType); + return success(); + } + + // TODO: Vectorize dynamic source but static destination. + + // Neither source type nor PadTensorOp result type have static shape. Such + // PadTensorOps cannot be vectorized. Generate a SubTensorInsertOp instead. + // Compute size of source of PadTensorOp. SmallVector srcSizes; for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { @@ -742,6 +754,25 @@ return success(); } + + /// Vectorize the copying of a PadTensorOp's source that has static shape. + void vectorizeStaticShapeSource(PatternRewriter &rewriter, PadTensorOp padOp, + Value dest, VectorType vecType) const { + // Generate TransferReadOp. + SmallVector readIndices( + vecType.getRank(), rewriter.create(padOp.getLoc(), 0)); + auto read = rewriter.create( + padOp.getLoc(), vecType, padOp.source(), readIndices); + + // Generate TransferWriteOp. The destination dimensions may be dynamic, but + // the write cannot be out-of-bounds. (A large enough destination tensor is + // allocated in this pattern.) + auto writeIndices = + ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad()); + SmallVector inBounds(vecType.getRank(), true); + rewriter.replaceOpWithNewOp( + padOp, read, dest, writeIndices, inBounds); + } }; /// Base pattern for rewriting PadTensorOps whose result is consumed by a given diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -532,6 +532,27 @@ // ----- +// CHECK-LABEL: func @pad_static_source( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x5x2xf32>, %[[PAD:.*]]: f32 +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 6, 4] : tensor<2x6x4xf32> +// CHECK: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x6x4xf32> +// CHECK: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<2x6x4xf32>, tensor<2x6x4xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : tensor<2x5x2xf32>, vector<2x5x2xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x5x2xf32>, tensor<2x6x4xf32> +// CHECK: return %[[WRITE]] +func @pad_static_source(%arg0: tensor<2x5x2xf32>, %pad_value: f32) -> tensor<2x6x4xf32> { + %0 = linalg.pad_tensor %arg0 low[0, 0, 2] high[0, 1, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + linalg.yield %pad_value : f32 + } : tensor<2x5x2xf32> to tensor<2x6x4xf32> + return %0 : tensor<2x6x4xf32> +} + +// ----- + // CHECK-LABEL: func @pad_static_dynamic( // CHECK-SAME: %[[SRC:.*]]: tensor<1x2x2x?xf32>, %[[LOW:.*]]: index, %[[HIGH:.*]]: index // CHECK-NOT: linalg.pad_tensor