diff --git a/mlir/include/mlir/IR/Dominance.h b/mlir/include/mlir/IR/Dominance.h --- a/mlir/include/mlir/IR/Dominance.h +++ b/mlir/include/mlir/IR/Dominance.h @@ -19,6 +19,10 @@ using DominanceInfoNode = llvm::DomTreeNodeBase; class Operation; +/// Return true if the region with the given index inside the operation +/// has SSA dominance. +bool hasSSADominance(Operation *op, unsigned index); + namespace detail { template class DominanceInfoBase { using base = llvm::DominatorTreeBase; diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp --- a/mlir/lib/IR/Dominance.cpp +++ b/mlir/lib/IR/Dominance.cpp @@ -24,9 +24,7 @@ template class llvm::DominatorTreeBase; template class llvm::DomTreeNodeBase; -/// Return true if the region with the given index inside the operation -/// has SSA dominance. -static bool hasSSADominance(Operation *op, unsigned index) { +bool mlir::hasSSADominance(Operation *op, unsigned index) { auto kindInterface = dyn_cast(op); return op->isRegistered() && (!kindInterface || kindInterface.hasSSADominance(index)); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -113,10 +113,29 @@ // Look for an existing definition for the operation. if (auto *existing = knownValues.lookup(op)) { + // If we find one then replace all uses of the current operation with the - // existing one and mark it for deletion. - op->replaceAllUsesWith(existing); - opsToErase.push_back(op); + // existing one and mark it for deletion. We can only replace an operand in + // an operation if it has not been visited yet. + if (hasSSADominance(op->getParentOp(), + op->getParentRegion()->getRegionNumber())) { + // If the region has SSA dominance, then we are guaranteed to have not + // visited any use of the current operation. + op->replaceAllUsesWith(existing); + opsToErase.push_back(op); + } else { + // When the region does not have SSA dominance, we need to check if we + // have visited a use before updating the operation. + for (unsigned i = 0, e = existing->getNumResults(); i != e; ++i) { + auto newResult = existing->getResult(i); + op->getResult(i).replaceUsesWithIf(newResult, [&](OpOperand &operand) { + return !knownValues.count(operand.getOwner()); + }); + } + // There may be some remaining uses of the operation. + if (op->use_empty()) + opsToErase.push_back(op); + } // If the existing operation has an unknown location and the current // operation doesn't, then set the existing op's location to that of the diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -244,3 +244,24 @@ return %0 : i32 } + +/// This test is checking that CSE gracefully handles values in graph regions +/// where the use occurs before the def, and one of the defs could be CSE'd with +/// the other. +// CHECK-LABEL: @use_before_def +func @use_before_def() { + // CHECK-NEXT: test.graph_region + test.graph_region { + // CHECK-NEXT: addi %c1_i32, %c1_i32_0 + %0 = addi %1, %2 : i32 + + // CHECK-NEXT: constant 1 + // CHECK-NEXT: constant 1 + %1 = constant 1 : i32 + %2 = constant 1 : i32 + + // CHECK-NEXT: "foo.yield"(%0) : (i32) -> () + "foo.yield"(%0) : (i32) -> () + } + return +}