diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2467,33 +2467,18 @@ assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && "expected a funcOp definition with a body"); - // Collect ops so we can build our own traversal. - SmallVector otherOps; - SmallVector insertSliceOps; + // Collect ops so we can build our own reverse traversal. + SmallVector ops; funcOp.walk([&](Operation *op) { - if (auto insertSliceOp = dyn_cast(op)) - return insertSliceOps.push_back(insertSliceOp); // No tensors => no buffers. if (none_of(op->getOperandTypes(), isaTensor) && none_of(op->getResultTypes(), isaTensor)) return; - otherOps.push_back(op); + ops.push_back(op); }); - // First, analyze InsertSliceOp greedily: we almost never want to bufferize - // the tensor "inserted into" to become out-of-place. This implementation - // does not distinguish between different InsertSliceOp. If we want - // finer-grained behavior, we could order the InsertSliceOp with some metric. - for (InsertSliceOp insertSliceOp : reverse(insertSliceOps)) { - OpOperand &destOpOperand = insertSliceOp->getOpOperand(1); - if (failed(bufferizableInPlaceAnalysis( - destOpOperand, getInplaceableOpResult(destOpOperand), aliasInfo, - domInfo))) - return failure(); - } - // Walk ops in reverse for better interference analysis. - for (Operation *op : reverse(otherOps)) { + for (Operation *op : reverse(ops)) { for (OpOperand &opOperand : op->getOpOperands()) { if (OpResult result = getInplaceableOpResult(opOperand)) if (result.getType().isa() && diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -705,6 +705,53 @@ // ----- +//===----------------------------------------------------------------------===// +// Chain of tensor.insert_slice is better traversed in reverse order without +// prioritizing the tensor.insert_slice ops. +//===----------------------------------------------------------------------===// + +func @insert_slice_chain( + %v1: vector<32x90xf32>, + %v2: vector<30x90xf32>, + %arg0: tensor<62x126xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<126x90xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<62x90xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + -> tensor<62x90xf32> attributes {passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} +{ + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f32 + + // CHECK: linalg.fill + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %0 = linalg.fill(%cst, %arg2) : f32, tensor<62x90xf32> -> tensor<62x90xf32> + + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + // TODO: in order to have this extract_slice bufferize inplace, we need to write a range + // analysis and determine that intersection([0, 32)x[0, 90), [32, 62)x[0, 90)) is empty. + %2 = tensor.extract_slice %0[0, 0] [32, 90] [1, 1] : tensor<62x90xf32> to tensor<32x90xf32> + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %7 = vector.transfer_write %v1, %2[%c0, %c0] {in_bounds = [true, true]} : vector<32x90xf32>, tensor<32x90xf32> + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %8 = tensor.insert_slice %7 into %0[0, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32> + + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %10 = tensor.extract_slice %8[32, 0] [30, 90] [1, 1] : tensor<62x90xf32> to tensor<30x90xf32> + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %14 = vector.transfer_write %v2, %10[%c0, %c0] {in_bounds = [true, true]} : vector<30x90xf32>, tensor<30x90xf32> + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %15 = tensor.insert_slice %14 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32> + + return %15 : tensor<62x90xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // Insert point issue cases. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -269,34 +269,6 @@ return %r0: tensor } -// ----- - -// CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> - -// CHECK-LABEL: func @insert_slice_fun_not_inplace -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref -// CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> -func @insert_slice_fun_not_inplace(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) - -> (tensor, tensor) -{ - %f0 = constant 0.0 : f32 - - // tensor.insert_slice is bufferized first, %A is inplaceable so we can make this inplace - // CHECK-DAG: %[[SV_A:.*]] = memref.subview %[[A]][0] [4] [1] : memref to memref<4xf32, {{.*}}> - // CHECK-DAG: linalg.copy(%[[t]], %[[SV_A]]) : memref<4xf32, {{.*}}>, memref<4xf32, {{.*}}> - %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor - - // fill would interfere with %r0 that is also being returned. - // So we need to bufferize it out of place and make a new alloc. - // CHECK-DAG: %[[ALLOC:.*]] = memref.alloc({{.*}}) {alignment = 128 : i64} : memref - // CHECK: linalg.fill(%{{.*}}, %[[ALLOC]] - %r1 = linalg.fill(%f0, %A) : f32, tensor -> tensor - - // CHECK: memref.dealloc %[[ALLOC]] : memref - // CHECK: return %[[ALLOC]] : memref - return %r1, %r0: tensor, tensor -} - //===----------------------------------------------------------------------===// // Simple loop cases //===----------------------------------------------------------------------===//