diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -671,6 +671,75 @@ return result; } +/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and +/// SubTensorInsertOp. For now, only constant padding values are supported. +/// Note: This rewrite is not yet a vectorization, but some of the generated ops +/// may be vectorized down the line (e.g., FillOp). +/// TODO: If there is enough static shape information, generate TransferReadOps +/// and TransferWriteOps instead of SubTensorInsertOp. +struct GenericPadTensorOpVectorizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const final { + // Given an OpFoldResult, return an index-typed value. + auto getIdxValue = [&](OpFoldResult ofr) { + if (auto val = ofr.dyn_cast()) + return val; + return rewriter.create( + padOp.getLoc(), getIntFromAttr(ofr.get())).getResult(); + }; + + // Pad value must be a constant. + auto padValue = padOp.getConstantPaddingValue(); + if (!padValue) return failure(); + + auto resultType = padOp.getResultType(); + // Compute size of InitTensorOp. Any combination of static/dynamic is + // supported. + SmallVector dynSizes; + SmallVector staticSizes; + for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { + if (resultType.isDynamicDim(dim)) { + auto srcSize = rewriter.createOrFold( + padOp.getLoc(), padOp.source(), dim); + // Add low and high padding value. + auto plusLow = rewriter.createOrFold( + padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); + auto plusHigh = rewriter.createOrFold( + padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); + dynSizes.push_back(plusHigh); + } + staticSizes.push_back(resultType.getDimSize(dim)); + } + + Value init = rewriter.create( + padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); + Value fill = + rewriter.create(padOp.getLoc(), init, padValue).result(); + + auto sourceType = padOp.getSourceType(); + // Compute size of source of PadTensorOp. + SmallVector srcSizes; + for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { + if (sourceType.isDynamicDim(dim)) { + srcSizes.push_back(rewriter.createOrFold( + padOp.getLoc(), padOp.source(), dim)); + } else { + srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); + } + } + // Strides of SubTensorInsertOp are all 1. + SmallVector strides(sourceType.getRank(), + rewriter.getIndexAttr(1)); + rewriter.replaceOpWithNewOp( + padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); + + return success(); + } +}; + /// Base pattern for rewriting PadTensorOps whose result is consumed by a given /// operation type OpTy. template @@ -949,14 +1018,13 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { - // TODO: Canonicalizer handles simple cases where low = 0 and high = 0, but a - // generic vectorization pattern is still missing. - + patterns.add( + patterns.getContext(), baseBenefit); // Try these specialized patterns first before resorting to the generic one. patterns.add( - patterns.getContext(), baseBenefit); + patterns.getContext(), baseBenefit.getBenefit() + 1); } // TODO: cleanup all the convolution vectorization patterns. diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -512,21 +512,44 @@ // ----- -// CHECK-LABEL: func @pad_static_high_padding -// CHECK: linalg.pad_tensor -func @pad_static_high_padding(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { - %0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 1, 0] { +// CHECK-LABEL: func @pad_static( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x?x2xf32>, %[[PAD:.*]]: f32 +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> +// CHECK-DAG: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x3x4xf32> +// CHECK: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]]{{.*}} : vector<2x3x4xf32>, tensor<2x3x4xf32> +// CHECK-DAG: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK: %[[RESULT:.*]] = subtensor_insert %[[ARG0]] into %2[0, 0, 2] [2, %[[DIM1]], 2] [1, 1, 1] : tensor<2x?x2xf32> into tensor<2x3x4xf32> +// CHECK: return %[[RESULT]] +func @pad_static(%arg0: tensor<2x?x2xf32>, %pad_value: f32) -> tensor<2x3x4xf32> { + %0 = linalg.pad_tensor %arg0 low[0, 0, 2] high[0, 1, 0] { ^bb0(%arg1: index, %arg2: index, %arg3: index): linalg.yield %pad_value : f32 - } : tensor to tensor<2x3x4xf32> + } : tensor<2x?x2xf32> to tensor<2x3x4xf32> return %0 : tensor<2x3x4xf32> } // ----- -// CHECK-LABEL: func @pad_dynamic -// CHECK: linalg.pad_tensor -func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, +// CHECK-LABEL: func @pad_static_dynamic( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x2x2x?xf32>, %[[LOW:.*]]: index, %[[HIGH:.*]]: index +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-DAG: %[[C3:.*]] = constant 3 : index +// CHECK-DAG: %[[C5:.*]] = constant 5 : index +// CHECK: %[[V0:.*]] = addi %[[LOW]], %[[C2]] : index +// CHECK: %[[V1:.*]] = addi %[[V0]], %[[C3]] : index +// CHECK: %[[V2:.*]] = addi %[[HIGH]], %[[C5]] : index +// CHECK: %[[DIM3:.*]] = memref.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> +// CHECK: %[[V4:.*]] = addi %[[DIM3]], %[[C3]] : index +// CHECK: %[[V5:.*]] = addi %[[V4]], %[[C2]] : index +// CHECK: %[[INIT:.*]] = linalg.init_tensor [6, %[[V1]], %[[V2]], %[[V5]]] : tensor<6x?x?x?xf32> +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]], %{{.*}}) : tensor<6x?x?x?xf32>, f32 -> tensor<6x?x?x?xf32> +// CHECK: %[[SRCDIM:.*]] = memref.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> +// CHECK: %[[RESULT:.*]] = subtensor_insert %[[SRC]] into %[[FILL]][2, %[[LOW]], 3, 3] [1, 2, 2, %[[SRCDIM]]] [1, 1, 1, 1] : tensor<1x2x2x?xf32> into tensor<6x?x?x?xf32> +// CHECK: return %[[RESULT]] +func @pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, %pad_value: f32) -> tensor<6x?x?x?xf32> { %0 = linalg.pad_tensor %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] { ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):