diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -464,10 +464,6 @@ /// A set of operand+index pairs that correspond to operands that need to be /// replaced by arguments when the cluster gets merged. std::set> operandsToMerge; - - /// A map of operations with external uses to a replacement within the leader - /// block. - DenseMap opsToReplace; }; } // end anonymous namespace @@ -480,7 +476,6 @@ // A set of operands that mismatch between the leader and the new block. SmallVector, 8> mismatchedOperands; - SmallVector, 2> newOpsToReplace; auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { @@ -519,9 +514,16 @@ return failure(); } - // If the rhs has external uses, it will need to be replaced. - if (rhsIt->isUsedOutsideOfBlock(mergeBlock)) - newOpsToReplace.emplace_back(&*rhsIt, &*lhsIt); + // If the lhs or rhs has external uses, the blocks cannot be merged as the + // merged version of this operation will not be either the lhs or rhs + // alone (thus semantically incorrect), but some mix dependening on which + // block preceeded this. + // TODO allow merging of operations when one block does not dominate the + // other + if (rhsIt->isUsedOutsideOfBlock(mergeBlock) || + lhsIt->isUsedOutsideOfBlock(leaderBlock)) { + return failure(); + } } // Make sure that the block sizes are equivalent. if (lhsIt != lhsE || rhsIt != rhsE) @@ -529,7 +531,6 @@ // If we get here, the blocks are equivalent and can be merged. operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end()); - opsToReplace.insert(newOpsToReplace.begin(), newOpsToReplace.end()); blocksToMerge.insert(blockData.block); return success(); } @@ -561,10 +562,6 @@ !llvm::all_of(blocksToMerge, ableToUpdatePredOperands)) return failure(); - // Replace any necessary operations. - for (std::pair &it : opsToReplace) - it.first->replaceAllUsesWith(it.second); - // Collect the iterators for each of the blocks to merge. We will walk all // of the iterators at once to avoid operand index invalidation. SmallVector blockIterators; diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir --- a/mlir/test/Transforms/canonicalize-block-merge.mlir +++ b/mlir/test/Transforms/canonicalize-block-merge.mlir @@ -174,26 +174,24 @@ return } -// Check that properly handles back edges and the case where a value from one -// block is used in another. +// Check that properly handles back edges. // CHECK-LABEL: func @mismatch_loop( // CHECK-SAME: %[[ARG:.*]]: i1, %[[ARG2:.*]]: i1 func @mismatch_loop(%cond : i1, %cond2 : i1) { + // CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op" // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG2]] : i1), ^bb2 + %cond3 = "foo.op"() : () -> (i1) cond_br %cond, ^bb2, ^bb3 ^bb1: // CHECK: ^bb1(%[[ARG3:.*]]: i1): - // CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op" // CHECK-NEXT: cond_br %[[ARG3]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2 - %ignored = "foo.op"() : () -> (i1) cond_br %cond3, ^bb1, ^bb3 ^bb2: - %cond3 = "foo.op"() : () -> (i1) cond_br %cond2, ^bb1, ^bb3 ^bb3: @@ -224,3 +222,32 @@ store %true, %arg2[] : memref br ^bb1 } + +// Check that it is illegal to merge blocks containing an operand +// with an external user. Incorrectly performing the optimization +// anyways will result in print(merged, merged) rather than +// distinct operands. +func private @print(%arg0: i32, %arg1: i32) +// CHECK-LABEL: @nomerge +func @nomerge(%arg0: i32, %i: i32) { + %c1_i32 = constant 1 : i32 + %icmp = cmpi "slt", %i, %arg0 : i32 + cond_br %icmp, ^bb2, ^bb3 + +^bb2: // pred: ^bb1 + %ip1 = addi %i, %c1_i32 : i32 + br ^bb4(%ip1 : i32) + +^bb7: // pred: ^bb5 + %jp1 = addi %j, %c1_i32 : i32 + br ^bb4(%jp1 : i32) + +^bb4(%j: i32): // 2 preds: ^bb2, ^bb7 + %jcmp = cmpi "slt", %j, %arg0 : i32 +// CHECK-NOT: call @print(%[[arg1:.+]], %[[arg1]]) + call @print(%j, %ip1) : (i32, i32) -> () + cond_br %jcmp, ^bb7, ^bb3 + +^bb3: // pred: ^bb1 + return +}