diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1123,10 +1123,13 @@ } /// Compare this range with another. - template bool operator==(const OtherT &other) { + template bool operator==(const OtherT &other) const { return size() == std::distance(other.begin(), other.end()) && std::equal(begin(), end(), other.begin()); } + template bool operator!=(const OtherT &other) const { + return !(*this == other); + } /// Return the size of this range. size_t size() const { return count; } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1951,9 +1951,9 @@ }]; let arguments = (ins BoolLike:$condition, - SignlessIntegerOrFloatLike:$true_value, - SignlessIntegerOrFloatLike:$false_value); - let results = (outs SignlessIntegerOrFloatLike:$result); + AnyType:$true_value, + AnyType:$false_value); + let results = (outs AnyType:$result); let verifier = ?; let builders = [OpBuilder< diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -248,6 +248,10 @@ /// destinations) is not considered to be a single predecessor. Block *getSinglePredecessor(); + /// If this block has a unique predecessor, i.e., all incoming edges originate + /// from one block, return it. Otherwise, return null. + Block *getUniquePredecessor(); + // Indexed successor access. unsigned getNumSuccessors(); Block *getSuccessor(unsigned i); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -684,23 +684,15 @@ //===----------------------------------------------------------------------===// // Return the type of the same shape (scalar, vector or tensor) containing i1. -static Type getCheckedI1SameShape(Type type) { +static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(1, type.getContext()); - if (type.isSignlessIntOrIndexOrFloat()) - return i1Type; if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i1Type); if (type.isa()) return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast()) return VectorType::get(vectorType.getShape(), i1Type); - return Type(); -} - -static Type getI1SameShape(Type type) { - Type res = getCheckedI1SameShape(type); - assert(res && "expected type with valid i1 shape"); - return res; + return i1Type; } //===----------------------------------------------------------------------===// @@ -840,8 +832,10 @@ //===----------------------------------------------------------------------===// namespace { -/// cond_br true, ^bb1, ^bb2 -> br ^bb1 -/// cond_br false, ^bb1, ^bb2 -> br ^bb2 +/// cond_br true, ^bb1, ^bb2 +/// -> br ^bb1 +/// cond_br false, ^bb1, ^bb2 +/// -> br ^bb2 /// struct SimplifyConstCondBranchPred : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -869,7 +863,7 @@ /// ^bb2 /// br ^bbK(...) /// -/// cond_br %cond, ^bbN(...), ^bbK(...) +/// -> cond_br %cond, ^bbN(...), ^bbK(...) /// struct SimplifyPassThroughCondBranch : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -943,12 +937,70 @@ return success(); } }; + +/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) +/// -> br ^bb1(A, ..., N) +/// +/// cond_br %cond, ^bb1(A), ^bb1(B) +/// -> %select = select %cond, A, B +/// br ^bb1(%select) +/// +struct SimplifyCondBranchIdenticalSuccessors + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // Check that the true and false destinations are the same and have the same + // operands. + Block *trueDest = condbr.trueDest(); + if (trueDest != condbr.falseDest()) + return failure(); + + // If all of the operands match, no selects need to be generated. + OperandRange trueOperands = condbr.getTrueOperands(); + OperandRange falseOperands = condbr.getFalseOperands(); + if (trueOperands == falseOperands) { + rewriter.replaceOpWithNewOp(condbr, trueDest, trueOperands); + return success(); + } + + // Otherwise, if the current block is the only predecessor insert selects + // for any mismatched branch operands. + if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock()) + return failure(); + + // TODO: ATM Tensor/Vector SelectOp requires that the condition has the same + // shape as the operands. We should relax that to allow an i1 to signify + // that everything is selected. + auto doesntSupportsScalarI1 = [](Type type) { + return type.isa() || type.isa(); + }; + if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1)) + return failure(); + + // Generate a select for any operands that differ between the two. + SmallVector mergedOperands; + mergedOperands.reserve(trueOperands.size()); + Value condition = condbr.getCondition(); + for (auto it : llvm::zip(trueOperands, falseOperands)) { + if (std::get<0>(it) == std::get<1>(it)) + mergedOperands.push_back(std::get<0>(it)); + else + mergedOperands.push_back(rewriter.create( + condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); + } + + rewriter.replaceOpWithNewOp(condbr, trueDest, mergedOperands); + return success(); + } +}; } // end anonymous namespace void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert( - context); + results.insert(context); } Optional CondBranchOp::getSuccessorOperands(unsigned index) { diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -229,6 +229,21 @@ return it == pred_end() ? firstPred : nullptr; } +/// If this block has a unique predecessor, i.e., all incoming edges originate +/// from one block, return it. Otherwise, return null. +Block *Block::getUniquePredecessor() { + auto it = pred_begin(), e = pred_end(); + if (it == e) + return nullptr; + + // Check for any conflicting predecessors. + auto *firstPred = *it; + for (++it; it != e; ++it) + if (*it != firstPred) + return nullptr; + return firstPred; +} + //===----------------------------------------------------------------------===// // Other //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s -// Test the folding of BranchOp. +/// Test the folding of BranchOp. // CHECK-LABEL: func @br_folding( func @br_folding() -> i32 { @@ -12,11 +12,11 @@ return %x : i32 } -// Test the folding of CondBranchOp with a constant condition. +/// Test the folding of CondBranchOp with a constant condition. // CHECK-LABEL: func @cond_br_folding( func @cond_br_folding(%cond : i1, %a : i32) { - // CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb1 + // CHECK-NEXT: return %false_cond = constant 0 : i1 %true_cond = constant 1 : i1 @@ -29,13 +29,62 @@ cond_br %false_cond, ^bb2(%x : i32), ^bb3 ^bb3: - // CHECK: ^bb1: + return +} + +/// Test the folding of CondBranchOp when the successors are identical. + +// CHECK-LABEL: func @cond_br_same_successor( +func @cond_br_same_successor(%cond : i1, %a : i32) { // CHECK-NEXT: return + cond_br %cond, ^bb1(%a : i32), ^bb1(%a : i32) + +^bb1(%result : i32): return } -// Test the compound folding of BranchOp and CondBranchOp. +/// Test the folding of CondBranchOp when the successors are identical, but the +/// arguments are different. + +// CHECK-LABEL: func @cond_br_same_successor_insert_select( +// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 { + // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] + // CHECK: return %[[RES]] + + cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32) + +^bb1(%result : i32): + return %result : i32 +} + +/// Check that we don't generate a select if the type requires a splat. +/// TODO: SelectOp should allow for matching a vector/tensor with i1. + +// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor( +func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>, + %b : tensor<2xi32>) -> tensor<2xi32>{ + // CHECK: cond_br + + cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>) + +^bb1(%result : tensor<2xi32>): + return %result : tensor<2xi32> +} + +// CHECK-LABEL: func @cond_br_same_successor_no_select_vector( +func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>, + %b : vector<2xi32>) -> vector<2xi32> { + // CHECK: cond_br + + cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>) + +^bb1(%result : vector<2xi32>): + return %result : vector<2xi32> +} + +/// Test the compound folding of BranchOp and CondBranchOp. // CHECK-LABEL: func @cond_br_and_br_folding( func @cond_br_and_br_folding(%a : i32) { @@ -55,9 +104,11 @@ /// Test that pass-through successors of CondBranchOp get folded. // CHECK-LABEL: func @cond_br_pass_through( -// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32 +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1 func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) { - // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG0]], %[[ARG1]] : i32, i32), ^bb1(%[[ARG2]], %[[ARG2]] : i32, i32) + // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]] + // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]] + // CHECK: return %[[RES]], %[[RES2]] cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32) @@ -65,9 +116,6 @@ br ^bb2(%arg3, %arg1 : i32, i32) ^bb2(%arg4: i32, %arg5: i32): - // CHECK: ^bb1(%[[RET0:.*]]: i32, %[[RET1:.*]]: i32): - // CHECK-NEXT: return %[[RET0]], %[[RET1]] - return %arg4, %arg5 : i32, i32 } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -297,12 +297,6 @@ // ----- -func @invalid_select_shape(%cond : i1, %idx : () -> ()) { - // expected-error@+1 {{'result' must be signless-integer-like or floating-point-like, but got '() -> ()'}} - %sel = select %cond, %idx, %idx : () -> () - -// ----- - func @invalid_cmp_shape(%idx : () -> ()) { // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}} %cmp = cmpi "eq", %idx, %idx : () -> ()