diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1174,6 +1174,9 @@ return RangeT(iterator_range(*this)); } + /// Returns the base of this range. + const BaseT &getBase() const { return base; } + private: /// Offset the given base by the given amount. static BaseT offset_base(const BaseT &base, size_t n) { diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -79,6 +79,29 @@ namespace llvm { +/// Provide support for hashing successor ranges. +template <> +struct DenseMapInfo { + static mlir::SuccessorRange getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SuccessorRange(pointer, 0); + } + static mlir::SuccessorRange getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SuccessorRange(pointer, 0); + } + static unsigned getHashValue(mlir::SuccessorRange value) { + return llvm::hash_combine_range(value.begin(), value.end()); + } + static bool isEqual(mlir::SuccessorRange lhs, mlir::SuccessorRange rhs) { + if (rhs.getBase() == getEmptyKey().getBase()) + return lhs.getBase() == getEmptyKey().getBase(); + if (rhs.getBase() == getTombstoneKey().getBase()) + return lhs.getBase() == getTombstoneKey().getBase(); + return lhs == rhs; + } +}; + //===----------------------------------------------------------------------===// // ilist_traits for Operation //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -554,6 +554,14 @@ [](OpResult result) { return result.use_empty(); }); } + /// Returns true if the results of this operation are used outside of the + /// given block. + bool isUsedOutsideOfBlock(Block *block) { + return llvm::any_of(getOpResults(), [block](OpResult result) { + return result.isUsedOutsideOfBlock(block); + }); + } + //===--------------------------------------------------------------------===// // Users //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -20,6 +20,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/TrailingObjects.h" @@ -617,6 +618,17 @@ ValueTypeIterator>::iterator_range; template ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {} + + /// Compare this range with another. + template + bool operator==(const OtherT &other) const { + return llvm::size(*this) == llvm::size(other) && + std::equal(this->begin(), this->end(), other.begin()); + } + template + bool operator!=(const OtherT &other) const { + return !(*this == other); + } }; template @@ -829,12 +841,29 @@ /// This class provides utilities for computing if two operations are /// equivalent. struct OperationEquivalence { + enum Flags { + None = 0, + + /// This flag signals that operands should not be considered when checking + /// for equivalence. This allows for users to implement there own + /// equivalence schemes for operand values. The number of operands are still + /// checked, just not the operands themselves. + IgnoreOperands = 1, + + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreOperands) + }; + /// Compute a hash for the given operation. - static llvm::hash_code computeHash(Operation *op); + static llvm::hash_code computeHash(Operation *op, Flags flags = Flags::None); /// Compare two operations and return if they are equivalent. - static bool isEquivalentTo(Operation *lhs, Operation *rhs); + static bool isEquivalentTo(Operation *lhs, Operation *rhs, + Flags flags = Flags::None); }; + +/// Enable Bitmask enums for OperationEquivalence::Flags. +LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); + } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -123,6 +123,9 @@ /// Return the Region in which this Value is defined. Region *getParentRegion(); + /// Return the Block in which this Value is defined. + Block *getParentBlock(); + //===--------------------------------------------------------------------===// // UseLists //===--------------------------------------------------------------------===// @@ -150,6 +153,9 @@ void replaceUsesWithIf(Value newValue, function_ref shouldReplace); + /// Returns true if the value is used outside of the given block. + bool isUsedOutsideOfBlock(Block *block); + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -412,7 +412,7 @@ // Operation Equivalency //===----------------------------------------------------------------------===// -llvm::hash_code OperationEquivalence::computeHash(Operation *op) { +llvm::hash_code OperationEquivalence::computeHash(Operation *op, Flags flags) { // Hash operations based upon their: // - Operation Name // - Attributes @@ -438,12 +438,17 @@ } // - Operands - // TODO: Allow commutative operations to have different ordering. - return llvm::hash_combine( - hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end())); + bool ignoreOperands = flags & Flags::IgnoreOperands; + if (!ignoreOperands) { + // TODO: Allow commutative operations to have different ordering. + hash = llvm::hash_combine( + hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end())); + } + return hash; } -bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) { +bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs, + Flags flags) { if (lhs == rhs) return true; @@ -478,6 +483,9 @@ break; } // Compare operands. + bool ignoreOperands = flags & Flags::IgnoreOperands; + if (ignoreOperands) + return true; // TODO: Allow commutative operations to have different ordering. return std::equal(lhs->operand_begin(), lhs->operand_end(), rhs->operand_begin()); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -87,6 +87,13 @@ return cast().getOwner()->getParent(); } +/// Return the Block in which this Value is defined. +Block *Value::getParentBlock() { + if (Operation *op = getDefiningOp()) + return op->getBlock(); + return cast().getOwner(); +} + //===----------------------------------------------------------------------===// // Value::UseLists //===----------------------------------------------------------------------===// @@ -134,6 +141,13 @@ use.set(newValue); } +/// Returns true if the value is used outside of the given block. +bool Value::isUsedOutsideOfBlock(Block *block) { + return llvm::any_of(getUsers(), [block](Operation *user) { + return user->getBlock() != block; + }); +} + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -367,6 +367,324 @@ return deleteDeadness(regions, liveMap); } +//===----------------------------------------------------------------------===// +// Block Merging +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// BlockEquivalenceData + +namespace { +/// This class contains the information for comparing the equivalencies of two +/// blocks. Blocks are considered equivalent if they contain the same operations +/// in the same order. The only allowed divergence is for operands that come +/// from sources outside of the parent block, i.e. the uses of values produced +/// within the block must be equivalent. +/// e.g., +/// Equivalent: +/// ^bb1(%arg0: i32) +/// return %arg0, %foo : i32, i32 +/// ^bb2(%arg1: i32) +/// return %arg1, %bar : i32, i32 +/// Not Equivalent: +/// ^bb1(%arg0: i32) +/// return %foo, %arg0 : i32, i32 +/// ^bb2(%arg1: i32) +/// return %arg1, %bar : i32, i32 +struct BlockEquivalenceData { + BlockEquivalenceData(Block *block); + + /// Return the order index for the given value that is within the block of + /// this data. + unsigned getOrderOf(Value value) const; + + /// The block this data refers to. + Block *block; + /// A hash value for this block. + llvm::hash_code hash; + /// A map of result producing operations to their relative orders within this + /// block. The order of an operation is the number of defined values that are + /// produced within the block before this operation. + DenseMap opOrderIndex; +}; +} // end anonymous namespace + +BlockEquivalenceData::BlockEquivalenceData(Block *block) + : block(block), hash(0) { + unsigned orderIt = block->getNumArguments(); + for (Operation &op : *block) { + if (unsigned numResults = op.getNumResults()) { + opOrderIndex.try_emplace(&op, orderIt); + orderIt += numResults; + } + auto opHash = OperationEquivalence::computeHash( + &op, OperationEquivalence::Flags::IgnoreOperands); + hash = llvm::hash_combine(hash, opHash); + } +} + +unsigned BlockEquivalenceData::getOrderOf(Value value) const { + assert(value.getParentBlock() == block && "expected value of this block"); + + // Arguments use the argument number as the order index. + if (BlockArgument arg = value.dyn_cast()) + return arg.getArgNumber(); + + // Otherwise, the result order is offset from the parent op's order. + OpResult result = value.cast(); + auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); + assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); + return opOrderIt->second + result.getResultNumber(); +} + +//===----------------------------------------------------------------------===// +// BlockMergeCluster + +namespace { +/// This class represents a cluster of blocks to be merged together. +class BlockMergeCluster { +public: + BlockMergeCluster(BlockEquivalenceData &&leaderData) + : leaderData(std::move(leaderData)) {} + + /// Attempt to add the given block to this cluster. Returns success if the + /// block was merged, failure otherwise. + LogicalResult addToCluster(BlockEquivalenceData &blockData); + + /// Try to merge all of the blocks within this cluster into the leader block. + LogicalResult merge(); + +private: + /// The equivalence data for the leader of the cluster. + BlockEquivalenceData leaderData; + + /// The set of blocks that can be merged into the leader. + llvm::SmallSetVector blocksToMerge; + + /// A set of operand+index pairs that correspond to operands that need to be + /// replaced by arguments when the cluster gets merged. + std::set> operandsToMerge; + + /// A map of operations with external uses to a replacement within the leader + /// block. + DenseMap opsToReplace; +}; +} // end anonymous namespace + +LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { + if (leaderData.hash != blockData.hash) + return failure(); + Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; + if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) + return failure(); + + // A set of operands that mismatch between the leader and the new block. + SmallVector, 8> mismatchedOperands; + SmallVector, 2> newOpsToReplace; + auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); + auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); + for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { + // Check that the operations are equivalent. + if (!OperationEquivalence::isEquivalentTo( + &*lhsIt, &*rhsIt, OperationEquivalence::Flags::IgnoreOperands)) + return failure(); + + // Compare the operands of the two operations. If the operand is within + // the block, it must refer to the same operation. + auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); + for (int operand : llvm::seq(0, lhsIt->getNumOperands())) { + Value lhsOperand = lhsOperands[operand]; + Value rhsOperand = rhsOperands[operand]; + if (lhsOperand == rhsOperand) + continue; + + // Check that these uses are both external, or both internal. + bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; + bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; + if (lhsIsInBlock != rhsIsInBlock) + return failure(); + // Let the operands differ if they are defined in a different block. These + // will become new arguments if the blocks get merged. + if (!lhsIsInBlock) { + mismatchedOperands.emplace_back(opI, operand); + continue; + } + + // Otherwise, these operands must have the same logical order within the + // parent block. + if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand)) + return failure(); + } + + // If the rhs has external uses, it will need to be replaced. + if (rhsIt->isUsedOutsideOfBlock(mergeBlock)) + newOpsToReplace.emplace_back(&*rhsIt, &*lhsIt); + } + // Make sure that the block sizes are equivalent. + if (lhsIt != lhsE || rhsIt != rhsE) + return failure(); + + // If we get here, the blocks are equivalent and can be merged. + operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end()); + opsToReplace.insert(newOpsToReplace.begin(), newOpsToReplace.end()); + blocksToMerge.insert(blockData.block); + return success(); +} + +/// Returns true if the predecessor terminators of the given block can not have +/// their operands updated. +static bool ableToUpdatePredOperands(Block *block) { + for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + auto branch = dyn_cast((*it)->getTerminator()); + if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex())) + return false; + } + return true; +} + +LogicalResult BlockMergeCluster::merge() { + // Don't consider clusters that don't have blocks to merge. + if (blocksToMerge.empty()) + return failure(); + + Block *leaderBlock = leaderData.block; + if (!operandsToMerge.empty()) { + // If the cluster has operands to merge, verify that the predecessor + // terminators of each of the blocks can have their successor operands + // updated. + // TODO: We could try and sub-partition this cluster if only some blocks + // cause the mismatch. + if (!ableToUpdatePredOperands(leaderBlock) || + !llvm::all_of(blocksToMerge, ableToUpdatePredOperands)) + return failure(); + + // Replace any necessary operations. + for (std::pair &it : opsToReplace) + it.first->replaceAllUsesWith(it.second); + + // Collect the iterators for each of the blocks to merge. We will walk all + // of the iterators at once to avoid operand index invalidation. + SmallVector blockIterators; + blockIterators.reserve(blocksToMerge.size() + 1); + blockIterators.push_back(leaderBlock->begin()); + for (Block *mergeBlock : blocksToMerge) + blockIterators.push_back(mergeBlock->begin()); + + // Update each of the predecessor terminators with the new arguments. + SmallVector, 2> newArguments( + 1 + blocksToMerge.size(), + SmallVector(operandsToMerge.size())); + unsigned curOpIndex = 0; + for (auto it : llvm::enumerate(operandsToMerge)) { + unsigned nextOpOffset = it.value().first - curOpIndex; + curOpIndex = it.value().first; + + // Process the operand for each of the block iterators. + for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { + Block::iterator &blockIter = blockIterators[i]; + std::advance(blockIter, nextOpOffset); + auto &operand = blockIter->getOpOperand(it.value().second); + newArguments[i][it.index()] = operand.get(); + + // Update the operand and insert an argument if this is the leader. + if (i == 0) + operand.set(leaderBlock->addArgument(operand.get().getType())); + } + } + // Update the predecessors for each of the blocks. + auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { + for (auto predIt = block->pred_begin(), predE = block->pred_end(); + predIt != predE; ++predIt) { + auto branch = cast((*predIt)->getTerminator()); + unsigned succIndex = predIt.getSuccessorIndex(); + branch.getMutableSuccessorOperands(succIndex)->append( + newArguments[clusterIndex]); + } + }; + updatePredecessors(leaderBlock, /*clusterIndex=*/0); + for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) + updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); + } + + // Replace all uses of the merged blocks with the leader and erase them. + for (Block *block : blocksToMerge) { + block->replaceAllUsesWith(leaderBlock); + block->erase(); + } + return success(); +} + +/// Identify identical blocks within the given region and merge them, inserting +/// new block arguments as necessary. Returns success if any blocks were merged, +/// failure otherwise. +static LogicalResult mergeIdenticalBlocks(Region ®ion) { + if (region.empty() || llvm::hasSingleElement(region)) + return failure(); + + // Identify sets of blocks, other than the entry block, that branch to the + // same successors. We will use these groups to create clusters of equivalent + // blocks. + DenseMap> matchingSuccessors; + for (Block &block : llvm::drop_begin(region, 1)) + matchingSuccessors[block.getSuccessors()].push_back(&block); + + bool mergedAnyBlocks = false; + for (ArrayRef blocks : llvm::make_second_range(matchingSuccessors)) { + if (blocks.size() == 1) + continue; + + SmallVector clusters; + for (Block *block : blocks) { + BlockEquivalenceData data(block); + + // Don't allow merging if this block has any regions. + // TODO: Add support for regions if necessary. + bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) { + return llvm::any_of(op.getRegions(), + [](Region ®ion) { return !region.empty(); }); + }); + if (hasNonEmptyRegion) + continue; + + // Try to add this block to an existing cluster. + bool addedToCluster = false; + for (auto &cluster : clusters) + if ((addedToCluster = succeeded(cluster.addToCluster(data)))) + break; + if (!addedToCluster) + clusters.emplace_back(std::move(data)); + } + for (auto &cluster : clusters) + mergedAnyBlocks |= succeeded(cluster.merge()); + } + + return success(mergedAnyBlocks); +} + +/// Identify identical blocks within the given regions and merge them, inserting +/// new block arguments as necessary. +static LogicalResult mergeIdenticalBlocks(MutableArrayRef regions) { + llvm::SmallSetVector worklist; + for (auto ®ion : regions) + worklist.insert(®ion); + bool anyChanged = false; + while (!worklist.empty()) { + Region *region = worklist.pop_back_val(); + if (succeeded(mergeIdenticalBlocks(*region))) { + worklist.insert(region); + anyChanged = true; + } + + // Add any nested regions to the worklist. + for (Block &block : *region) + for (auto &op : block) + for (auto &nestedRegion : op.getRegions()) + worklist.insert(&nestedRegion); + } + + return success(anyChanged); +} + //===----------------------------------------------------------------------===// // Region Simplification //===----------------------------------------------------------------------===// @@ -376,7 +694,9 @@ /// elimination, as well as some other DCE. This function returns success if any /// of the regions were simplified, failure otherwise. LogicalResult mlir::simplifyRegions(MutableArrayRef regions) { - LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions); - LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions); - return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs)); + bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions)); + bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions)); + bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions)); + return success(eliminatedBlocks || eliminatedOpsOrArgs || + mergedIdenticalBlocks); } diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -559,15 +559,18 @@ // CHECK: spv.selection { spv.selection { + // CHECK: spv.BranchConditional + // CHECK-SAME: ^bb1(%[[DST_VAR_0]], %[[SRC_VALUE_0]] + // CHECK-SAME: ^bb1(%[[DST_VAR_1]], %[[SRC_VALUE_1]] spv.BranchConditional %cond, ^then, ^else ^then: - // CHECK: spv.Store "Function" %[[DST_VAR_0]], %[[SRC_VALUE_0]] ["Aligned", 8] : vector<3xi32> + // CHECK: ^bb1(%[[ARG0:.*]]: !spv.ptr, Function>, %[[ARG1:.*]]: vector<3xi32>): + // CHECK: spv.Store "Function" %[[ARG0]], %[[ARG1]] ["Aligned", 8] : vector<3xi32> spv.Store "Function" %3, %1 ["Aligned", 8]: vector<3xi32> spv.Branch ^merge ^else: - // CHECK: spv.Store "Function" %[[DST_VAR_1]], %[[SRC_VALUE_1]] ["Aligned", 8] : vector<3xi32> spv.Store "Function" %4, %2 ["Aligned", 8] : vector<3xi32> spv.Branch ^merge diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/canonicalize-block-merge.mlir @@ -0,0 +1,204 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s + +// Check the simple case of single operation blocks with a return. + +// CHECK-LABEL: func @return_blocks( +func @return_blocks() { + // CHECK: "foo.cond_br"()[^bb1, ^bb1] + // CHECK: ^bb1: + // CHECK-NEXT: return + // CHECK-NOT: ^bb2 + + "foo.cond_br"() [^bb1, ^bb2] : () -> () + +^bb1: + return +^bb2: + return +} + +// Check the case of identical blocks with matching arguments. + +// CHECK-LABEL: func @matching_arguments( +func @matching_arguments() -> i32 { + // CHECK: "foo.cond_br"()[^bb1, ^bb1] + // CHECK: ^bb1(%{{.*}}: i32): + // CHECK-NEXT: return + // CHECK-NOT: ^bb2 + + "foo.cond_br"() [^bb1, ^bb2] : () -> () + +^bb1(%arg0 : i32): + return %arg0 : i32 +^bb2(%arg1 : i32): + return %arg1 : i32 +} + +// Check that no merging occurs if there is an operand mismatch and we can't +// update th predecessor. + +// CHECK-LABEL: func @mismatch_unknown_terminator +func @mismatch_unknown_terminator(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK: "foo.cond_br"()[^bb1, ^bb2] + + "foo.cond_br"() [^bb1, ^bb2] : () -> () + +^bb1: + return %arg0 : i32 +^bb2: + return %arg1 : i32 +} + +// Check that merging does occurs if there is an operand mismatch and we can +// update th predecessor. + +// CHECK-LABEL: func @mismatch_operands +// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func @mismatch_operands(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 { + // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] + // CHECK: return %[[RES]] + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + return %arg0 : i32 +^bb2: + return %arg1 : i32 +} + +// Check the same as above, but with pre-existing arguments. + +// CHECK-LABEL: func @mismatch_operands_matching_arguments( +// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) { + // CHECK: %[[RES0:.*]] = select %[[COND]], %[[ARG1]], %[[ARG0]] + // CHECK: %[[RES1:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] + // CHECK: return %[[RES1]], %[[RES0]] + + cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32) + +^bb1(%arg2 : i32): + return %arg0, %arg2 : i32, i32 +^bb2(%arg3 : i32): + return %arg1, %arg3 : i32, i32 +} + +// Check that merging does not occur if the uses of the arguments differ. + +// CHECK-LABEL: func @mismatch_argument_uses( +func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) { + // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + + cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32) + +^bb1(%arg2 : i32): + return %arg0, %arg2 : i32, i32 +^bb2(%arg3 : i32): + return %arg3, %arg1 : i32, i32 +} + +// Check that merging does not occur if the types of the arguments differ. + +// CHECK-LABEL: func @mismatch_argument_types( +func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) { + // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + + cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16) + +^bb1(%arg2 : i32): + "foo.return"(%arg2) : (i32) -> () +^bb2(%arg3 : i16): + "foo.return"(%arg3) : (i16) -> () +} + +// Check that merging does not occur if the number of the arguments differ. + +// CHECK-LABEL: func @mismatch_argument_count( +func @mismatch_argument_count(%cond : i1, %arg0 : i32) { + // CHECK: cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + + cond_br %cond, ^bb1(%arg0 : i32), ^bb2 + +^bb1(%arg2 : i32): + "foo.return"(%arg2) : (i32) -> () +^bb2: + "foo.return"() : () -> () +} + +// Check that merging does not occur if the operations differ. + +// CHECK-LABEL: func @mismatch_operations( +func @mismatch_operations(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1, ^bb2 + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + "foo.return"() : () -> () +^bb2: + return +} + +// Check that merging does not occur if the number of operations differ. + +// CHECK-LABEL: func @mismatch_operation_count( +func @mismatch_operation_count(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1, ^bb2 + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + "foo.op"() : () -> () + return +^bb2: + return +} + +// Check that merging does not occur if the blocks contain regions. + +// CHECK-LABEL: func @contains_regions( +func @contains_regions(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1, ^bb2 + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + loop.if %cond { + "foo.op"() : () -> () + } + return +^bb2: + loop.if %cond { + "foo.op"() : () -> () + } + return +} + +// Check that properly handles back edges and the case where a value from one +// block is used in another. + +// CHECK-LABEL: func @mismatch_loop( +// CHECK-SAME: %[[ARG:.*]]: i1 +func @mismatch_loop(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG]] : i1), ^bb2 + + cond_br %cond, ^bb2, ^bb3 + +^bb1: + // CHECK: ^bb1(%[[ARG2:.*]]: i1): + // CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op" + // CHECK-NEXT: cond_br %[[ARG2]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2 + + %ignored = "foo.op"() : () -> (i1) + cond_br %cond2, ^bb1, ^bb3 + +^bb2: + %cond2 = "foo.op"() : () -> (i1) + cond_br %cond, ^bb1, ^bb3 + +^bb3: + // CHECK: ^bb2: + // CHECK-NEXT: return + + return +} diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir --- a/mlir/test/Transforms/canonicalize-dce.mlir +++ b/mlir/test/Transforms/canonicalize-dce.mlir @@ -62,10 +62,6 @@ // Test case: Delete block arguments for cond_br. // CHECK: func @f(%arg0: f32, %arg1: i1) -// CHECK-NEXT: cond_br %arg1, ^bb1, ^bb2 -// CHECK-NEXT: ^bb1: -// CHECK-NEXT: return -// CHECK-NEXT: ^bb2: // CHECK-NEXT: return func @f(%arg0: f32, %pred: i1) { diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s -dump-input-on-failure // CHECK-LABEL: func @test_subi_zero func @test_subi_zero(%arg0: i32) -> i32 { @@ -361,19 +361,15 @@ // CHECK-LABEL: func @dead_dealloc_fold_multi_use func @dead_dealloc_fold_multi_use(%cond : i1) { - // CHECK-NEXT: cond_br + // CHECK-NEXT: return %a = alloc() : memref<4xf32> cond_br %cond, ^bb1, ^bb2 - // CHECK-LABEL: bb1: ^bb1: - // CHECK-NEXT: return dealloc %a: memref<4xf32> return - // CHECK-LABEL: bb2: ^bb2: - // CHECK-NEXT: return dealloc %a: memref<4xf32> return }