diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -76,6 +76,9 @@ %c = bufferization.alloc_tensor(%d1, %d2) size_hint = %noe : tensor ``` + + Note: An `alloc_tensor` with a `copy` should also be expressed as an + `alloc_tensor` without `copy`, followed by a `copy_tensor`. }]; let arguments = (ins Variadic:$dynamic_sizes, @@ -202,6 +205,46 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// CopyTensorOp +//===----------------------------------------------------------------------===// + +def Bufferization_CopyTensorOp : Bufferization_Op<"copy_tensor", + [BufferizableOpInterface, SameOperandsAndResultType, + DeclareOpInterfaceMethods]> { + let summary = "copy a tensor"; + + let description = [{ + Copy the contents of the source tensor into the destination tensor. This + operation is guaranteed to bufferize to a memory copy. + }]; + + let arguments = (ins AnyTensor:$source, + AnyTensor:$dest); + + let results = (outs AnyTensor:$result); + + let extraClassDeclaration = [{ + LogicalResult bufferize(RewriterBase &rewriter, + const BufferizationOptions &options); + + bool bufferizesToMemoryRead(OpOperand &opOperand, + const AnalysisState &state); + + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const AnalysisState &state); + + AliasingOpResultList getAliasingOpResults( + OpOperand &opOperand, const AnalysisState &state); + + RankedTensorType getType() { + return ::llvm::cast(getResult().getType()); + } + }]; + + let assemblyFormat = "$source `,` $dest attr-dict `:` type($source)"; +} + //===----------------------------------------------------------------------===// // DeallocTensorOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -444,6 +444,49 @@ return getOperand(getIndexOfDynamicSize(idx)); } +//===----------------------------------------------------------------------===// +// CopyTensorOp +//===----------------------------------------------------------------------===// + +bool CopyTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, + const AnalysisState &state) { + if (&opOperand == &getOperation()->getOpOperand(0) /*source*/) + return true; + return false; +} + +bool CopyTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, + const AnalysisState &state) { + if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/) + return true; + return false; +} + +AliasingOpResultList +CopyTensorOp::getAliasingOpResults(OpOperand &opOperand, + const AnalysisState &state) { + if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/) + return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; + return {}; +} + +LogicalResult CopyTensorOp::bufferize(RewriterBase &rewriter, + const BufferizationOptions &options) { + FailureOr buffer = getBuffer(rewriter, getDest(), options); + if (failed(buffer)) + return failure(); + rewriter.create(getLoc(), getSource(), *buffer); + replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer); + return success(); +} + +LogicalResult CopyTensorOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); + reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest()); + return success(); +} + //===----------------------------------------------------------------------===// // CloneOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -208,3 +208,19 @@ %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32> return %0 : tensor<*xi32> } + +// ----- + +// CHECK-LABEL: func @tensor_copy( +// CHECK-SAME: %[[arg0:.*]]: tensor<5xf32>) +func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> { + // CHECK: %[[m:.*]] = bufferization.to_memref %[[arg0]] + // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> + // CHECK: memref.copy %[[m]], %[[alloc]] + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] + // CHECK: memref.dealloc %[[alloc]] + // CHECK: return %[[r]] + %dest = bufferization.alloc_tensor() : tensor<5xf32> + %0 = bufferization.copy_tensor %arg0, %dest : tensor<5xf32> + return %0 : tensor<5xf32> +} diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -95,3 +95,11 @@ // expected-error @+1{{attribute '"bufferization.writable"' not supported as an op attribute by the bufferization dialect}} arith.constant {bufferization.writable = true} 0 : index } + +// ----- + +// expected-note @below{{prior use here}} +func.func @invalid_tensor_copy(%arg0: tensor, %arg1: tensor<5xf32>) { + // expected-error @below{{expects different type than prior uses: 'tensor' vs 'tensor<5xf32>'}} + bufferization.copy_tensor %arg0, %arg1 : tensor +} diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -57,3 +57,11 @@ bufferization.dealloc_tensor %arg0 : tensor<4xi32> return } + +// CHECK-LABEL: func @test_copy_tensor_op +func.func @test_copy_tensor_op(%arg0: tensor, %arg1: tensor) + -> tensor { + // CHECK: bufferization.copy_tensor {{.*}} : tensor + %1 = bufferization.copy_tensor %arg0, %arg1 : tensor + return %1 : tensor +}