diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -486,11 +486,25 @@ return success(modified); } }; + +/// Fold memref.copy(%x, %x). +struct FoldSelfCopy : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CopyOp copyOp, + PatternRewriter &rewriter) const override { + if (copyOp.source() != copyOp.target()) + return failure(); + + rewriter.eraseOp(copyOp); + return success(); + } +}; } // namespace void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -525,3 +525,13 @@ // CHECK-SAME: %[[m1:.*]]: memref, %[[m2:.*]]: memref<*xf32> // CHECK: %[[casted2:.*]] = memref.cast %[[m2]] // CHECK: memref.copy %[[m1]], %[[casted2]] + +// ----- + +func @self_copy(%m1: memref) { + memref.copy %m1, %m1 : memref to memref + return +} + +// CHECK-LABEL: func @self_copy +// CHECK-NEXT: return