diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -586,7 +586,11 @@ let extraClassDeclaration = [{ OperandRange getSuccessorEntryOperands(unsigned index); + ConditionOp getConditionOp(); + Block::BlockArgListType getAfterArguments(); }]; + + let hasCanonicalizer = 1; } def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator, diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1694,6 +1694,14 @@ return inits(); } +ConditionOp WhileOp::getConditionOp() { + return cast(before().front().getTerminator()); +} + +Block::BlockArgListType WhileOp::getAfterArguments() { + return after().front().getArguments(); +} + void WhileOp::getSuccessorRegions(Optional index, ArrayRef operands, SmallVectorImpl ®ions) { @@ -1835,6 +1843,62 @@ return success(afterTerminator != nullptr); } +namespace { +/// Replace uses of the condition within the do block with true, since otherwise +/// the block would not be evaluated. +/// +/// scf.while (..) : (i1, ...) -> ... { +/// %condition = call @evaluate_condition() : () -> i1 +/// scf.condition(%condition) %condition : i1, ... +/// } do { +/// ^bb0(%arg0: i1, ...): +/// use(%arg0) +/// ... +/// +/// becomes +/// scf.while (..) : (i1, ...) -> ... { +/// %condition = call @evaluate_condition() : () -> i1 +/// scf.condition(%condition) %condition : i1, ... +/// } do { +/// ^bb0(%arg0: i1, ...): +/// use(%true) +/// ... +struct WhileConditionTruth : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + auto term = op.getConditionOp(); + + // These variables serve to prevent creating duplicate constants + // and hold constant true or false values. + Value constantTrue = nullptr; + + bool replaced = false; + for (auto yieldedAndBlockArgs : + llvm::zip(term.args(), op.getAfterArguments())) { + if (std::get<0>(yieldedAndBlockArgs) == term.condition()) { + if (!std::get<1>(yieldedAndBlockArgs).use_empty()) { + if (!constantTrue) + constantTrue = rewriter.create( + op.getLoc(), term.condition().getType(), + rewriter.getBoolAttr(true)); + + std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue); + replaced = true; + } + } + } + return success(replaced); + } +}; +} // namespace + +void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -724,3 +724,26 @@ // CHECK-NEXT: scf.yield %[[sv2]] : i32 // CHECK-NEXT: } // CHECK-NEXT: return %[[if]], %arg1 : i32, i64 + + +// CHECK-LABEL: @while_cond_true +func @while_cond_true() { + %0 = scf.while () : () -> i1 { + %condition = "test.condition"() : () -> i1 + scf.condition(%condition) %condition : i1 + } do { + ^bb0(%arg0: i1): + "test.use"(%arg0) : (i1) -> () + scf.yield + } + return +} +// CHECK-NEXT: %[[true:.+]] = constant true +// CHECK-NEXT: %{{.+}} = scf.while : () -> i1 { +// CHECK-NEXT: %[[cmp:.+]] = "test.condition"() : () -> i1 +// CHECK-NEXT: scf.condition(%[[cmp]]) %[[cmp]] : i1 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg0: i1): // no predecessors +// CHECK-NEXT: "test.use"(%[[true]]) : (i1) -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: }