diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -121,6 +121,8 @@ /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs); + +arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred); } // namespace arith } // namespace mlir diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -40,7 +40,7 @@ } /// Invert an integer comparison predicate. -static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { +arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { switch (pred) { case arith::CmpIPredicate::eq: return arith::CmpIPredicate::ne; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -2443,11 +2443,71 @@ return success(); } }; + +/// Replace operations equivalent to the condition in the do block with true, +/// since otherwise the block would not be evaluated. +/// +/// scf.while (..) : (i32, ...) -> ... { +/// %z = ... : i32 +/// %condition = cmpi pred %z, %a +/// scf.condition(%condition) %z : i32, ... +/// } do { +/// ^bb0(%arg0: i32, ...): +/// %condition2 = cmpi pred %arg0, %a +/// use(%condition2) +/// ... +/// +/// becomes +/// scf.while (..) : (i32, ...) -> ... { +/// %z = ... : i32 +/// %condition = cmpi pred %z, %a +/// scf.condition(%condition) %z : i32, ... +/// } do { +/// ^bb0(%arg0: i32, ...): +/// use(%true) +/// ... +struct WhileCmpCond : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + using namespace scf; + auto cond = op.getConditionOp(); + auto cmp = cond.getCondition().getDefiningOp(); + if (!cmp) + return failure(); + bool changed = false; + for (auto tup : + llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) { + if (std::get<0>(tup) == cmp.getLhs()) { + for (auto &u : std::get<1>(tup).getUses()) { + if (auto cmp2 = dyn_cast(u.getOwner())) { + if (cmp2.getRhs() != cmp.getRhs()) + continue; + bool samePredicate; + if (cmp2.getPredicate() == cmp.getPredicate()) + samePredicate = true; + else if (cmp2.getPredicate() == + arith::invertPredicate(cmp.getPredicate())) + samePredicate = false; + else + continue; + + rewriter.replaceOpWithNewOp(u.getOwner(), + samePredicate, 1); + changed = true; + } + } + } + } + return success(changed); + } +}; } // namespace void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -872,6 +872,56 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[res]] : i32 +// CHECK-LABEL: @while_cmp_true +func @while_cmp_true(%arg0 : i32) { + %0 = scf.while () : () -> i32 { + %val = "test.val"() : () -> i32 + %condition = arith.cmpi ne, %val, %arg0 : i32 + scf.condition(%condition) %val : i32 + } do { + ^bb0(%val2: i32): + %condition2 = arith.cmpi ne, %val2, %arg0 : i32 + "test.use"(%condition2, %val2) : (i1, i32) -> () + scf.yield + } + return +} +// CHECK-NEXT: %[[true:.+]] = arith.constant true +// CHECK-NEXT: %{{.+}} = scf.while : () -> i32 { +// CHECK-NEXT: %[[val:.+]] = "test.val" +// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32 +// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors +// CHECK-NEXT: "test.use"(%[[true]], %arg1) : (i1, i32) -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } + +// CHECK-LABEL: @while_cmp_false +func @while_cmp_false(%arg0 : i32) { + %0 = scf.while () : () -> i32 { + %val = "test.val"() : () -> i32 + %condition = arith.cmpi ne, %val, %arg0 : i32 + scf.condition(%condition) %val : i32 + } do { + ^bb0(%val2: i32): + %condition2 = arith.cmpi eq, %val2, %arg0 : i32 + "test.use"(%condition2, %val2) : (i1, i32) -> () + scf.yield + } + return +} +// CHECK-NEXT: %[[false:.+]] = arith.constant false +// CHECK-NEXT: %{{.+}} = scf.while : () -> i32 { +// CHECK-NEXT: %[[val:.+]] = "test.val" +// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32 +// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors +// CHECK-NEXT: "test.use"(%[[false]], %arg1) : (i1, i32) -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } + // ----- // CHECK-LABEL: @combineIfs