diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -279,6 +279,8 @@ "ArrayRef":$low, "ArrayRef":$high, CArg<"ArrayRef", "{}">:$attrs)> ]; + + let hasFolder = 1; } def Linalg_RangeOp : diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1119,6 +1119,12 @@ return success(); } +OpFoldResult PadTensorOp::fold(ArrayRef) { + if (getResultType().hasStaticShape() && getResultType() == getSourceType()) + return source(); + return {}; +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -893,6 +893,22 @@ // ----- +// CHECK-LABEL: func @pad_tensor_same_static_shape( +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK: return %[[ARG0]] +func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) + -> tensor<5x6xf32> { + %cst = constant 0.000000e+00 : f32 + %0 = linalg.pad_tensor %arg0 low[%a, 0] high[0, %a] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %cst : f32 + } : tensor<5x6xf32> to tensor<5x6xf32> + return %0 : tensor<5x6xf32> +} + +// ----- + func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index) { %c1 = constant 1 : index