diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -930,13 +930,82 @@ return success(); } }; + +/// cond_br %arg0, ^trueB, ^falseB +/// +/// ^trueB: +/// "test.consumer1"(%arg0) : (i1) -> () +/// ... +/// +/// ^falseB: +/// "test.consumer2"(%arg0) : (i1) -> () +/// ... +/// +/// -> +/// +/// cond_br %arg0, ^trueB, ^falseB +/// ^trueB: +/// "test.consumer1"(%true) : (i1) -> () +/// ... +/// +/// ^falseB: +/// "test.consumer2"(%false) : (i1) -> () +/// ... +struct CondBranchTruthPropagation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // Check that we have a single distinct predecessor. + bool replaced = false; + Type ty = rewriter.getI1Type(); + + // These variables serve to prevent creating duplicate constants + // and hold constant true or false values. + Value constantTrue = nullptr; + Value constantFalse = nullptr; + + if (condbr.getTrueDest()->getSinglePredecessor()) { + for (OpOperand &use : + llvm::make_early_inc_range(condbr.condition().getUses())) { + if (use.getOwner()->getBlock() == condbr.getTrueDest()) { + replaced = true; + + if (!constantTrue) + constantTrue = rewriter.create( + condbr.getLoc(), ty, rewriter.getBoolAttr(true)); + + rewriter.updateRootInPlace(use.getOwner(), + [&] { use.set(constantTrue); }); + } + } + } + if (condbr.getFalseDest()->getSinglePredecessor()) { + for (OpOperand &use : + llvm::make_early_inc_range(condbr.condition().getUses())) { + if (use.getOwner()->getBlock() == condbr.getFalseDest()) { + replaced = true; + + if (!constantFalse) + constantFalse = rewriter.create( + condbr.getLoc(), ty, rewriter.getBoolAttr(false)); + + rewriter.updateRootInPlace(use.getOwner(), + [&] { use.set(constantFalse); }); + } + } + } + return replaced ? success() : failure(); + } +}; } // end anonymous namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + SimplifyCondBranchFromCondBranchOnSameCondition, + CondBranchTruthPropagation>(context); } Optional diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -306,3 +306,32 @@ %1 = select %0, %arg0, %arg1 : i64 return %1 : i64 } + +// ----- + +// CHECK-LABEL: @branchCondProp +// CHECK: %true = constant true +// CHECK: %false = constant false +// CHECK: cond_br %arg0, ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: "test.consumer1"(%true) : (i1) -> () +// CHECK: br ^bb3 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: "test.consumer2"(%false) : (i1) -> () +// CHECK: br ^bb3 +// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2 +// CHECK: return +func @branchCondProp(%arg0: i1) { + cond_br %arg0, ^trueB, ^falseB + +^trueB: + "test.consumer1"(%arg0) : (i1) -> () + br ^exit + +^falseB: + "test.consumer2"(%arg0) : (i1) -> () + br ^exit + +^exit: + return +}