diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -1709,36 +1709,16 @@ LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); - /// Start by bufferizing `funcOp` arguments. + + // Start by bufferizing `funcOp` arguments. if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns))) return failure(); // Cannot erase ops during the traversal. Do that afterwards. SmallVector toErase; - // Bufferize the function body. `bufferizedOps` keeps track ops that were - // already bufferized with pre-order traversal. - DenseSet bufferizedOps; - auto walkFunc = [&](Operation *op) -> WalkResult { - // Collect ops that need to be bufferized before `op`. - SmallVector preorderBufferize; - Operation *parentOp = op->getParentOp(); - // scf::ForOp and TiledLoopOp must be bufferized before their blocks - // ("pre-order") because BBargs must be mapped when bufferizing children. - while (isa_and_nonnull(parentOp)) { - if (bufferizedOps.contains(parentOp)) - break; - bufferizedOps.insert(parentOp); - preorderBufferize.push_back(parentOp); - parentOp = parentOp->getParentOp(); - } - - for (Operation *op : llvm::reverse(preorderBufferize)) - if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns, - &bufferizedFunctionTypes))) - return failure(); - if (!bufferizedOps.contains(op) && - failed(bufferizeOp(op, bvm, aliasInfo, allocationFns, + auto walkFunc = [&](Operation *op) -> WalkResult { + if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns, &bufferizedFunctionTypes))) return failure(); @@ -1750,7 +1730,11 @@ return success(); }; - if (funcOp.walk(walkFunc).wasInterrupted()) + + // Bufferize ops pre-order, i.e., bufferize ops first, then their children. + // This is needed for ops with blocks that have BlockArguments. These must be + // mapped before bufferizing the children. + if (funcOp.walk(walkFunc).wasInterrupted()) return failure(); LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); @@ -2772,31 +2756,40 @@ BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFn) const { - auto ifOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); + // scf::IfOp is bufferized after scf::YieldOp in the else branch. + return success(); + } +}; - for (OpResult opResult : ifOp->getResults()) { - if (!opResult.getType().isa()) - continue; - // TODO: Atm we bail on unranked TensorType because we don't know how to - // alloc an UnrankedMemRefType + its underlying ranked MemRefType. - assert(opResult.getType().isa() && - "unsupported unranked tensor"); +/// Bufferize the scf::IfOp. This function is called after the YieldOp was +/// bufferized. +static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(ifOp); - Value resultBuffer = - getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); - if (!resultBuffer) - return failure(); + for (OpResult opResult : ifOp->getResults()) { + if (!opResult.getType().isa()) + continue; + // TODO: Atm we bail on unranked TensorType because we don't know how to + // alloc an UnrankedMemRefType + its underlying ranked MemRefType. + assert(opResult.getType().isa() && + "unsupported unranked tensor"); - aliasInfo.createAliasInfoEntry(resultBuffer); - map(bvm, opResult, resultBuffer); - } + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); + if (!resultBuffer) + return failure(); - return success(); + aliasInfo.createAliasInfoEntry(resultBuffer); + map(bvm, opResult, resultBuffer); } -}; + + return success(); +} struct ForOpInterface : public BufferizableOpInterface::ExternalModel(op); // Take a guard before anything else. @@ -2867,6 +2863,39 @@ } }; +/// Bufferize the scf::ForOp. This function is called after the YieldOp was +/// bufferized. +static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { + auto yieldOp = cast(&forOp.region().front().back()); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + Value yieldedBuffer = lookup(bvm, operand.get()); + Value bbArgBuffer = lookup(bvm, bbArg); + if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + return yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand"; + } + + // Buffers are equivalent so the work is already done and we just yield + // the bbArg so that it later canonicalizes away. + operand.set(bbArg); + } + return success(); +} + struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { @@ -2892,11 +2921,6 @@ AllocationCallbacks &allocationFn) const { auto yieldOp = cast(op); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - // Cannot create IR past a yieldOp. - b.setInsertionPoint(yieldOp); - if (auto execOp = dyn_cast(yieldOp->getParentOp())) { if (execOp->getNumResults() != 0) return execOp->emitError( @@ -2904,37 +2928,19 @@ return success(); } - if (auto ifOp = dyn_cast(yieldOp->getParentOp())) - return success(); - - scf::ForOp forOp = dyn_cast(yieldOp->getParentOp()); - if (!forOp) - return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp"); - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; + // Bufferize scf::IfOp after bufferizing the scf::YieldOp in the else + // branch. + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + if (ifOp.elseYield() != yieldOp) + return success(); + return bufferizeIfOp(ifOp, b, bvm, aliasInfo, allocationFn); + } - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - Value yieldedBuffer = lookup(bvm, operand.get()); - Value bbArgBuffer = lookup(bvm, bbArg); - if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, - bbArgBuffer)) { - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. - return yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand"; - } + // Bufferize scf::ForOp after bufferizing the scf::YieldOp. + if (auto forOp = dyn_cast(yieldOp->getParentOp())) + return bufferizeForOp(forOp, b, bvm, aliasInfo, allocationFn); - // Buffers are equivalent so the work is already done and we just yield - // the bbArg so that it later canonicalizes away. - operand.set(bbArg); - } - return success(); + return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp"); } }; diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -888,3 +888,32 @@ return %r : tensor } +// ----- + +// CHECK-LABEL: func @scf_if_inside_scf_for +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index +// CHECK: scf.for %{{.*}} = %[[c0]] to %[[c10]] step %[[c1]] { +// CHECK: scf.if %{{.*}} { +// CHECK: } else { +// CHECK: vector.transfer_write +// CHECK: } +// CHECK: } +func @scf_if_inside_scf_for(%t1: tensor {linalg.inplaceable = true}, + %v: vector<5xf32>, %idx: index, + %cond: i1) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%bb = %t1) -> (tensor) { + %r2 = scf.if %cond -> (tensor) { + scf.yield %bb : tensor + } else { + %t2 = vector.transfer_write %v, %bb[%idx] : vector<5xf32>, tensor + scf.yield %t2 : tensor + } + scf.yield %r2 : tensor + } + return %r : tensor +}