diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -843,8 +843,6 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); // Check all indexing maps are identity. if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { return !map.isIdentity(); })) @@ -859,6 +857,17 @@ if (!yieldOp) return failure(); + // In the buffer case, we need to check exact buffer equality. + if (genericOp.hasBufferSemantics()) { + if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 && + genericOp.getInputOperand(0)->get() == + genericOp.getOutputOperand(0)->get()) { + rewriter.eraseOp(genericOp); + return success(); + } + return failure(); + } + // Get the argument number of the returned values. That is the operand // number to use for replacing uses of this operation. SmallVector returnedArgs; @@ -876,6 +885,7 @@ resultType, returnedArg); returnedArgs.push_back(returnedArg); } + if (returnedArgs.size() != genericOp->getNumResults()) return failure(); rewriter.replaceOp(genericOp, returnedArgs); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -583,3 +583,19 @@ %r2 = tensor.dim %r, %c0 : tensor return %r2 : index } + +// ----- + +// CHECK: func @fold_self_copy +func @fold_self_copy(%0 : memref<4x16xf32>) { +// CHECK-NEXT: return + linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : memref<4x16xf32>) + outs(%0 : memref<4x16xf32>) { + ^bb0(%arg4: f32, %arg5: f32): + linalg.yield %arg4 : f32 + } + return +} diff --git a/mlir/test/Dialect/Linalg/inlining.mlir b/mlir/test/Dialect/Linalg/inlining.mlir --- a/mlir/test/Dialect/Linalg/inlining.mlir +++ b/mlir/test/Dialect/Linalg/inlining.mlir @@ -25,7 +25,8 @@ ins(%arg0 : memref) outs(%arg0 : memref) { ^bb(%0 : f32, %1 : f32) : - linalg.yield %0 : f32 + %2 = arith.addf %0, %0: f32 + linalg.yield %2 : f32 } return }