diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -930,13 +930,87 @@ return success(); } }; + +/// cond_br %arg0, ^trueB, ^falseB +/// +/// ^trueB: +/// "test.consumer1"(%arg0) : (i1) -> () +/// ... +/// +/// ^falseB: +/// "test.consumer2"(%arg0) : (i1) -> () +/// ... +/// +/// -> +/// +/// cond_br %arg0, ^trueB, ^falseB +/// ^trueB: +/// "test.consumer1"(%true) : (i1) -> () +/// ... +/// +/// ^falseB: +/// "test.consumer2"(%false) : (i1) -> () +/// ... +struct CondBranchTruthPropagation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // Check that we have a single distinct predecessor. + bool replaced = false; + Type ty = rewriter.getI1Type(); + + // These variables serve to prevent creating duplicate constants + // and hold constant true or false values. + Value constantTrue = nullptr; + Value constantFalse = nullptr; + + // TODO These checks can be expanded to encompas any use with only + // either the true of false edge as a predecessor. For now, we fall + // back to checking the single predecessor is given by the true/fasle + // destination, thereby ensuring that only that edge can reach the + // op. + if (condbr.getTrueDest()->getSinglePredecessor()) { + for (OpOperand &use : + llvm::make_early_inc_range(condbr.condition().getUses())) { + if (use.getOwner()->getBlock() == condbr.getTrueDest()) { + replaced = true; + + if (!constantTrue) + constantTrue = rewriter.create( + condbr.getLoc(), ty, rewriter.getBoolAttr(true)); + + rewriter.updateRootInPlace(use.getOwner(), + [&] { use.set(constantTrue); }); + } + } + } + if (condbr.getFalseDest()->getSinglePredecessor()) { + for (OpOperand &use : + llvm::make_early_inc_range(condbr.condition().getUses())) { + if (use.getOwner()->getBlock() == condbr.getFalseDest()) { + replaced = true; + + if (!constantFalse) + constantFalse = rewriter.create( + condbr.getLoc(), ty, rewriter.getBoolAttr(false)); + + rewriter.updateRootInPlace(use.getOwner(), + [&] { use.set(constantFalse); }); + } + } + } + return success(replaced); + } +}; } // end anonymous namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + SimplifyCondBranchFromCondBranchOnSameCondition, + CondBranchTruthPropagation>(context); } Optional diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -399,3 +399,25 @@ %1 = select %0, %arg0, %arg1 : i64 return %1 : i64 } + +// ----- + +// CHECK-LABEL: @branchCondProp +// CHECK: %[[trueval:.+]] = constant true +// CHECK: %[[falseval:.+]] = constant false +// CHECK: "test.consumer1"(%[[trueval]]) : (i1) -> () +// CHECK: "test.consumer2"(%[[falseval]]) : (i1) -> () +func @branchCondProp(%arg0: i1) { + cond_br %arg0, ^trueB, ^falseB + +^trueB: + "test.consumer1"(%arg0) : (i1) -> () + br ^exit + +^falseB: + "test.consumer2"(%arg0) : (i1) -> () + br ^exit + +^exit: + return +}