diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2583,6 +2583,15 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + if (auto copyOp = dyn_cast(op)) { + assert(copyOp.hasBufferSemantics()); + if (copyOp.input() == copyOp.output() && + copyOp.inputPermutation() == copyOp.outputPermutation()) { + rewriter.eraseOp(op); + return success(); + } + } + if (!isa(op)) return failure(); LinalgOp genericOp = cast(op); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -767,3 +767,13 @@ // CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C4]] // CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] // CHECK: return %[[C5]], %[[D1]] + +// ----- + +// CHECK-LABEL: @self_copy +func @self_copy(%arg0 : memref<2x3x?x4xf32>) { +// CHECK-NOT: linalg.copy + linalg.copy(%arg0, %arg0): memref<2x3x?x4xf32>, memref<2x3x?x4xf32> +// CHECK: return + return +}