diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -58,10 +58,12 @@ DefaultValuedAttr:$bufferize_function_boundaries, DefaultValuedAttr:$create_deallocs, DefaultValuedAttr:$test_analysis_only, - DefaultValuedAttr:$print_conflicts); + DefaultValuedAttr:$print_conflicts, + DefaultValuedAttr:$memcpy_op); let results = (outs TransformHandleTypeInterface:$transformed); + let hasVerifier = 1; let assemblyFormat = [{ (`layout` `{` $function_boundary_type_conversion^ `}`)? $target attr-dict `:` functional-type($target, results) diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -25,6 +26,12 @@ // OneShotBufferizeOp //===----------------------------------------------------------------------===// +LogicalResult transform::OneShotBufferizeOp::verify() { + if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy") + return emitOpError() << "unsupported memcpy op"; + return success(); +} + DiagnosedSilenceableFailure transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, TransformResults &transformResults, @@ -39,6 +46,19 @@ if (getFunctionBoundaryTypeConversion().has_value()) options.setFunctionBoundaryTypeConversion( *getFunctionBoundaryTypeConversion()); + if (getMemcpyOp() == "memref.copy") { + options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { + b.create(loc, from, to); + return success(); + }; + } else if (getMemcpyOp() == "linalg.copy") { + options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { + b.create(loc, from, to); + return success(); + }; + } else { + llvm_unreachable("invalid copy op"); + } auto payloadOps = state.getPayloadOps(getTarget()); for (Operation *target : payloadOps) { diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt @@ -11,6 +11,7 @@ MLIRIR MLIRBufferizationDialect MLIRBufferizationTransforms + MLIRLinalgDialect MLIRParser MLIRPDLDialect MLIRSideEffectInterfaces diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -28,6 +28,35 @@ // ----- +// Emit linalg.copy instead of memref.copy. + +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.bufferization.one_shot_bufferize %0 {memcpy_op = "linalg.copy"} : (!transform.any_op) -> !transform.any_op +} + +// CHECK-LABEL: func @test_function( +// CHECK-SAME: %[[A:.*]]: tensor +// CHECK-NOT: memref.copy +func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]] + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + + // CHECK: memref.dealloc %[[alloc]] + // CHECK: return %[[res_tensor]] + return %0 : tensor +} + +// ----- + // Test analysis of One-Shot Bufferize only. transform.sequence failures(propagate) { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -11477,6 +11477,7 @@ ":BufferizationTransformOpsIncGen", ":BufferizationTransforms", ":IR", + ":LinalgDialect", ":MemRefDialect", ":Parser", ":SideEffectInterfaces",