diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -253,7 +253,7 @@ let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))"; } -def fir_LoadOp : fir_OneResultOp<"load"> { +def fir_LoadOp : fir_OneResultOp<"load", [MemoryEffects<[MemRead]>]> { let summary = "load a value from a memory reference"; let description = [{ Load a value from a memory reference into an ssa-value (virtual register). @@ -320,7 +320,7 @@ let hasVerifier = 1; } -def fir_StoreOp : fir_Op<"store", []> { +def fir_StoreOp : fir_Op<"store", [MemoryEffects<[MemWrite]>]> { let summary = "store an SSA-value to a memory location"; let description = [{ diff --git a/flang/test/Fir/cse.fir b/flang/test/Fir/cse.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/cse.fir @@ -0,0 +1,57 @@ +// RUN: fir-opt --cse -split-input-file %s | FileCheck %s + +// Check that the redundant fir.load is removed. +func @fun(%arg0: !fir.ref) -> i64 { + %0 = fir.load %arg0 : !fir.ref + %1 = fir.load %arg0 : !fir.ref + %2 = arith.addi %0, %1 : i64 + return %2 : i64 +} + +// CHECK-LABEL: func @fun +// CHECK-NEXT: %[[LOAD:.*]] = fir.load %{{.*}} : !fir.ref +// CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64 + +// ----- + +// CHECK-LABEL: func @fun( +// CHECK-SAME: %[[A:.*]]: !fir.ref +func @fun(%a : !fir.ref) -> i64 { + // CHECK: %[[LOAD:.*]] = fir.load %[[A]] : !fir.ref + %1 = fir.load %a : !fir.ref + %2 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64 + %3 = arith.addi %1, %2 : i64 + %4 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %5 = arith.addi %3, %4 : i64 + %6 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %7 = arith.addi %5, %6 : i64 + %8 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %9 = arith.addi %7, %8 : i64 + %10 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %11 = arith.addi %10, %9 : i64 + %12 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %13 = arith.addi %11, %12 : i64 + // CHECK-NEXT: return %{{.*}} : i64 + return %13 : i64 +} + +// ----- + +func @fun(%a : !fir.ref) -> i64 { + cf.br ^bb1 +^bb1: + %1 = fir.load %a : !fir.ref + %2 = fir.load %a : !fir.ref + %3 = arith.addi %1, %2 : i64 + cf.br ^bb2 +^bb2: + %4 = fir.load %a : !fir.ref + %5 = arith.subi %4, %4 : i64 + return %5 : i64 +} diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -60,6 +60,14 @@ using ScopedMapTy = llvm::ScopedHashTable; + /// Cache holding MemoryEffects information between two operations. The first + /// operation is stored has the key. The second operation is stored inside a + /// pair in the value. The pair also hold the MemoryEffects between those + /// two operations. If the MemoryEffects is nullptr then we assume there is + /// no operation with MemoryEffects::Write between the two operations. + using MemEffectsCache = + DenseMap>; + /// Represents a single entry in the depth first traversal of a CFG. struct CFGStackNode { CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node) @@ -85,12 +93,94 @@ void runOnOperation() override; private: + void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, + Operation *existing, bool hasSSADominance); + + /// Check if there is side-effecting operations other than the given effect + /// between the two operations. + bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp); + /// Operations marked as dead and to be erased. std::vector opsToErase; DominanceInfo *domInfo = nullptr; + MemEffectsCache memEffectsCache; }; } // namespace +void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, + Operation *existing, bool hasSSADominance) { + // If we find one then replace all uses of the current operation with the + // existing one and mark it for deletion. We can only replace an operand in + // an operation if it has not been visited yet. + if (hasSSADominance) { + // If the region has SSA dominance, then we are guaranteed to have not + // visited any use of the current operation. + op->replaceAllUsesWith(existing); + opsToErase.push_back(op); + } else { + // When the region does not have SSA dominance, we need to check if we + // have visited a use before replacing any use. + for (auto it : llvm::zip(op->getResults(), existing->getResults())) { + std::get<0>(it).replaceUsesWithIf( + std::get<1>(it), [&](OpOperand &operand) { + return !knownValues.count(operand.getOwner()); + }); + } + + // There may be some remaining uses of the operation. + if (op->use_empty()) + opsToErase.push_back(op); + } + + // If the existing operation has an unknown location and the current + // operation doesn't, then set the existing op's location to that of the + // current op. + if (existing->getLoc().isa() && !op->getLoc().isa()) + existing->setLoc(op->getLoc()); + + ++numCSE; +} + +bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) { + assert(fromOp->getBlock() == toOp->getBlock()); + assert( + isa(fromOp) && + cast(fromOp).hasEffect() && + isa(toOp) && + cast(toOp).hasEffect()); + Operation *nextOp = fromOp->getNextNode(); + auto result = + memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr)); + if (result.second) { + auto memEffectsCachePair = result.first->second; + if (memEffectsCachePair.second == nullptr) { + // No MemoryEffects::Write has been detected until the cached operation. + // Continue looking from the cached operation to toOp. + nextOp = memEffectsCachePair.first; + } else { + // MemoryEffects::Write has been detected before so there is no need to + // check further. + return true; + } + } + while (nextOp && nextOp != toOp) { + auto nextOpMemEffects = dyn_cast(nextOp); + // TODO: Do we need to handle other effects generically? + // If the operation does not implement the MemoryEffectOpInterface we + // conservatively assumes it writes. + if ((nextOpMemEffects && + nextOpMemEffects.hasEffect()) || + !nextOpMemEffects) { + result.first->second = + std::make_pair(nextOp, MemoryEffects::Write::get()); + return true; + } + nextOp = nextOp->getNextNode(); + } + result.first->second = std::make_pair(toOp, nullptr); + return false; +} + /// Attempt to eliminate a redundant operation. LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op, bool hasSSADominance) { @@ -111,45 +201,34 @@ if (op->getNumRegions() != 0) return failure(); - // TODO: We currently only eliminate non side-effecting - // operations. - if (!MemoryEffectOpInterface::hasNoEffect(op)) + // Some simple use case of operation with memory side-effect are dealt with + // here. Operations with no side-effect are done after. + if (!MemoryEffectOpInterface::hasNoEffect(op)) { + auto memEffects = dyn_cast(op); + // TODO: Only basic use case for operations with MemoryEffects::Read can be + // eleminated now. More work needs to be done for more complicated patterns + // and other side-effects. + if (!memEffects || !memEffects.onlyHasEffect()) + return failure(); + + // Look for an existing definition for the operation. + if (auto *existing = knownValues.lookup(op)) { + if (existing->getBlock() == op->getBlock() && + !hasOtherSideEffectingOpInBetween(existing, op)) { + // The operation that can be deleted has been reach with no + // side-effecting operations in between the existing operation and + // this one so we can remove the duplicate. + replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); + return success(); + } + } + knownValues.insert(op, op); return failure(); + } // Look for an existing definition for the operation. if (auto *existing = knownValues.lookup(op)) { - - // If we find one then replace all uses of the current operation with the - // existing one and mark it for deletion. We can only replace an operand in - // an operation if it has not been visited yet. - if (hasSSADominance) { - // If the region has SSA dominance, then we are guaranteed to have not - // visited any use of the current operation. - op->replaceAllUsesWith(existing); - opsToErase.push_back(op); - } else { - // When the region does not have SSA dominance, we need to check if we - // have visited a use before replacing any use. - for (auto it : llvm::zip(op->getResults(), existing->getResults())) { - std::get<0>(it).replaceUsesWithIf( - std::get<1>(it), [&](OpOperand &operand) { - return !knownValues.count(operand.getOwner()); - }); - } - - // There may be some remaining uses of the operation. - if (op->use_empty()) - opsToErase.push_back(op); - } - - // If the existing operation has an unknown location and the current - // operation doesn't, then set the existing op's location to that of the - // current op. - if (existing->getLoc().isa() && - !op->getLoc().isa()) { - existing->setLoc(op->getLoc()); - } - + replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); ++numCSE; return success(); } @@ -184,6 +263,8 @@ for (auto ®ion : op.getRegions()) simplifyRegion(knownValues, region); } + // Clear the MemoryEffects cache since its usage is by block only. + memEffectsCache.clear(); } void CSE::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { diff --git a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir --- a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir +++ b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir @@ -32,8 +32,7 @@ // CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { // CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { // CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64 +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 // CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> // CHECK: toy.print [[VAL_6]] : memref<3x2xf64> // CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> diff --git a/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir --- a/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir +++ b/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir @@ -32,8 +32,7 @@ // CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { // CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { // CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64 +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 // CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> // CHECK: toy.print [[VAL_6]] : memref<3x2xf64> // CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> diff --git a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir --- a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir +++ b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir @@ -32,8 +32,7 @@ // CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { // CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { // CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64 +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 // CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> // CHECK: toy.print [[VAL_6]] : memref<3x2xf64> // CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -265,3 +265,48 @@ } return } + +/// This test is checking that CSE is removing duplicated read op that follow +/// other. +// CHECK-LABEL: @remove_direct_duplicated_read_op +func @remove_direct_duplicated_read_op() -> i32 { + // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32 + %0 = "test.op_with_memread"() : () -> (i32) + %1 = "test.op_with_memread"() : () -> (i32) + // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE]], %[[READ_VALUE]] : i32 + %2 = arith.addi %0, %1 : i32 + return %2 : i32 +} + +/// This test is checking that CSE is removing duplicated read op that follow +/// other. +// CHECK-LABEL: @remove_multiple_duplicated_read_op +func @remove_multiple_duplicated_read_op() -> i64 { + // CHECK: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i64 + %0 = "test.op_with_memread"() : () -> (i64) + %1 = "test.op_with_memread"() : () -> (i64) + // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %[[READ_VALUE]] : i64 + %2 = arith.addi %0, %1 : i64 + %3 = "test.op_with_memread"() : () -> (i64) + // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 + %4 = arith.addi %2, %3 : i64 + %5 = "test.op_with_memread"() : () -> (i64) + // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 + %6 = arith.addi %4, %5 : i64 + // CHECK-NEXT: return %{{.*}} : i64 + return %6 : i64 +} + +/// This test is checking that CSE is not removing duplicated read op that +/// have write op in between. +// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting +func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 { + // CHECK-NEXT: %[[READ_VALUE0:.*]] = "test.op_with_memread"() : () -> i32 + %0 = "test.op_with_memread"() : () -> (i32) + "test.op_with_memwrite"() : () -> () + // CHECK: %[[READ_VALUE1:.*]] = "test.op_with_memread"() : () -> i32 + %1 = "test.op_with_memread"() : () -> (i32) + // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE0]], %[[READ_VALUE1]] : i32 + %2 = arith.addi %0, %1 : i32 + return %2 : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2761,4 +2761,12 @@ def TestEffectsOpB : TEST_Op<"op_with_effects_b", [MemoryEffects<[MemWrite]>]>; +def TestEffectsRead : TEST_Op<"op_with_memread", + [MemoryEffects<[MemRead]>]> { + let results = (outs AnyInteger); +} + +def TestEffectsWrite : TEST_Op<"op_with_memwrite", + [MemoryEffects<[MemWrite]>]>; + #endif // TEST_OPS