diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -239,6 +239,22 @@ return isWritten; } +static void annotateFuncArgAccess(FuncOp funcOp, BlockArgument bbArg, + bool isRead, bool isWritten) { + OpBuilder b(funcOp.getContext()); + Attribute accessType; + if (isRead && isWritten) { + accessType = b.getStringAttr("read-write"); + } else if (isRead) { + accessType = b.getStringAttr("read"); + } else if (isWritten) { + accessType = b.getStringAttr("write"); + } else { + accessType = b.getStringAttr("none"); + } + funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); +} + /// Determine which FuncOp bbArgs are read and which are written. If this /// PostAnalysisStepFn is run on a function with unknown ops, it will /// conservatively assume that such ops bufferize to a read + write. @@ -263,9 +279,13 @@ for (BlockArgument bbArg : funcOp.getArguments()) { if (!bbArg.getType().isa()) continue; - if (state.isValueRead(bbArg)) + bool isRead = state.isValueRead(bbArg); + bool isWritten = isValueWritten(bbArg, state, aliasInfo); + if (state.getOptions().testAnalysisOnly) + annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); + if (isRead) moduleState.readBbArgs.insert(bbArg); - if (isValueWritten(bbArg, state, aliasInfo)) + if (isWritten) moduleState.writtenBbArgs.insert(bbArg); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -11,9 +11,11 @@ // ----- -// CHECK-LABEL: func @extract_slice_fun +// CHECK-LABEL: func @extract_slice_fun( func @extract_slice_fun(%A : tensor {linalg.inplaceable = false}, +// CHECK-SAME: bufferization.access = "read" %B : tensor {linalg.inplaceable = true}) +// CHECK-SAME: bufferization.access = "read" -> (tensor<4xf32>, tensor<8xf32>) { // tensor.extract_slice is not used in a write, it is not compelled to @@ -33,10 +35,13 @@ // ----- -// CHECK-LABEL: func @insert_slice_fun +// CHECK-LABEL: func @insert_slice_fun( func @insert_slice_fun(%A : tensor {linalg.inplaceable = false}, +// CHECK-SAME: bufferization.access = "read" %B : tensor {linalg.inplaceable = true}, +// CHECK-SAME: bufferization.access = "read-write" %C : tensor<4xf32> {linalg.inplaceable = false}) +// CHECK-SAME: bufferization.access = "read" -> (tensor, tensor) { // must bufferize out of place. @@ -56,9 +61,11 @@ // ----- -// CHECK-LABEL: func @conflict_on_B +// CHECK-LABEL: func @conflict_on_B( func @conflict_on_B(%A : tensor<4x4xf32> {linalg.inplaceable = true}, +// CHECK-SAME: bufferization.access = "read" %B : tensor<4x4xf32> {linalg.inplaceable = true}) +// CHECK-SAME: bufferization.access = "read-write" -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) { // matmul output operand interferes with input operand. @@ -93,10 +100,12 @@ // ----- -// CHECK-LABEL: func @extract_slice_extract_slice +// CHECK-LABEL: func @extract_slice_extract_slice( func @extract_slice_extract_slice( %A : tensor {linalg.inplaceable = true}, +// CHECK-SAME: bufferization.access = "read" %B : tensor {linalg.inplaceable = false}) +// CHECK-SAME: bufferization.access = "read" -> (tensor<2xf32>, tensor<2xf32>) { // tensor.extract_slice is not used in a write, it is not compelled to @@ -120,14 +129,20 @@ // ----- -// CHECK-LABEL: func @insert_slice_insert_slice +// CHECK-LABEL: func @insert_slice_insert_slice( func @insert_slice_insert_slice( %A : tensor {linalg.inplaceable = true}, +// CHECK-SAME: bufferization.access = "read-write" %A2 : tensor<4xf32> {linalg.inplaceable = true}, +// CHECK-SAME: bufferization.access = "read-write" %A3 : tensor<2xf32> {linalg.inplaceable = true}, +// CHECK-SAME: bufferization.access = "read" %B : tensor {linalg.inplaceable = false}, +// CHECK-SAME: bufferization.access = "read" %B2 : tensor<4xf32> {linalg.inplaceable = false}, +// CHECK-SAME: bufferization.access = "read" %B3 : tensor<2xf32> {linalg.inplaceable = false}) +// CHECK-SAME: bufferization.access = "read" -> (tensor, tensor) { // CHECK: {__inplace_operands_attr__ = ["true", "true"]} @@ -888,12 +903,16 @@ // prioritizing the tensor.insert_slice ops. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @insert_slice_chain( func @insert_slice_chain( %v1: vector<32x90xf32>, %v2: vector<30x90xf32>, %arg0: tensor<62x126xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, +// CHECK-SAME: bufferization.access = "none" %arg1: tensor<126x90xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, +// CHECK-SAME: bufferization.access = "none" %arg2: tensor<62x90xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) +// CHECK-SAME: bufferization.access = "write" -> tensor<62x90xf32> attributes {passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} { %c0 = arith.constant 0 : index @@ -968,10 +987,13 @@ iterator_types = ["parallel"] } -// CHECK-LABEL: func @linalg_op_same_out_tensors +// CHECK-LABEL: func @linalg_op_same_out_tensors( func @linalg_op_same_out_tensors( %t1: tensor {linalg.inplaceable = true}, - %t2: tensor {linalg.inplaceable = true}) -> (tensor, tensor){ +// CHECK-SAME: bufferization.access = "read-write" + %t2: tensor {linalg.inplaceable = true}) +// CHECK-SAME: bufferization.access = "write" + -> (tensor, tensor){ // CHECK: linalg.generic // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"] @@ -999,10 +1021,12 @@ iterator_types = ["parallel"] } -// CHECK-LABEL: func @linalg_op_same_out_tensors_2 +// CHECK-LABEL: func @linalg_op_same_out_tensors_2( func @linalg_op_same_out_tensors_2( %t1: tensor {linalg.inplaceable = true}, +// CHECK-SAME: bufferization.access = "read-write" %t2: tensor {linalg.inplaceable = true}) +// CHECK-SAME: bufferization.access = "write" -> (tensor, tensor, tensor){ // CHECK: linalg.generic