diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -671,52 +671,6 @@ return result; } -/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and -/// TransferWriteOp. For now, this only applies when all low and high paddings -/// are determined to be zero. -struct GenericPadTensorOpVectorizationPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(PadTensorOp padOp, - PatternRewriter &rewriter) const override { - /// Given an OpFoldResult, return true if its value is guaranteed to be a - /// zero integer. - auto isZeroInt = [&](OpFoldResult ofr) { - return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); }; - // Low padding must be static 0. - if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) return failure(); - // High padding must be static 0. - if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure(); - // Pad value must be a constant. - auto padValue = padOp.getConstantPaddingValue(); - if (!padValue) return failure(); - - // Bail on non-static shapes. - auto resultShapedType = padOp.result().getType().cast(); - if (!resultShapedType.hasStaticShape()) - return failure(); - VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); - if (!vectorType) - return failure(); - - // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] + - // TransferWriteOp@[0..0]. - SmallVector indices( - resultShapedType.getRank(), - rewriter.create(padOp.getLoc(), 0)); - Value read = rewriter.create( - padOp.getLoc(), vectorType, padOp.source(), indices, padValue); - Value init = rewriter.create( - padOp.getLoc(), resultShapedType.getShape(), - resultShapedType.getElementType()); - rewriter.replaceOpWithNewOp(padOp, read, init, - indices); - - return success(); - } -}; - /// Base pattern for rewriting PadTensorOps whose result is consumed by a given /// operation type OpTy. template @@ -995,13 +949,14 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { - patterns.add( - patterns.getContext(), baseBenefit); + // TODO: Canonicalizer handles simple cases where low = 0 and high = 0, but a + // generic vectorization pattern is still missing. + // Try these specialized patterns first before resorting to the generic one. patterns.add( - patterns.getContext(), baseBenefit.getBenefit() + 1); + patterns.getContext(), baseBenefit); } // TODO: cleanup all the convolution vectorization patterns. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1148,3 +1148,21 @@ // CHECK-LABEL: @tensor_pad_cast // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32> // CHECK: return %[[ARG0]] + +// ----- + +// CHECK-LABEL: func @pad_static_zero_cast( +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<2x3x4xf32> +// CHECK: return %[[RESULT]] +func @pad_static_zero_cast(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { + %c0 = constant 0 : index + %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + linalg.yield %pad_value : f32 + } : tensor to tensor<2x3x4xf32> + + return %0 : tensor<2x3x4xf32> +} + diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -512,27 +512,6 @@ // ----- -// CHECK-LABEL: func @pad_static -// CHECK-NOT: linalg.pad_tensor -func @pad_static(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] - // CHECK-SAME: : tensor, vector<2x3x4xf32> - // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> - // CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] - // CHECK-SAME: {in_bounds = [true, true, true]} : vector<2x3x4xf32>, tensor<2x3x4xf32> - %c0 = constant 0 : index - %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] { - ^bb0(%arg1: index, %arg2: index, %arg3: index): - linalg.yield %pad_value : f32 - } : tensor to tensor<2x3x4xf32> - - // CHECK: return %[[WRITTEN]] : tensor<2x3x4xf32> - return %0 : tensor<2x3x4xf32> -} - -// ----- - // CHECK-LABEL: func @pad_static_high_padding // CHECK: linalg.pad_tensor func @pad_static_high_padding(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> {