diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2595,6 +2595,15 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + if (auto copyOp = dyn_cast(op)) { + assert(copyOp.hasBufferSemantics()); + if (copyOp.input() == copyOp.output() && + copyOp.inputPermutation() == copyOp.outputPermutation()) { + rewriter.eraseOp(op); + return success(); + } + } + if (!isa(op)) return failure(); LinalgOp genericOp = cast(op); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -789,3 +789,15 @@ // CHECK: %[[INSERTED:.+]] = subtensor_insert %{{.+}} into %[[FILL]] // CHECK: %[[RESULT:.+]] = tensor.cast %[[INSERTED]] // CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: @self_copy +func @self_copy(%arg0 : memref<2x3x?x4xf32>) { + +// CHECK-NOT: linalg.copy + linalg.copy(%arg0, %arg0): memref<2x3x?x4xf32>, memref<2x3x?x4xf32> + +// CHECK: return + return +}