diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -444,6 +444,43 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// CopyOp +//===----------------------------------------------------------------------===// + +def CopyOp : MemRef_Op<"copy", + [CopyOpInterface, SameOperandsElementType, SameOperandsShape]> { + + let description = [{ + Copies the data from the source to the destination memref. + + Usage: + + ```mlir + memref.copy %arg0, %arg1 : memref to memref + ``` + + Source and destination are expected to have the same element type and shape. + Otherwise, the result is undefined. They may have different layouts. + }]; + + let arguments = (ins Arg:$source, + Arg:$target); + + let extraClassDeclaration = [{ + Value getSource() { return source();} + Value getTarget() { return target(); } + }]; + + let assemblyFormat = [{ + $source `,` $target attr-dict `:` type($source) `to` type($target) + }]; + + let verifier = ?; +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -215,3 +215,19 @@ %0 = memref.get_global @gv : memref<3xf32> return } + +// ----- + +func @copy_different_shape(%arg0: memref<2xf32>, %arg1: memref<3xf32>) { + // expected-error @+1 {{op requires the same shape for all operands}} + memref.copy %arg0, %arg1 : memref<2xf32> to memref<3xf32> + return +} + +// ----- + +func @copy_different_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) { + // expected-error @+1 {{op requires the same element type for all operands}} + memref.copy %arg0, %arg1 : memref<2xf32> to memref<2xf16> + return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -69,6 +69,16 @@ return } +// CHECK-LABEL: func @memref_copy +func @memref_copy() { + %0 = memref.alloc() : memref<2xf32> + %1 = memref.cast %0 : memref<2xf32> to memref<*xf32> + %2 = memref.alloc() : memref<2xf32> + %3 = memref.cast %0 : memref<2xf32> to memref<*xf32> + memref.copy %1, %3 : memref<*xf32> to memref<*xf32> + return +} + // CHECK-LABEL: func @memref_dealloc func @memref_dealloc() { %0 = memref.alloc() : memref<2xf32>