diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1611,6 +1611,7 @@ Operation *newCallOp = b.create(callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); + callOp->erase(); return success(); } @@ -2259,17 +2260,52 @@ DenseMap *bufferizedFunctionTypes, GlobalCreator *globalCreator) { OpBuilder b(op->getContext()); + + // Return the first child op or nullptr for ops without children. + // TODO: Is there a better way? + auto getFirstChildOp = [](Operation *op) { + Operation *firstChildOp = nullptr; + op->walk([&](Operation *childOp) -> WalkResult { + firstChildOp = childOp; + return WalkResult::interrupt(); + }); + return firstChildOp; + }; + + // Some ops require a pre-order traversal. Bufferize such ops right before + // bufferizing their first child op. + if (getFirstChildOp(op->getParentOp()) == op) { + // op is the first child inside op's parent. + LogicalResult preOrder = + TypeSwitch(op->getParentOp()) + // bbArgs must be mapped in bvm before bufferizing the loop body. + .Case([&](auto op) { + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo); + }) + .Default([](Operation *op) { return success(); }); + if (failed(preOrder)) + return failure(); + } + + // Handling of ops with regular post-order traversal. return TypeSwitch(op) // Skip BufferCast and TensorLoad ops. .Case( [&](auto) { return success(); }) - .Case([&](auto op) { - LDBG("Begin bufferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo); + // Handle pre-order traversal ops that have no children. + .Case([&](auto op) { + if (!getFirstChildOp(op)) + return bufferize(b, op, bvm, aliasInfo); + return success(); }) + .Case( + [&](auto op) { + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo); + }) .Case([&](CallOpInterface op) { LDBG("Begin bufferize:\n" << op << '\n'); if (!bufferizedFunctionTypes) @@ -2302,34 +2338,23 @@ LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); - /// Start by bufferizing `funcOp` arguments. + + // Start by bufferizing `funcOp` arguments. if (failed(bufferize(b, funcOp, bvm, aliasInfo))) return failure(); - // Walk in PreOrder to ensure ops with regions are handled before their body. - // Since walk has to be PreOrder, we need to erase ops that require it - // separately: this is the case for CallOp - SmallVector toErase; - if (funcOp - .walk([&](Operation *op) -> WalkResult { - if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, - &globalCreator))) - return failure(); - // Register post-walk erasure, if necessary. - if (isa(op)) - if (llvm::any_of(op->getOperandTypes(), isaTensor) || - llvm::any_of(op->getResultTypes(), isaTensor)) - toErase.push_back(op); - return success(); - }) - .wasInterrupted()) + // Bufferize the function body. + auto walkFunc = [&](Operation *op) -> WalkResult { + if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, + &globalCreator))) + return failure(); + return success(); + }; + if (funcOp.walk(walkFunc).wasInterrupted()) return failure(); LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); - for (Operation *op : toErase) - op->erase(); - return success(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -111,19 +111,6 @@ // ----- -func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> -{ - // expected-error @+1 {{unsupported op with tensors}} - %r = scf.if %b -> (tensor<4xf32>) { - scf.yield %A : tensor<4xf32> - } else { - scf.yield %B : tensor<4xf32> - } - return %r: tensor<4xf32> -} - -// ----- - func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32> { // expected-error @+1 {{unsupported op with tensors}} @@ -144,7 +131,7 @@ // ----- func @main() -> tensor<4xi32> { - // expected-error @+1 {{unsupported op with tensors}} + // expected-error @+1 {{expected result-less scf.execute_region containing op}} %r = scf.execute_region -> tensor<4xi32> { %A = constant dense<[1, 2, 3, 4]> : tensor<4xi32> scf.yield %A: tensor<4xi32>