diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -279,6 +279,8 @@ "ArrayRef":$low, "ArrayRef":$high, CArg<"ArrayRef", "{}">:$attrs)> ]; + + let hasCanonicalizer = 1; } def Linalg_RangeOp : diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1119,6 +1119,34 @@ return success(); } +/// Fold away PadTensorOp where low and high padding is zero. +struct FoldEmptyPadTensorPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override { + /// Given an OpFoldResult, return true if its value is guaranteed to be a + /// zero integer. + auto isZeroInt = [&](OpFoldResult ofr) { + return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); + }; + // Fold if low and high padding is zero. + if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) + return failure(); + if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) + return failure(); + + // Insert cast to make sure that the type is not changing. + rewriter.replaceOpWithNewOp(padOp, padOp.getResultType(), + padOp.source()); + return success(); + } +}; + +void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -31,6 +31,7 @@ using namespace mlir; using namespace mlir::linalg; +using namespace mlir::linalg::detail; using llvm::dbgs; @@ -787,52 +788,6 @@ return true; } -/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and -/// TransferWriteOp. For now, this only applies when all low and high paddings -/// are determined to be zero. -struct GenericPadTensorOpVectorizationPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(PadTensorOp padOp, - PatternRewriter &rewriter) const override { - /// Given an OpFoldResult, return true if its value is guaranteed to be a - /// zero integer. - auto isZeroInt = [&](OpFoldResult ofr) { - return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); }; - // Low padding must be static 0. - if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) return failure(); - // High padding must be static 0. - if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure(); - // Pad value must be a constant. - auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); - if (!padValue) return failure(); - - // Bail on non-static shapes. - auto resultShapedType = padOp.result().getType().cast(); - if (!resultShapedType.hasStaticShape()) - return failure(); - VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); - if (!vectorType) - return failure(); - - // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] + - // TransferWriteOp@[0..0]. - SmallVector indices( - resultShapedType.getRank(), - rewriter.create(padOp.getLoc(), 0)); - Value read = rewriter.create( - padOp.getLoc(), vectorType, padOp.source(), indices, padValue); - Value init = rewriter.create( - padOp.getLoc(), resultShapedType.getShape(), - resultShapedType.getElementType()); - rewriter.replaceOpWithNewOp(padOp, read, init, - indices); - - return success(); - } -}; - /// Base pattern for rewriting PadTensorOps whose result is consumed by a given /// operation type OpTy. template @@ -1074,13 +1029,14 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { - patterns.add( - patterns.getContext(), baseBenefit); + // TODO: Canonicalizer handles simple cases where low = 0 and high = 0, but a + // generic vectorization pattern is still missing. + // Try these specialized patterns first before resorting to the generic one. patterns.add( - patterns.getContext(), baseBenefit.getBenefit() + 1); + patterns.getContext(), baseBenefit); } // TODO: cleanup all the convolution vectorization patterns. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1132,3 +1132,36 @@ // CHECK-NEXT: %[[SUM1:.+]] = addi %[[SUM0]], %[[ARG2]] : index // CHECK-NEXT: %[[SUM2:.+]] = addi %[[SUM1]], %[[ARG3]] : index // CHECK-NEXT: linalg.yield %[[SUM2]] : index + +// ----- + +// CHECK-LABEL: func @pad_static_zero_cast( +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<2x3x4xf32> +// CHECK: return %[[RESULT]] +func @pad_static_zero_cast(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { + %c0 = constant 0 : index + %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + linalg.yield %pad_value : f32 + } : tensor to tensor<2x3x4xf32> + + return %0 : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: func @pad_static_zero( +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK-NOT: linalg.pad_tensor +// CHECK: return %[[ARG0]] : tensor +func @pad_static_zero(%arg0: tensor, %pad_value: f32) -> tensor { + %c0 = constant 0 : index + %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + linalg.yield %pad_value : f32 + } : tensor to tensor + + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -512,27 +512,6 @@ // ----- -// CHECK-LABEL: func @pad_static -// CHECK-NOT: linalg.pad_tensor -func @pad_static(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] - // CHECK-SAME: : tensor, vector<2x3x4xf32> - // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> - // CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] - // CHECK-SAME: {in_bounds = [true, true, true]} : vector<2x3x4xf32>, tensor<2x3x4xf32> - %c0 = constant 0 : index - %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] { - ^bb0(%arg1: index, %arg2: index, %arg3: index): - linalg.yield %pad_value : f32 - } : tensor to tensor<2x3x4xf32> - - // CHECK: return %[[WRITTEN]] : tensor<2x3x4xf32> - return %0 : tensor<2x3x4xf32> -} - -// ----- - // CHECK-LABEL: func @pad_static_high_padding // CHECK: linalg.pad_tensor func @pad_static_high_padding(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> {