diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -72,10 +72,10 @@ /// Attempt to eliminate a redundant operation. Returns success if the /// operation was marked for removal, failure otherwise. - LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op); - + LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op, + bool hasSSADominance); void simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo, - Block *bb); + Block *bb, bool hasSSADominance); void simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo, Region ®ion); @@ -88,7 +88,8 @@ } // end anonymous namespace /// Attempt to eliminate a redundant operation. -LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) { +LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op, + bool hasSSADominance) { // Don't simplify terminator operations. if (op->hasTrait()) return failure(); @@ -113,10 +114,28 @@ // Look for an existing definition for the operation. if (auto *existing = knownValues.lookup(op)) { + // If we find one then replace all uses of the current operation with the - // existing one and mark it for deletion. - op->replaceAllUsesWith(existing); - opsToErase.push_back(op); + // existing one and mark it for deletion. We can only replace an operand in + // an operation if it has not been visited yet. + if (hasSSADominance) { + // If the region has SSA dominance, then we are guaranteed to have not + // visited any use of the current operation. + op->replaceAllUsesWith(existing); + opsToErase.push_back(op); + } else { + // When the region does not have SSA dominance, we need to check if we + // have visited a use before replacing any use. + for (unsigned i = 0, e = existing->getNumResults(); i != e; ++i) { + auto newResult = existing->getResult(i); + op->getResult(i).replaceUsesWithIf(newResult, [&](OpOperand &operand) { + return !knownValues.count(operand.getOwner()); + }); + } + // There may be some remaining uses of the operation. + if (op->use_empty()) + opsToErase.push_back(op); + } // If the existing operation has an unknown location and the current // operation doesn't, then set the existing op's location to that of the @@ -136,10 +155,10 @@ } void CSE::simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo, - Block *bb) { + Block *bb, bool hasSSADominance) { for (auto &inst : *bb) { // If the operation is simplified, we don't process any held regions. - if (succeeded(simplifyOperation(knownValues, &inst))) + if (succeeded(simplifyOperation(knownValues, &inst, hasSSADominance))) continue; // If this operation is isolated above, we can't process nested regions with @@ -164,17 +183,19 @@ if (region.empty()) return; + bool hasSSADominance = domInfo.hasDominanceInfo(®ion); + // If the region only contains one block, then simplify it directly. if (std::next(region.begin()) == region.end()) { ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(knownValues, domInfo, ®ion.front()); + simplifyBlock(knownValues, domInfo, ®ion.front(), hasSSADominance); return; } // If the region does not have dominanceInfo, then skip it. // TODO: Regions without SSA dominance should define a different // traversal order which is appropriate and can be used here. - if (!domInfo.hasDominanceInfo(®ion)) + if (!hasSSADominance) return; // Note, deque is being used here because there was significant performance @@ -195,7 +216,8 @@ // Check to see if we need to process this node. if (!currentNode->processed) { currentNode->processed = true; - simplifyBlock(knownValues, domInfo, currentNode->node->getBlock()); + simplifyBlock(knownValues, domInfo, currentNode->node->getBlock(), + hasSSADominance); } // Otherwise, check to see if we need to process a child node. diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -244,3 +244,24 @@ return %0 : i32 } + +/// This test is checking that CSE gracefully handles values in graph regions +/// where the use occurs before the def, and one of the defs could be CSE'd with +/// the other. +// CHECK-LABEL: @use_before_def +func @use_before_def() { + // CHECK-NEXT: test.graph_region + test.graph_region { + // CHECK-NEXT: addi %c1_i32, %c1_i32_0 + %0 = addi %1, %2 : i32 + + // CHECK-NEXT: constant 1 + // CHECK-NEXT: constant 1 + %1 = constant 1 : i32 + %2 = constant 1 : i32 + + // CHECK-NEXT: "foo.yield"(%0) : (i32) -> () + "foo.yield"(%0) : (i32) -> () + } + return +}