diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -767,6 +767,13 @@ const ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + // "linalg.inplaceable" overrides other writability decisions. This is + // currently used for testing only. + if (BoolAttr inplaceAttr = funcOp.getArgAttrOfType( + bbArg.getArgNumber(), + BufferizableOpInterface::kInplaceableAttrName)) + return inplaceAttr.getValue(); + // In a first approximation: // ========================= // If the function is called, we can allocate on the caller side which lets @@ -775,13 +782,6 @@ if (moduleState.callerMap.find(funcOp) != moduleState.callerMap.end()) return true; - // Set the function arguments marked with inplaceable to be known as - // bufferizing to a writeable memory. - BoolAttr inplaceAttr = funcOp.getArgAttrOfType( - bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName); - if (inplaceAttr && inplaceAttr.getValue()) - return true; - // All other function arguments are not writable. return false; } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -12,7 +12,8 @@ // ----- // CHECK-LABEL: func @extract_slice_fun -func @extract_slice_fun(%A : tensor, %B : tensor {linalg.inplaceable = true}) +func @extract_slice_fun(%A : tensor {linalg.inplaceable = false}, + %B : tensor {linalg.inplaceable = true}) -> (tensor<4xf32>, tensor<8xf32>) { // tensor.extract_slice is not used in a write, it is not compelled to @@ -33,10 +34,9 @@ // ----- // CHECK-LABEL: func @insert_slice_fun -func @insert_slice_fun( - %A : tensor, - %B : tensor {linalg.inplaceable = true}, - %C : tensor<4xf32>) +func @insert_slice_fun(%A : tensor {linalg.inplaceable = false}, + %B : tensor {linalg.inplaceable = true}, + %C : tensor<4xf32> {linalg.inplaceable = false}) -> (tensor, tensor) { // must bufferize out of place. @@ -57,9 +57,8 @@ // ----- // CHECK-LABEL: func @conflict_on_B -func @conflict_on_B( - %A : tensor<4x4xf32> {linalg.inplaceable = true}, - %B : tensor<4x4xf32> {linalg.inplaceable = true}) +func @conflict_on_B(%A : tensor<4x4xf32> {linalg.inplaceable = true}, + %B : tensor<4x4xf32> {linalg.inplaceable = true}) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) { // matmul output operand interferes with input operand. @@ -96,7 +95,8 @@ // CHECK-LABEL: func @extract_slice_extract_slice func @extract_slice_extract_slice( - %A : tensor {linalg.inplaceable = true}, %B : tensor) + %A : tensor {linalg.inplaceable = true}, + %B : tensor {linalg.inplaceable = false}) -> (tensor<2xf32>, tensor<2xf32>) { // tensor.extract_slice is not used in a write, it is not compelled to @@ -125,7 +125,9 @@ %A : tensor {linalg.inplaceable = true}, %A2 : tensor<4xf32> {linalg.inplaceable = true}, %A3 : tensor<2xf32> {linalg.inplaceable = true}, - %B : tensor, %B2 : tensor<4xf32>, %B3 : tensor<2xf32>) + %B : tensor {linalg.inplaceable = false}, + %B2 : tensor<4xf32> {linalg.inplaceable = false}, + %B3 : tensor<2xf32> {linalg.inplaceable = false}) -> (tensor, tensor) { // CHECK: {__inplace_results_attr__ = ["true"]} @@ -150,7 +152,8 @@ // CHECK-LABEL: func @extract_slice_nonmatching_insert_slice func @extract_slice_nonmatching_insert_slice( %A : tensor {linalg.inplaceable = true}, - %B : tensor, %idx: index) + %B : tensor {linalg.inplaceable = false}, + %idx: index) -> (tensor, tensor) { // %r1 bufferizes inplace because %A is inplaceable. @@ -188,7 +191,7 @@ // CHECK-LABEL: func @extract_slice_matching_insert_slice func @extract_slice_matching_insert_slice( %A : tensor {linalg.inplaceable = true}, - %B : tensor) + %B : tensor {linalg.inplaceable = false}) -> (tensor, tensor) { // %r1 bufferizes inplace because %A is inplaceable. @@ -225,7 +228,9 @@ // CHECK-LABEL: @read_of_matching_insert_slice_source func @read_of_matching_insert_slice_source( - %A : tensor {linalg.inplaceable = true}, %idx : index, %idx2 : index) + %A : tensor {linalg.inplaceable = true}, + %idx : index, + %idx2 : index) -> (tensor, vector<5xf32>) { %cst = arith.constant 0.0 : f32 @@ -254,7 +259,9 @@ // CHECK-LABEL: @read_of_matching_insert_slice_source_interleaved func @read_of_matching_insert_slice_source_interleaved( - %A : tensor {linalg.inplaceable = true}, %idx : index, %idx2 : index, + %A : tensor {linalg.inplaceable = true}, + %idx : index, + %idx2 : index, %idx3 : index) -> (tensor, vector<5xf32>) { @@ -296,8 +303,8 @@ // CHECK-LABEL: func @extract_slice_linalg_readonly_use func @extract_slice_linalg_readonly_use( - %A : tensor, - %B : tensor<4x4xf32>, + %A : tensor {linalg.inplaceable = false}, + %B : tensor<4x4xf32> {linalg.inplaceable = false}, %C : tensor<4x4xf32> {linalg.inplaceable = true}) -> (tensor<4x4xf32>, tensor<4x4xf32>) { @@ -330,8 +337,8 @@ // CHECK-LABEL: func @extract_slice_to_linalg_write_use func @extract_slice_to_linalg_write_use( - %A : tensor<4x4xf32>, - %B : tensor, + %A : tensor<4x4xf32> {linalg.inplaceable = false}, + %B : tensor {linalg.inplaceable = false}, %C : tensor {linalg.inplaceable = true}) -> (tensor<4x4xf32>, tensor<4x4xf32>) { @@ -370,9 +377,15 @@ // CHECK-LABEL: func @insert_slice_double_extract_slice func @insert_slice_double_extract_slice( - %s1: index, %s2: index, %s3: index, %s4: index, %A: tensor<8x6xf32>, - %B: tensor<6x6xf32>, %C: tensor<30x20xf32> {linalg.inplaceable = true}) - -> tensor<30x20xf32> { + %s1: index, + %s2: index, + %s3: index, + %s4: index, + %A: tensor<8x6xf32> {linalg.inplaceable = false}, + %B: tensor<6x6xf32> {linalg.inplaceable = false}, + %C: tensor<30x20xf32> {linalg.inplaceable = true}) + -> tensor<30x20xf32> +{ // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"]} %15 = tensor.extract_slice %C[%s3, %s4] [%s1, %s2] [1, 1] : tensor<30x20xf32> to tensor @@ -402,8 +415,8 @@ // CHECK-LABEL: func @extract_slice_to_linalg_write_use func @extract_slice_to_linalg_write_use( - %A : tensor<4x4xf32>, - %B : tensor, + %A : tensor<4x4xf32> {linalg.inplaceable = false}, + %B : tensor {linalg.inplaceable = false}, %C : tensor {linalg.inplaceable = true}) -> (tensor<4x4xf32>, tensor<4x4xf32>) { @@ -444,7 +457,7 @@ // CHECK-LABEL: func @nested_extract_slice_and_insert func @nested_extract_slice_and_insert( - %A : tensor, + %A : tensor {linalg.inplaceable = false}, %B : tensor {linalg.inplaceable = true}, %C : tensor {linalg.inplaceable = true}, %idx : index, @@ -535,9 +548,12 @@ // ----- // CHECK-LABEL: func @scf_for_yield_only -func @scf_for_yield_only(%A : tensor, - %B : tensor {linalg.inplaceable = true}, - %lb : index, %ub : index, %step : index) +func @scf_for_yield_only( + %A : tensor {linalg.inplaceable = false}, + %B : tensor {linalg.inplaceable = true}, + %lb : index, + %ub : index, + %step : index) -> (tensor, tensor) { // CHECK: scf.for @@ -562,10 +578,13 @@ // ----- // CHECK-LABEL: func @scf_for_with_tensor.insert_slice -func @scf_for_with_tensor.insert_slice(%A : tensor, - %B : tensor {linalg.inplaceable = true}, - %C : tensor<4xf32>, - %lb : index, %ub : index, %step : index) +func @scf_for_with_tensor.insert_slice( + %A : tensor {linalg.inplaceable = false}, + %B : tensor {linalg.inplaceable = true}, + %C : tensor<4xf32> {linalg.inplaceable = false}, + %lb : index, + %ub : index, + %step : index) -> (tensor, tensor) { // CHECK: scf.for @@ -597,9 +616,12 @@ func private @some_use(tensor) -> () // CHECK-LABEL: func @scf_for_deps -func @scf_for_deps(%A : tensor {linalg.inplaceable = true}, - %B : tensor {linalg.inplaceable = true}, - %lb : index, %ub : index, %step : index) +func @scf_for_deps( + %A : tensor {linalg.inplaceable = true}, + %B : tensor {linalg.inplaceable = true}, + %lb : index, + %ub : index, + %step : index) -> (tensor, tensor) { // %r0 must be out of place because one use of %t in the subsequent production diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -6,7 +6,10 @@ // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // CHECK-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> { -func @transfer_read(%A : tensor) -> (vector<4xf32>) { +func @transfer_read( + %A : tensor {linalg.inplaceable = false}) + -> (vector<4xf32>) +{ %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 @@ -23,7 +26,10 @@ // CHECK-LABEL: func @fill_inplace( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref -func @fill_inplace(%A : tensor {linalg.inplaceable = true}) -> tensor { +func @fill_inplace( + %A : tensor {linalg.inplaceable = true}) + -> tensor +{ // CHECK: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 %f0 = arith.constant 0.0 : f32 @@ -40,7 +46,7 @@ // ----- // CHECK-LABEL: func @tensor_extract(%{{.*}}: memref) -> f32 { -func @tensor_extract(%A : tensor) -> (f32) { +func @tensor_extract(%A : tensor {linalg.inplaceable = false}) -> (f32) { %c0 = arith.constant 0 : index // CHECK: %[[RES:.*]] = memref.load {{.*}} : memref @@ -57,7 +63,10 @@ /// No linalg.inplaceable flag, must allocate. // CHECK-LABEL: func @not_inplace( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref) -> memref { -func @not_inplace(%A : tensor) -> tensor { +func @not_inplace( + %A : tensor {linalg.inplaceable = false}) + -> tensor +{ // CHECK: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 %f0 = arith.constant 0.0 : f32 @@ -77,7 +86,10 @@ // CHECK-LABEL: func @not_inplace // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref) { -func @not_inplace(%A : tensor {linalg.inplaceable = true}) -> tensor { +func @not_inplace( + %A : tensor {linalg.inplaceable = true}) + -> tensor +{ %f0 = arith.constant 0.0 : f32 /// Cross-op multiple uses of %A, the first op which has interfering reads must alloc. @@ -161,9 +173,9 @@ // CHECK-SAME: %[[A1:[a-zA-Z0-9]*]]: memref, // CHECK-SAME: %[[t0:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]>, // CHECK-SAME: %[[t1:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> -func @insert_slice_fun(%A0 : tensor, +func @insert_slice_fun(%A0 : tensor {linalg.inplaceable = false}, %A1 : tensor {linalg.inplaceable = true}, - %t0 : tensor<4xf32>, + %t0 : tensor<4xf32> {linalg.inplaceable = false}, %t1 : tensor<4xf32> {linalg.inplaceable = true}) -> (tensor, tensor, tensor, tensor) { @@ -208,7 +220,9 @@ // CHECK-LABEL: func @insert_slice_fun // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> -func @insert_slice_fun(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) +func @insert_slice_fun( + %A : tensor {linalg.inplaceable = true}, + %t : tensor<4xf32> {linalg.inplaceable = false}) -> tensor { %f0 = arith.constant 0.0 : f32 @@ -234,7 +248,9 @@ // CHECK-LABEL: func @insert_slice_fun // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> -func @insert_slice_fun(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) +func @insert_slice_fun( + %A : tensor {linalg.inplaceable = true}, + %t : tensor<4xf32> {linalg.inplaceable = false}) -> tensor { %f0 = arith.constant 0.0 : f32 @@ -260,7 +276,9 @@ // CHECK-LABEL: func @insert_slice_fun_not_inplace // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> -func @insert_slice_fun_not_inplace(%A : tensor, %t : tensor<4xf32>) +func @insert_slice_fun_not_inplace( + %A : tensor {linalg.inplaceable = false}, + %t : tensor<4xf32> {linalg.inplaceable = false}) -> tensor { // CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) {alignment = 128 : i64} : memref @@ -285,7 +303,7 @@ // CHECK-LABEL: func @scf_for_yield_only // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref -func @scf_for_yield_only(%A : tensor, +func @scf_for_yield_only(%A : tensor {linalg.inplaceable = false}, %B : tensor {linalg.inplaceable = true}, %lb : index, %ub : index, %step : index) -> (tensor, tensor) @@ -340,9 +358,9 @@ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> func @scf_for_with_tensor.insert_slice( - %A : tensor, + %A : tensor {linalg.inplaceable = false}, %B : tensor {linalg.inplaceable = true}, - %C : tensor<4xf32>, + %C : tensor<4xf32> {linalg.inplaceable = false}, %lb : index, %ub : index, %step : index) -> (tensor, tensor) { @@ -567,8 +585,13 @@ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[c:[a-zA-Z0-9]*]]: memref -func @tiled_dot(%A: tensor, %B: tensor, %c: tensor {linalg.inplaceable = true}, - %effecting: memref) -> tensor { +func @tiled_dot( + %A: tensor {linalg.inplaceable = false}, + %B: tensor {linalg.inplaceable = false}, + %c: tensor {linalg.inplaceable = true}, + %effecting: memref) + -> tensor +{ %c3 = arith.constant 3 : index %c0 = arith.constant 0 : index @@ -719,9 +742,9 @@ // CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref -func @entry(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, - %B : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, - %C : tensor) { +func @entry(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false}, + %B : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false}, + %C : tensor {linalg.inplaceable = false}) { // CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref to memref // CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]]) call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () @@ -809,7 +832,7 @@ // CHECK: %[[cast:.*]] = memref.cast %[[alloc]] func @tensor_cast_not_in_place( %A : tensor {linalg.inplaceable = true}, - %B : tensor, %idx: index) + %B : tensor {linalg.inplaceable = false}, %idx: index) -> (tensor) { %r0 = tensor.cast %A : tensor to tensor<4xf32> @@ -827,7 +850,11 @@ /// errors in the def-use chains. // CHECK-LABEL: func @dominance_violation_bug_1 -func @dominance_violation_bug_1(%A : tensor, %idx : index) -> tensor { +func @dominance_violation_bug_1( + %A : tensor {linalg.inplaceable = false}, + %idx : index) + -> tensor +{ %f0 = arith.constant 0.0 : f32 %sA = tensor.extract_slice %A[0, 0][%idx, %idx][1, 1] : tensor to tensor @@ -958,7 +985,11 @@ // CHECK-LABEL: func @scf_if_non_equiv_yields( // CHECK-SAME: %[[cond:.*]]: i1, %[[A:.*]]: memref<{{.*}}>, %[[B:.*]]: memref<{{.*}}>) -> memref<{{.*}}> -func @scf_if_non_equiv_yields(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> +func @scf_if_non_equiv_yields( + %b : i1, + %A : tensor<4xf32> {linalg.inplaceable = false}, + %B : tensor<4xf32> {linalg.inplaceable = false}) + -> tensor<4xf32> { // CHECK: %[[r:.*]] = select %[[cond]], %[[A]], %[[B]] %r = scf.if %b -> (tensor<4xf32>) { @@ -1092,7 +1123,9 @@ // ----- -func @gather_like(%arg0 : tensor, %arg1 : tensor, +func @gather_like( + %arg0 : tensor {linalg.inplaceable = false}, + %arg1 : tensor {linalg.inplaceable = false}, %arg2 : tensor {linalg.inplaceable = true}) -> tensor { %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0)>, @@ -1101,9 +1134,9 @@ ins(%arg1 : tensor) outs(%arg2 : tensor) { ^bb0(%arg3: i32, %arg4 : f32): %iv1 = linalg.index 1 : index - %1 = arith.index_cast %arg3: i32 to index - %2 = tensor.extract %arg0[%1, %iv1] : tensor - linalg.yield %2 : f32 + %1 = arith.index_cast %arg3: i32 to index + %2 = tensor.extract %arg0[%1, %iv1] : tensor + linalg.yield %2 : f32 } -> tensor return %0 : tensor } @@ -1124,7 +1157,8 @@ // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref func @linalg_op_bufferizes_inplace_with_input( %t1: tensor {linalg.inplaceable = true}, - %t2: tensor, %t3: tensor, + %t2: tensor {linalg.inplaceable = false}, + %t3: tensor {linalg.inplaceable = false}, %s1: index, %s2: index, %cst: f32) -> tensor { // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t1]] : {{.*}}) %r = linalg.generic { @@ -1146,7 +1180,9 @@ // CHECK-LABEL: func @linalg_op_bufferizes_out_of_place_with_input // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref func @linalg_op_bufferizes_out_of_place_with_input( - %t1: tensor, %t2: tensor, %t3: tensor, + %t1: tensor {linalg.inplaceable = false}, + %t2: tensor {linalg.inplaceable = false}, + %t3: tensor {linalg.inplaceable = false}, %s1: index, %s2: index, %cst: f32) -> tensor { // CHECK: %[[alloc:.*]] = memref.alloc // CHECK: linalg.copy(%[[t1]], %[[alloc]]) @@ -1172,7 +1208,8 @@ // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref func @linalg_op_output_cannot_alias_with_input( %t1: tensor {linalg.inplaceable = true}, - %t2: tensor, %t3: tensor {linalg.inplaceable = true}, + %t2: tensor {linalg.inplaceable = false}, + %t3: tensor {linalg.inplaceable = true}, %s1: index, %s2: index, %cst: f32) -> tensor { // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t3]] : {{.*}}) %r = linalg.generic {