diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -1110,6 +1110,8 @@ YieldOp getYieldOp(); Block::BlockArgListType getBeforeArguments(); Block::BlockArgListType getAfterArguments(); + Block *getBeforeBody() { return &getBefore().front(); } + Block *getAfterBody() { return &getAfter().front(); } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -542,10 +542,8 @@ rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); // Inline both regions. - Block *after = &whileOp.getAfter().front(); - Block *afterLast = &whileOp.getAfter().back(); - Block *before = &whileOp.getBefore().front(); - Block *beforeLast = &whileOp.getBefore().back(); + Block *after = whileOp.getAfterBody(); + Block *before = whileOp.getBeforeBody(); rewriter.inlineRegionBefore(whileOp.getAfter(), continuation); rewriter.inlineRegionBefore(whileOp.getBefore(), after); @@ -556,14 +554,14 @@ // Replace terminators with branches. Assuming bodies are SESE, which holds // given only the patterns from this file, we only need to look at the last // block. This should be reconsidered if we allow break/continue in SCF. - rewriter.setInsertionPointToEnd(beforeLast); - auto condOp = cast(beforeLast->getTerminator()); + rewriter.setInsertionPointToEnd(before); + auto condOp = cast(before->getTerminator()); rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), after, condOp.getArgs(), continuation, ValueRange()); - rewriter.setInsertionPointToEnd(afterLast); - auto yieldOp = cast(afterLast->getTerminator()); + rewriter.setInsertionPointToEnd(after); + auto yieldOp = cast(after->getTerminator()); rewriter.replaceOpWithNewOp(yieldOp, before, yieldOp.getResults()); @@ -577,12 +575,7 @@ LogicalResult DoWhileLowering::matchAndRewrite(WhileOp whileOp, PatternRewriter &rewriter) const { - if (!llvm::hasSingleElement(whileOp.getAfter())) - return rewriter.notifyMatchFailure(whileOp, - "do-while simplification applicable to " - "single-block 'after' region only"); - - Block &afterBlock = whileOp.getAfter().front(); + Block &afterBlock = *whileOp.getAfterBody(); if (!llvm::hasSingleElement(afterBlock)) return rewriter.notifyMatchFailure(whileOp, "do-while simplification applicable " @@ -601,8 +594,7 @@ rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); // Only the "before" region should be inlined. - Block *before = &whileOp.getBefore().front(); - Block *beforeLast = &whileOp.getBefore().back(); + Block *before = whileOp.getBeforeBody(); rewriter.inlineRegionBefore(whileOp.getBefore(), continuation); // Branch to the "before" region. @@ -610,8 +602,8 @@ rewriter.create(whileOp.getLoc(), before, whileOp.getInits()); // Loop around the "before" region based on condition. - rewriter.setInsertionPointToEnd(beforeLast); - auto condOp = cast(beforeLast->getTerminator()); + rewriter.setInsertionPointToEnd(before); + auto condOp = cast(before->getTerminator()); rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), before, condOp.getArgs(), continuation, ValueRange()); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3177,19 +3177,19 @@ } ConditionOp WhileOp::getConditionOp() { - return cast(getBefore().front().getTerminator()); + return cast(getBeforeBody()->getTerminator()); } YieldOp WhileOp::getYieldOp() { - return cast(getAfter().front().getTerminator()); + return cast(getAfterBody()->getTerminator()); } Block::BlockArgListType WhileOp::getBeforeArguments() { - return getBefore().front().getArguments(); + return getBeforeBody()->getArguments(); } Block::BlockArgListType WhileOp::getAfterArguments() { - return getAfter().front().getArguments(); + return getAfterBody()->getArguments(); } void WhileOp::getSuccessorRegions(std::optional index, @@ -3260,8 +3260,7 @@ /// Prints a `while` op. void scf::WhileOp::print(OpAsmPrinter &p) { - printInitializationList(p, getBefore().front().getArguments(), getInits(), - " "); + printInitializationList(p, getBeforeArguments(), getInits(), " "); p << " : "; p.printFunctionalType(getInits().getTypes(), getResults().getTypes()); p << ' '; @@ -3411,7 +3410,7 @@ LogicalResult matchAndRewrite(WhileOp op, PatternRewriter &rewriter) const override { - Block &afterBlock = op.getAfter().front(); + Block &afterBlock = *op.getAfterBody(); Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); ConditionOp condOp = op.getConditionOp(); OperandRange condOpArgs = condOp.getArgs(); @@ -3493,7 +3492,7 @@ &newWhile.getBefore(), /*insertPt*/ {}, ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs); - Block &beforeBlock = op.getBefore().front(); + Block &beforeBlock = *op.getBeforeBody(); SmallVector newBeforeBlockArgs(beforeBlock.getNumArguments()); // For each i-th before block argument we find it's replacement value as :- // 1. If i-th before block argument is a loop invariant, we fetch it's @@ -3563,7 +3562,7 @@ LogicalResult matchAndRewrite(WhileOp op, PatternRewriter &rewriter) const override { - Block &beforeBlock = op.getBefore().front(); + Block &beforeBlock = *op.getBeforeBody(); ConditionOp condOp = op.getConditionOp(); OperandRange condOpArgs = condOp.getArgs(); @@ -3616,7 +3615,7 @@ *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, newAfterBlockType, newAfterBlockArgLocs); - Block &afterBlock = op.getAfter().front(); + Block &afterBlock = *op.getAfterBody(); // Since a new scf.condition op was created, we need to fetch the new // `after` block arguments which will be used while replacing operations of // previous scf.while's `after` blocks. We'd also be fetching new result @@ -3733,7 +3732,7 @@ rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), newWhile.getBefore().begin()); - Block &afterBlock = op.getAfter().front(); + Block &afterBlock = *op.getAfterBody(); rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); rewriter.replaceOp(op, newResults); @@ -3774,8 +3773,7 @@ if (!cmp) return failure(); bool changed = false; - for (auto tup : - llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) { + for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) { for (size_t opIdx = 0; opIdx < 2; opIdx++) { if (std::get<0>(tup) != cmp.getOperand(opIdx)) continue; @@ -3839,8 +3837,8 @@ } } - Block &beforeBlock = op.getBefore().front(); - Block &afterBlock = op.getAfter().front(); + Block &beforeBlock = *op.getBeforeBody(); + Block &afterBlock = *op.getAfterBody(); beforeBlock.eraseArguments(argsToErase); @@ -3848,8 +3846,8 @@ auto newWhileOp = rewriter.create(loc, op.getResultTypes(), newInits, /*beforeBody*/ nullptr, /*afterBody*/ nullptr); - Block &newBeforeBlock = newWhileOp.getBefore().front(); - Block &newAfterBlock = newWhileOp.getAfter().front(); + Block &newBeforeBlock = *newWhileOp.getBeforeBody(); + Block &newAfterBlock = *newWhileOp.getAfterBody(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(yield); @@ -3899,8 +3897,8 @@ auto newWhileOp = rewriter.create( loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr, /*afterBody*/ nullptr); - Block &newBeforeBlock = newWhileOp.getBefore().front(); - Block &newAfterBlock = newWhileOp.getAfter().front(); + Block &newBeforeBlock = *newWhileOp.getBeforeBody(); + Block &newAfterBlock = *newWhileOp.getAfterBody(); SmallVector afterArgsMapping; SmallVector resultsMapping; @@ -3917,8 +3915,8 @@ rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), argsRange); - Block &beforeBlock = op.getBefore().front(); - Block &afterBlock = op.getAfter().front(); + Block &beforeBlock = *op.getBeforeBody(); + Block &afterBlock = *op.getAfterBody(); rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlock.getArguments()); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -760,13 +760,6 @@ const BufferizationOptions &options) const { auto whileOp = cast(op); - assert(whileOp.getBefore().getBlocks().size() == 1 && - "regions with multiple blocks not supported"); - Block *beforeBody = &whileOp.getBefore().front(); - assert(whileOp.getAfter().getBlocks().size() == 1 && - "regions with multiple blocks not supported"); - Block *afterBody = &whileOp.getAfter().front(); - // Indices of all bbArgs that have tensor type. These are the ones that // are bufferized. The "before" and "after" regions may have different args. DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); @@ -827,7 +820,7 @@ rewriter.setInsertionPointToStart(newBeforeBody); SmallVector newBeforeArgs = getBbArgReplacements( rewriter, newWhileOp.getBeforeArguments(), indicesBefore); - rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); + rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs); // Set up new iter_args and move the loop body block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block @@ -835,7 +828,7 @@ rewriter.setInsertionPointToStart(newAfterBody); SmallVector newAfterArgs = getBbArgReplacements( rewriter, newWhileOp.getAfterArguments(), indicesAfter); - rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); + rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs); // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -57,7 +57,7 @@ // arguments to the 'after' region. auto *beforeBlock = rewriter.createBlock( &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); - rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); + rewriter.setInsertionPointToStart(whileOp.getBeforeBody()); auto cmpOp = rewriter.create( whileOp.getLoc(), arith::CmpIPredicate::slt, beforeBlock->getArgument(0), forOp.getUpperBound());