diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -291,6 +291,8 @@ "ArrayRef":$low, "ArrayRef":$high, CArg<"ArrayRef", "{}">:$attrs)> ]; + + let hasCanonicalizer = 1; } def Linalg_RangeOp : diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1119,6 +1119,28 @@ return success(); } +namespace { +// Folds linalg.pad_tensor when padding is static zeros. +struct FoldStaticZeroPadding : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padTensorOp, + PatternRewriter &rewriter) const override { + if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad()) + return failure(); + rewriter.replaceOpWithNewOp( + padTensorOp, padTensorOp.result().getType(), padTensorOp.source()); + return success(); + } +}; + +} // namespace + +void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1132,3 +1132,19 @@ // CHECK-NEXT: %[[SUM1:.+]] = addi %[[SUM0]], %[[ARG2]] : index // CHECK-NEXT: %[[SUM2:.+]] = addi %[[SUM1]], %[[ARG3]] : index // CHECK-NEXT: linalg.yield %[[SUM2]] : index + +// ----- + +func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor + %1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %c0] { + ^bb0(%arg1: index, %arg2: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<4x4xf32> + return %1 : tensor<4x4xf32> +} +// CHECK-LABEL: @tensor_pad_cast +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32> +// CHECK: return %[[ARG0]]