diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1625,6 +1625,7 @@ Operation *newCallOp = b.create(callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); + callOp->erase(); return success(); } @@ -2316,34 +2317,44 @@ LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); - /// Start by bufferizing `funcOp` arguments. + + // Start by bufferizing `funcOp` arguments. if (failed(bufferize(b, funcOp, bvm, aliasInfo))) return failure(); - // Walk in PreOrder to ensure ops with regions are handled before their body. - // Since walk has to be PreOrder, we need to erase ops that require it - // separately: this is the case for CallOp - SmallVector toErase; - if (funcOp - .walk([&](Operation *op) -> WalkResult { - if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, - &globalCreator))) - return failure(); - // Register post-walk erasure, if necessary. - if (isa(op)) - if (llvm::any_of(op->getOperandTypes(), isaTensor) || - llvm::any_of(op->getResultTypes(), isaTensor)) - toErase.push_back(op); - return success(); - }) - .wasInterrupted()) + // Bufferize the function body. `bufferizedOps` keeps track ops that were + // already bufferized with pre-order traversal. + DenseSet bufferizedOps; + auto walkFunc = [&](Operation *op) -> WalkResult { + // Collect ops that need to be bufferized before `op`. + SmallVector preorderBufferize; + Operation *parentOp = op->getParentOp(); + // scf::ForOp and TiledLoopOp must be bufferized before their blocks + // ("pre-order") because BBargs must be mapped when bufferizing children. + while (isa_and_nonnull(parentOp)) { + if (bufferizedOps.contains(parentOp)) + break; + bufferizedOps.insert(parentOp); + preorderBufferize.push_back(parentOp); + parentOp = parentOp->getParentOp(); + } + + for (Operation *op : llvm::reverse(preorderBufferize)) + if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, + &globalCreator))) + return failure(); + + if (!bufferizedOps.contains(op) && + failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, + &globalCreator))) + return failure(); + return success(); + }; + if (funcOp.walk(walkFunc).wasInterrupted()) return failure(); LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); - for (Operation *op : toErase) - op->erase(); - return success(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -113,8 +113,8 @@ func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> { - // expected-error @+1 {{unsupported op with tensors}} %r = scf.if %b -> (tensor<4xf32>) { + // expected-error @+1 {{expected scf::ForOp parent for scf::YieldOp}} scf.yield %A : tensor<4xf32> } else { scf.yield %B : tensor<4xf32> @@ -144,7 +144,7 @@ // ----- func @main() -> tensor<4xi32> { - // expected-error @+1 {{unsupported op with tensors}} + // expected-error @+1 {{expected result-less scf.execute_region containing op}} %r = scf.execute_region -> tensor<4xi32> { %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> scf.yield %A: tensor<4xi32> diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -305,6 +305,28 @@ // ----- +// Ensure that the function bufferizes without error. This tests pre-order +// traversal of scf.for loops during bufferization. No need to check the IR, +// just want to make sure that it does not crash. + +// CHECK-LABEL: func @nested_scf_for +func @nested_scf_for(%A : tensor {linalg.inplaceable = true}, + %v : vector<5xf32>) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%B = %A) -> tensor { + %r2 = scf.for %j = %c0 to %c10 step %c1 iter_args(%C = %B) -> tensor { + %w = vector.transfer_write %v, %C[%c0] : vector<5xf32>, tensor + scf.yield %w : tensor + } + scf.yield %r2 : tensor + } + return %r1 : tensor +} + +// ----- + // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK-LABEL: func @scf_for_with_tensor.insert_slice @@ -767,7 +789,7 @@ // CHECK-LABEL: func @dominance_violation_bug_1 func @dominance_violation_bug_1(%A : tensor, %idx : index) -> tensor { %f0 = arith.constant 0.0 : f32 - + %sA = tensor.extract_slice %A[0, 0][%idx, %idx][1, 1] : tensor to tensor %ssA = tensor.extract_slice %sA[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> %FA = linalg.fill(%f0, %ssA) : f32, tensor<4x4xf32> -> tensor<4x4xf32>