diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -28,6 +28,7 @@ std::unique_ptr createAbstractResultOptPass(); std::unique_ptr createAffineDemotionPass(); +std::unique_ptr createCSEPass(); std::unique_ptr createFirToCfgPass(); std::unique_ptr createCharacterConversionPass(); std::unique_ptr createExternalNameConversionPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -74,6 +74,18 @@ ]; } +def BasicCSE : FunctionPass<"basic-cse"> { + let summary = "Basic common sub-expression elimination."; + let description = [{ + Perform common subexpression elimination on FIR operations. This pass + differs from the MLIR CSE pass in that it is FIR/Fortran semantics aware. + }]; + let constructor = "::fir::createCSEPass()"; + let dependentDialects = [ + "fir::FIROpsDialect", "mlir::StandardOpsDialect" + ]; +} + def CharacterConversion : Pass<"character-conversion"> { let summary = "Convert CHARACTER entities with different KINDs"; let description = [{ diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ AffinePromotion.cpp AffineDemotion.cpp CharacterConversion.cpp + CSE.cpp Inliner.cpp ExternalNameConversion.cpp RewriteLoop.cpp diff --git a/flang/lib/Optimizer/Transforms/CSE.cpp b/flang/lib/Optimizer/Transforms/CSE.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/CSE.cpp @@ -0,0 +1,320 @@ +//===-- CSE.cpp -- common subexpression elimination -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This transformation pass performs a simple common sub-expression elimination +/// algorithm on operations within a function. +/// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/RecyclingAllocator.h" +#include + +using namespace mlir; + +static llvm::cl::opt + leaveEffects("keep-effects", + llvm::cl::desc("disable cleaning up effects attributes"), + llvm::cl::init(false), llvm::cl::Hidden); + +namespace { + +struct SimpleOperationInfo : public llvm::DenseMapInfo { + + /// Compute the hash value of an Operation + static unsigned getHashValue(const Operation *opC) { + auto *op = const_cast(opC); + // Hash the operations based upon their: + // - Operation Name + // - Attributes + // - Result Types + // - Operands + unsigned hashOps; + if (op->hasTrait()) { + std::vector vec; + for (auto i = op->operand_begin(), e = op->operand_end(); i != e; ++i) + vec.push_back((*i).getAsOpaquePointer()); + llvm::sort(vec.begin(), vec.end()); + hashOps = llvm::hash_combine_range(vec.begin(), vec.end()); + } else { + hashOps = + llvm::hash_combine_range(op->operand_begin(), op->operand_end()); + } + auto hashResTys{llvm::hash_combine_range(op->result_type_begin(), + op->result_type_end())}; + return llvm::hash_combine(op->getName(), op->getAttrs(), hashResTys, + hashOps); + } + + static bool isEqual(const Operation *lhsC, const Operation *rhsC) { + auto *lhs = const_cast(lhsC); + auto *rhs = const_cast(rhsC); + if (lhs == rhs) + return true; + if (lhs == getTombstoneKey() || lhs == getEmptyKey() || + rhs == getTombstoneKey() || rhs == getEmptyKey()) + return false; + + // Compare the operation name. + if (lhs->getName() != rhs->getName()) + return false; + // Check operand and result type counts. + if (lhs->getNumOperands() != rhs->getNumOperands() || + lhs->getNumResults() != rhs->getNumResults()) + return false; + // Compare attributes. + if (lhs->getAttrs() != rhs->getAttrs()) + return false; + // Compare operands. + if (lhs->hasTrait()) { + SmallVector lops; + for (auto lod : lhs->getOperands()) + lops.push_back(lod.getAsOpaquePointer()); + llvm::sort(lops.begin(), lops.end()); + SmallVector rops; + for (auto rod : rhs->getOperands()) + rops.push_back(rod.getAsOpaquePointer()); + llvm::sort(rops.begin(), rops.end()); + if (!std::equal(lops.begin(), lops.end(), rops.begin())) + return false; + } else { + if (!std::equal(lhs->operand_begin(), lhs->operand_end(), + rhs->operand_begin())) + return false; + } + // Compare result types. + return std::equal(lhs->result_type_begin(), lhs->result_type_end(), + rhs->result_type_begin()); + } +}; + +/// Basic common sub-expression elimination. +struct BasicCSE : public fir::BasicCSEBase { + /// Shared implementation of operation elimination and scoped map definitions. + using AllocatorTy = llvm::RecyclingAllocator< + llvm::BumpPtrAllocator, + llvm::ScopedHashTableVal>; + using ScopedMapTy = llvm::ScopedHashTable; + + /// Represents a single entry in the depth first traversal of a CFG. + struct CFGStackNode { + CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node) + : scope(knownValues), node(node), childIterator(node->begin()), + processed(false) {} + + /// Scope for the known values. + ScopedMapTy::ScopeTy scope; + + DominanceInfoNode *node; + DominanceInfoNode::iterator childIterator; + + /// If this node has been fully processed yet or not. + bool processed; + }; + + /// Attempt to eliminate a redundant operation. Returns success if the + /// operation was marked for removal, failure otherwise. + LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op); + + void simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo, + Block *bb); + void simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo, + Region ®ion); + + void cleanupBlock(Block *bb) { + for (auto &inst : *bb) { + if (fir::nonVolatileLoad(&inst) || fir::pureCall(&inst)) { + inst.removeAttr(Identifier::get("effects_token", inst.getContext())); + } else if (inst.getNumRegions()) { + for (auto ®ion : inst.getRegions()) + cleanupRegion(region); + } + } + } + void cleanupRegion(Region ®ion) { + for (auto &block : region) + cleanupBlock(&block); + } + + void runOnFunction() override final; + +private: + /// Operations marked as dead and to be erased. + std::vector opsToErase; +}; + +/// Attempt to eliminate a redundant operation. +LogicalResult BasicCSE::simplifyOperation(ScopedMapTy &knownValues, + Operation *op) { + if (op->hasTrait()) + return failure(); + + if (isOpTriviallyDead(op)) { + opsToErase.push_back(op); + return success(); + } + + // Don't simplify operations with nested blocks. We don't currently model + // equality comparisons correctly among other things. It is also unclear + // whether we would want to CSE such operations. + if (op->getNumRegions() != 0) + return failure(); + + if (!MemoryEffectOpInterface::hasNoEffect(op) && !fir::nonVolatileLoad(op) && + !fir::pureCall(op)) + return failure(); + + // Look for an existing definition for the operation. + if (auto *existing = knownValues.lookup(op)) { + // If we find one then replace all uses of the current operation with the + // existing one and mark it for deletion. + op->replaceAllUsesWith(existing); + if (!op->hasTrait()) + opsToErase.push_back(op); + + // If the existing operation has an unknown location and the current + // operation doesn't, then set the existing op's location to that of the + // current op. + if (existing->getLoc().isa() && + !op->getLoc().isa()) { + existing->setLoc(op->getLoc()); + } + return success(); + } + + // Otherwise, we add this operation to the known values map. + knownValues.insert(op, op); + return failure(); +} + +void BasicCSE::simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo, + Block *bb) { + std::intptr_t token = reinterpret_cast(bb); + for (auto &inst : *bb) { + if (fir::nonVolatileLoad(&inst) || fir::pureCall(&inst)) + inst.setAttr("effects_token", + IntegerAttr::get(IndexType::get(inst.getContext()), token)); + if (isa(&inst) || fir::impureCall(&inst) || + inst.getNumRegions() != 0) + token = reinterpret_cast(&inst); + } + for (auto &inst : *bb) { + // If the operation is simplified, we don't process any held regions. + if (succeeded(simplifyOperation(knownValues, &inst))) + continue; + + // If this operation is isolated above, we can't process nested regions with + // the given 'knownValues' map. This would cause the insertion of implicit + // captures in explicit capture only regions. + if (!inst.isRegistered() || + inst.hasTrait()) { + ScopedMapTy nestedKnownValues; + for (auto ®ion : inst.getRegions()) + simplifyRegion(nestedKnownValues, domInfo, region); + continue; + } + + // Otherwise, process nested regions normally. + for (auto ®ion : inst.getRegions()) + simplifyRegion(knownValues, domInfo, region); + } +} + +void BasicCSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo, + Region ®ion) { + // If the region is empty there is nothing to do. + if (region.empty()) + return; + + // If the region only contains one block, then simplify it directly. + if (std::next(region.begin()) == region.end()) { + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(knownValues, domInfo, ®ion.front()); + return; + } + + // Note, deque is being used here because there was significant performance + // gains over vector when the container becomes very large due to the + // specific access patterns. If/when these performance issues are no + // longer a problem we can change this to vector. For more information see + // the llvm mailing list discussion on this: + // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html + std::deque> stack; + + // Process the nodes of the dom tree for this region. + stack.emplace_back(std::make_unique( + knownValues, domInfo.getRootNode(®ion))); + + while (!stack.empty()) { + auto ¤tNode = stack.back(); + + // Check to see if we need to process this node. + if (!currentNode->processed) { + currentNode->processed = true; + simplifyBlock(knownValues, domInfo, currentNode->node->getBlock()); + } + + // Otherwise, check to see if we need to process a child node. + if (currentNode->childIterator != currentNode->node->end()) { + auto *childNode = *(currentNode->childIterator++); + stack.emplace_back( + std::make_unique(knownValues, childNode)); + } else { + // Finally, if the node and all of its children have been processed + // then we delete the node. + stack.pop_back(); + } + } +} + +void BasicCSE::runOnFunction() { + /// A scoped hash table of defining operations within a function. + { + ScopedMapTy knownValues; + simplifyRegion(knownValues, getAnalysis(), + getFunction().getBody()); + } + if (!leaveEffects) + cleanupRegion(getFunction().getBody()); + + // If no operations were erased, then we mark all analyses as preserved. + if (opsToErase.empty()) + return markAllAnalysesPreserved(); + + /// Erase any operations that were marked as dead during simplification. + for (auto *op : opsToErase) + op->erase(); + opsToErase.clear(); + + // We currently don't remove region operations, so mark dominance as + // preserved. + markAnalysesPreserved(); +} + +} // end anonymous namespace + +std::unique_ptr fir::createCSEPass() { + return std::make_unique(); +} diff --git a/flang/test/Fir/cse.fir b/flang/test/Fir/cse.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/cse.fir @@ -0,0 +1,77 @@ +// Test CSE pass + +// RUN: fir-opt --basic-cse %s | FileCheck %s + +// CHECK-LABEL: @fun +func @fun(%a : !fir.ref) -> i64 { + // CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref + %1 = fir.load %a : !fir.ref + // CHECK-NOT: fir.load %{{.*}} : !fir.ref + %2 = fir.load %a : !fir.ref + // CHECK-COUNT-6: arith.addi %{{.*}}, %{{.*}} : i64 + %3 = arith.addi %1, %2 : i64 + %4 = fir.load %a : !fir.ref + %5 = arith.addi %3, %4 : i64 + %6 = fir.load %a : !fir.ref + %7 = arith.addi %5, %6 : i64 + %8 = fir.load %a : !fir.ref + %9 = arith.addi %7, %8 : i64 + %10 = fir.load %a : !fir.ref + %11 = arith.addi %10, %9 : i64 + %12 = fir.load %a : !fir.ref + %13 = arith.addi %11, %12 : i64 + // CHECK-NEXT: return %{{.*}} i64 + return %13 : i64 +} + +// CHECK-LABEL: @bar +func private @bar(%a : !fir.ref) -> i64 + +// CHECK-LABEL: @fun2 +func @fun2(%a : !fir.ref) -> i64 { + // CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref + %1 = fir.load %a : !fir.ref + // CHECK-NEXT: fir.call @bar + %2 = fir.call @bar(%a) { pure = true } : (!fir.ref) -> i64 + // CHECK-COUNT-6: arith.addi %{{.*}}, %{{.*}} : i64 + %3 = arith.addi %1, %2 : i64 + %4 = fir.call @bar(%a) { pure = true } : (!fir.ref) -> i64 + %5 = arith.addi %3, %4 : i64 + %6 = fir.call @bar(%a) { pure = true } : (!fir.ref) -> i64 + %7 = arith.addi %5, %6 : i64 + %8 = fir.call @bar(%a) { pure = true } : (!fir.ref) -> i64 + %9 = arith.addi %7, %8 : i64 + %10 = fir.call @bar(%a) { pure = true } : (!fir.ref) -> i64 + %11 = arith.addi %10, %9 : i64 + %12 = fir.call @bar(%a) { pure = true } : (!fir.ref) -> i64 + %13 = arith.addi %11, %12 : i64 + // CHECK-NEXT: return %{{.*}} : i64 + return %13 : i64 +} + +// Negative test: do not merge loads when an op with regions is between +// CHECK-LABEL: @foo +func @foo(%arg0: !fir.ref) -> f32 { + // CHECK: %{{.*}} = fir.alloca f32 {name = "x"} + %0 = fir.alloca f32 {name = "x"} + %1 = fir.load %arg0 : !fir.ref + fir.store %1 to %0 : !fir.ref + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref + %2 = fir.load %0 : !fir.ref + %3 = arith.cmpf olt, %2, %cst : f32 + fir.if %3 { + // CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref + %7 = fir.load %0 : !fir.ref + %8 = arith.negf %7 : f32 + fir.store %8 to %0 : !fir.ref + } + %cst_0 = arith.constant 1.000000e+00 : f32 + // CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref + %4 = fir.load %0 : !fir.ref + %5 = arith.addf %4, %cst_0 : f32 + fir.store %5 to %0 : !fir.ref + // CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref + %6 = fir.load %0 : !fir.ref + return %6 : f32 +}