diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -168,6 +168,7 @@ custom($region, ref(type($input)), ref(type($input))) }]; + let hasCanonicalizer = 1; let hasFolder = 1; let skipDefaultBuilders = 1; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -426,6 +426,31 @@ SideEffects::DefaultResource::get()); } +namespace { +/// Remove copy operations that copy data inplace. Requirements are: +/// 1) The input and output values are identical. +/// 2) The input and output permutation maps are identical. +struct EraseIdentityCopyOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CopyOp copyOp, + PatternRewriter &rewriter) const override { + assert(copyOp.hasBufferSemantics()); + if (copyOp.input() == copyOp.output() && + copyOp.inputPermutation() == copyOp.outputPermutation()) { + rewriter.eraseOp(copyOp); + return success(); + } + return failure(); + } +}; +} // namespace + +void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// @@ -2615,15 +2640,6 @@ LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - if (auto copyOp = dyn_cast(*op)) { - assert(copyOp.hasBufferSemantics()); - if (copyOp.input() == copyOp.output() && - copyOp.inputPermutation() == copyOp.outputPermutation()) { - rewriter.eraseOp(op); - return success(); - } - } - if (!isa(op)) return failure(); if (!op.hasTensorSemantics()) diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -661,6 +661,20 @@ // ----- +// CHECK-LABEL: @self_copy_with_permutation +func @self_copy_with_permutation(%arg0 : memref<2x3x?x4xf32>) { + +// CHECK: linalg.copy + linalg.copy(%arg0, %arg0) + {inputPermutation = affine_map<(i, j, k, l) -> (j, k, i, l)>, + outputPermuation = affine_map<(i, j, k, l) -> (i, j, k, l)>} : memref<2x3x?x4xf32>, memref<2x3x?x4xf32> + +// CHECK: return + return +} + +// ----- + // CHECK-LABEL: func @fold_fill_reshape() func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = constant 0.0 : f32