diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -930,13 +930,71 @@ return success(); } }; + +/// cond_br %arg0, ^trueB, ^falseB +/// +/// ^trueB: +/// "test.consumer1"(%arg0) : (i1) -> () +/// ... +/// +/// ^falseB: +/// "test.consumer2"(%arg0) : (i1) -> () +/// ... +/// +/// -> +/// +/// cond_br %arg0, ^trueB, ^falseB +/// ^trueB: +/// "test.consumer1"(%true) : (i1) -> () +/// ... +/// +/// ^falseB: +/// "test.consumer2"(%false) : (i1) -> () +/// ... +struct CondBranchTruthPropagation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // Check that we have a single distinct predecessor. + bool replaced = false; + mlir::Type ty = rewriter.getI1Type(); + + if (condbr.getTrueDest()->getSinglePredecessor()) { + for (OpOperand &use : + llvm::make_early_inc_range(condbr.condition().getUses())) { + if (use.getOwner()->getBlock() == condbr.getTrueDest()) { + replaced = true; + rewriter.updateRootInPlace(use.getOwner(), [&]() { + use.set(rewriter.create( + condbr.getLoc(), ty, rewriter.getIntegerAttr(ty, 1))); + }); + } + } + } + if (condbr.getFalseDest()->getSinglePredecessor()) { + for (OpOperand &use : + llvm::make_early_inc_range(condbr.condition().getUses())) { + if (use.getOwner()->getBlock() == condbr.getFalseDest()) { + replaced = true; + rewriter.updateRootInPlace(use.getOwner(), [&]() { + use.set(rewriter.create( + condbr.getLoc(), ty, rewriter.getIntegerAttr(ty, 0))); + }); + } + } + } + return replaced ? success() : failure(); + } +}; } // end anonymous namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + SimplifyCondBranchFromCondBranchOnSameCondition, + CondBranchTruthPropagation>(context); } Optional diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -306,3 +306,32 @@ %1 = select %0, %arg0, %arg1 : i64 return %1 : i64 } + +// ----- + +// CHECK-LABEL: @branchCondProp +// CHECK: %true = constant true +// CHECK: %false = constant false +// CHECK: cond_br %arg0, ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: "test.consumer1"(%true) : (i1) -> () +// CHECK: br ^bb3 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: "test.consumer2"(%false) : (i1) -> () +// CHECK: br ^bb3 +// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2 +// CHECK: return +func @branchCondProp(%arg0: i1) { + cond_br %arg0, ^trueB, ^falseB + +^trueB: + "test.consumer1"(%arg0) : (i1) -> () + br ^exit + +^falseB: + "test.consumer2"(%arg0) : (i1) -> () + br ^exit + +^exit: + return +}