diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -178,6 +178,8 @@ /// value for `index`. OperandRange getSuccessorEntryOperands(unsigned index); }]; + + let hasCanonicalizer = 1; } def IfOp : SCF_Op<"if", diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -359,6 +359,120 @@ }); } +namespace { +// Fold away ForOp iter arguments that are also yielded by the op. +// These arguments must be defined outside of the ForOp region and can just be +// forwarded after simplifying the op inits, yields and /returns. +// +// The implementation uses `mergeBlockBefore` to steal the content of the +// original ForOp and avoid cloning. +struct ForOpIterArgsFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + bool canonicalize = false; + Block &block = forOp.region().front(); + auto yieldOp = cast(block.getTerminator()); + + // An internal flat vector of block transfer + // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to + // transformed block argument mappings. This plays the role of a + // BlockAndValueMapping for the particular use case of calling into + // `mergeBlockBefore`. + SmallVector keepMask; + keepMask.reserve(yieldOp.getNumOperands()); + SmallVector newBlockTransferArgs, newIterArgs, newYieldValues, + newResultValues; + newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands()); + newBlockTransferArgs.push_back(Value()); // iv placeholder with null value + newIterArgs.reserve(forOp.getNumIterOperands()); + newYieldValues.reserve(yieldOp.getNumOperands()); + newResultValues.reserve(forOp.getNumResults()); + for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside + forOp.getRegionIterArgs(), // iter inside region + yieldOp.getOperands() // iter yield + )) { + // Forwarded is `true` when the region `iter` argument is yielded. + bool forwarded = (std::get<1>(it) == std::get<2>(it)); + keepMask.push_back(!forwarded); + canonicalize |= forwarded; + if (forwarded) { + newBlockTransferArgs.push_back(std::get<0>(it)); + newResultValues.push_back(std::get<0>(it)); + continue; + } + newIterArgs.push_back(std::get<0>(it)); + newYieldValues.push_back(std::get<2>(it)); + newBlockTransferArgs.push_back(Value()); // placeholder with null value + newResultValues.push_back(Value()); // placeholder with null value + } + + if (!canonicalize) + return failure(); + + scf::ForOp newForOp = rewriter.create( + forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(), + newIterArgs); + Block &newBlock = newForOp.region().front(); + + // Replace the null placeholders with newly constructed values. + newBlockTransferArgs[0] = newBlock.getArgument(0); // iv + for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size(); + idx != e; ++idx) { + Value &blockTransferArg = newBlockTransferArgs[1 + idx]; + Value &newResultVal = newResultValues[idx]; + assert((blockTransferArg && newResultVal) || + (!blockTransferArg && !newResultVal)); + if (!blockTransferArg) { + blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx]; + newResultVal = newForOp.getResult(collapsedIdx++); + } + } + + Block &oldBlock = forOp.region().front(); + assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() && + "unexpected argument size mismatch"); + + // No results case: the scf::ForOp builder already created a zero + // reult terminator. Merge before this terminator and just get rid of the + // original terminator that has been merged in. + if (newIterArgs.empty()) { + auto newYieldOp = cast(newBlock.getTerminator()); + rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs); + rewriter.eraseOp(newBlock.getTerminator()->getPrevNode()); + rewriter.replaceOp(forOp, newResultValues); + return success(); + } + + // No terminator case: merge and rewrite the merged terminator. + auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(mergedTerminator); + SmallVector filteredOperands; + filteredOperands.reserve(newResultValues.size()); + for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) + if (keepMask[idx]) + filteredOperands.push_back(mergedTerminator.getOperand(idx)); + rewriter.create(mergedTerminator.getLoc(), + filteredOperands); + }; + + rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); + auto mergedYieldOp = cast(newBlock.getTerminator()); + cloneFilteredTerminator(mergedYieldOp); + rewriter.eraseOp(mergedYieldOp); + rewriter.replaceOp(forOp, newResultValues); + return success(); + } +}; +} // namespace + +void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -137,3 +137,38 @@ // CHECK: call @side_effect() : () -> () // CHECK: } // CHECK: return + +// ----- + +func @make_i32() -> i32 + +func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 { + %a = call @make_i32() : () -> (i32) + %b = scf.for %i = %lb to %ub step %step iter_args(%0 = %a) -> i32 { + scf.yield %0 : i32 + } + return %b : i32 +} + +// CHECK-LABEL: func @for_yields_2 +// CHECK-NEXT: %[[R:.*]] = call @make_i32() : () -> i32 +// CHECK-NEXT: return %[[R]] : i32 + +func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) { + %a = call @make_i32() : () -> (i32) + %b = call @make_i32() : () -> (i32) + %r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) { + %c = call @make_i32() : () -> (i32) + scf.yield %0, %c, %2 : i32, i32, i32 + } + return %r#0, %r#1, %r#2 : i32, i32, i32 +} + +// CHECK-LABEL: func @for_yields_3 +// CHECK-NEXT: %[[a:.*]] = call @make_i32() : () -> i32 +// CHECK-NEXT: %[[b:.*]] = call @make_i32() : () -> i32 +// CHECK-NEXT: %[[r1:.*]] = scf.for {{.*}} iter_args(%arg4 = %[[a]]) -> (i32) { +// CHECK-NEXT: %[[c:.*]] = call @make_i32() : () -> i32 +// CHECK-NEXT: scf.yield %[[c]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32