diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -406,6 +406,7 @@ $source `,` $target attr-dict `:` type($source) `to` type($target) }]; + let hasCanonicalizer = 1; let verifier = ?; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -438,6 +438,61 @@ return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +//===----------------------------------------------------------------------===// +// CopyOp +//===----------------------------------------------------------------------===// + +namespace { +/// If the source/target of a CopyOp is a CastOp that does not modify the shape +/// and element type, the cast can be skipped. Such CastOps only cast the layout +/// of the type. +struct FoldCopyOfCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CopyOp copyOp, + PatternRewriter &rewriter) const override { + bool modified = false; + + // Check source. + if (auto castOp = copyOp.source().getDefiningOp()) { + auto fromType = castOp.source().getType().dyn_cast(); + auto toType = castOp.source().getType().dyn_cast(); + + if (fromType && toType) { + if (fromType.getShape() == toType.getShape() && + fromType.getElementType() == toType.getElementType()) { + rewriter.updateRootInPlace( + copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); + modified = true; + } + } + } + + // Check target. + if (auto castOp = copyOp.target().getDefiningOp()) { + auto fromType = castOp.source().getType().dyn_cast(); + auto toType = castOp.source().getType().dyn_cast(); + + if (fromType && toType) { + if (fromType.getShape() == toType.getShape() && + fromType.getElementType() == toType.getElementType()) { + rewriter.updateRootInPlace( + copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); + modified = true; + } + } + } + + return success(modified); + } +}; +} // namespace + +void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -510,3 +510,18 @@ // CHECK-LABEL: func @atomicrmw_cast_fold // CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32 + +// ----- + +#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +func @copy_of_cast(%m1: memref, %m2: memref<*xf32>) { + %casted1 = memref.cast %m1 : memref to memref + %casted2 = memref.cast %m2 : memref<*xf32> to memref + memref.copy %casted1, %casted2 : memref to memref + return +} + +// CHECK-LABEL: func @copy_of_cast( +// CHECK-SAME: %[[m1:.*]]: memref, %[[m2:.*]]: memref<*xf32> +// CHECK: %[[casted2:.*]] = memref.cast %[[m2]] +// CHECK: memref.copy %[[m1]], %[[casted2]]