diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -406,6 +406,7 @@ $source `,` $target attr-dict `:` type($source) `to` type($target) }]; + let hasCanonicalizer = 1; let verifier = ?; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -438,6 +438,61 @@ return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +//===----------------------------------------------------------------------===// +// CopyOp +//===----------------------------------------------------------------------===// + +namespace { +/// If the source/target of a CopyOp is a CastOp that does not modify the shape +/// and element type, the cast can be skipped. Such CastOps only cast the layout +/// of the type. +struct FoldCopyOfCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CopyOp copyOp, + PatternRewriter &rewriter) const override { + bool modified = false; + + // Check source. + if (auto castOp = copyOp.source().getDefiningOp()) { + auto fromType = castOp.source().getType().dyn_cast(); + auto toType = castOp.source().getType().dyn_cast(); + + if (fromType && toType) { + if (fromType.getShape() == toType.getShape() && + fromType.getElementType() == toType.getElementType()) { + rewriter.updateRootInPlace( + copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); + modified = true; + } + } + } + + // Check target. + if (auto castOp = copyOp.target().getDefiningOp()) { + auto fromType = castOp.source().getType().dyn_cast(); + auto toType = castOp.source().getType().dyn_cast(); + + if (fromType && toType) { + if (fromType.getShape() == toType.getShape() && + fromType.getElementType() == toType.getElementType()) { + rewriter.updateRootInPlace( + copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); + modified = true; + } + } + } + + return success(modified); + } +}; +} // namespace + +void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -31,7 +31,7 @@ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]] - // CHECK: memref.copy %[[A_memref]], %[[casted]] + // CHECK: memref.copy %[[A_memref]], %[[alloc]] // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor @@ -95,4 +95,4 @@ scf.yield %10 : tensor } return %5: tensor -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -159,7 +159,7 @@ // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]] - // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]] + // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[alloc]] // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]] %0 = tensor.insert %f into %t1[%c0] : tensor // CHECK-TENSOR: return %[[casted_tensor]] @@ -177,7 +177,7 @@ // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] // CHECK-SCF: %[[alloc:.*]] = memref.alloc // CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]] - // CHECK-SCF: memref.copy %[[t1_memref]], %[[casted]] + // CHECK-SCF: memref.copy %[[t1_memref]], %[[alloc]] // CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) { %0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor { // CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]] diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -510,3 +510,18 @@ // CHECK-LABEL: func @atomicrmw_cast_fold // CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32 + +// ----- + +#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +func @copy_of_cast(%m1: memref, %m2: memref<*xf32>) { + %casted1 = memref.cast %m1 : memref to memref + %casted2 = memref.cast %m2 : memref<*xf32> to memref + memref.copy %casted1, %casted2 : memref to memref + return +} + +// CHECK-LABEL: func @copy_of_cast( +// CHECK-SAME: %[[m1:.*]]: memref, %[[m2:.*]]: memref<*xf32> +// CHECK: %[[casted2:.*]] = memref.cast %[[m2]] +// CHECK: memref.copy %[[m1]], %[[casted2]]