diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -31,6 +31,7 @@ std::unique_ptr createFirToCfgPass(); std::unique_ptr createCharacterConversionPass(); std::unique_ptr createExternalNameConversionPass(); +std::unique_ptr createMemDataFlowOptPass(); std::unique_ptr createPromoteToAffinePass(); /// Support for inlining on FIR. diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -120,4 +120,17 @@ let constructor = "::fir::createExternalNameConversionPass()"; } +def MemRefDataFlowOpt : FunctionPass<"fir-memref-dataflow-opt"> { + let summary = + "Perform store/load forwarding and potentially removing dead stores."; + let description = [{ + This pass performs store to load forwarding to eliminate memory accesses and + potentially the entire allocation if all the accesses are forwarded. + }]; + let constructor = "::fir::createMemDataFlowOptPass()"; + let dependentDialects = [ + "fir::FIROpsDialect", "mlir::StandardOpsDialect" + ]; +} + #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ CharacterConversion.cpp Inliner.cpp ExternalNameConversion.cpp + MemRefDataFlowOpt.cpp RewriteLoop.cpp DEPENDS diff --git a/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp b/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp @@ -0,0 +1,130 @@ +//===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "fir-memref-dataflow-opt" + +namespace { + +template +static std::vector getSpecificUsers(mlir::Value v) { + std::vector ops; + for (mlir::Operation *user : v.getUsers()) + if (auto op = dyn_cast(user)) + ops.push_back(op); + return ops; +} + +/// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead +/// and AffineWrite interface +template +class LoadStoreForwarding { +public: + LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {} + + // FIXME: This algorithm has a bug. It ignores escaping references between a + // store and a load. + llvm::Optional findStoreToForward(ReadOp loadOp, + std::vector &&storeOps) { + llvm::SmallVector candidateSet; + + for (auto storeOp : storeOps) + if (domInfo->dominates(storeOp, loadOp)) + candidateSet.push_back(storeOp); + + if (candidateSet.empty()) + return {}; + + llvm::Optional nearestStore; + for (auto candidate : candidateSet) { + auto nearerThan = [&](WriteOp otherStore) { + if (candidate == otherStore) + return false; + bool rv = domInfo->properlyDominates(candidate, otherStore); + if (rv) { + LLVM_DEBUG(llvm::dbgs() + << "candidate " << candidate << " is not the nearest to " + << loadOp << " because " << otherStore << " is closer\n"); + } + return rv; + }; + if (!llvm::any_of(candidateSet, nearerThan)) { + nearestStore = mlir::cast(candidate); + break; + } + } + if (!nearestStore) { + LLVM_DEBUG( + llvm::dbgs() + << "load " << loadOp << " has " << candidateSet.size() + << " store candidates, but this algorithm can't find a best.\n"); + } + return nearestStore; + } + + llvm::Optional findReadForWrite(WriteOp storeOp, + std::vector &&loadOps) { + for (auto &loadOp : loadOps) { + if (domInfo->dominates(storeOp, loadOp)) + return loadOp; + } + return {}; + } + +private: + mlir::DominanceInfo *domInfo; +}; + +class MemDataFlowOpt : public fir::MemRefDataFlowOptBase { +public: + void runOnFunction() override { + mlir::FuncOp f = getFunction(); + + auto *domInfo = &getAnalysis(); + LoadStoreForwarding lsf(domInfo); + f.walk([&](fir::LoadOp loadOp) { + auto maybeStore = lsf.findStoreToForward( + loadOp, getSpecificUsers(loadOp.memref())); + if (maybeStore) { + auto storeOp = maybeStore.getValue(); + LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName() + << " erasing load " << loadOp + << " with value from " << storeOp << '\n'); + loadOp.getResult().replaceAllUsesWith(storeOp.value()); + loadOp.erase(); + } + }); + f.walk([&](fir::AllocaOp alloca) { + for (auto &storeOp : getSpecificUsers(alloca.getResult())) { + if (!lsf.findReadForWrite( + storeOp, getSpecificUsers(storeOp.memref()))) { + LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName() + << " erasing store " << storeOp << '\n'); + storeOp.erase(); + } + } + }); + } +}; +} // namespace + +std::unique_ptr fir::createMemDataFlowOptPass() { + return std::make_unique(); +} diff --git a/flang/test/Fir/memref-data-flow.fir b/flang/test/Fir/memref-data-flow.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/memref-data-flow.fir @@ -0,0 +1,79 @@ +// RUN: fir-opt --split-input-file --fir-memref-dataflow-opt %s | FileCheck %s + +// Test that all load-store chains are removed + +func @load_store_chain_removal(%arg0: !fir.ref>, %arg1: !fir.ref>, %arg2: !fir.ref>) { + %c1_i64 = arith.constant 1 : i64 + %c60 = arith.constant 60 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"} + %1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"} + br ^bb1(%c1, %c60 : index, index) +^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2 + %4 = arith.cmpi sgt, %3, %c0 : index + cond_br %4, ^bb2, ^bb3 +^bb2: // pred: ^bb1 + %5 = fir.convert %2 : (index) -> i32 + fir.store %5 to %0 : !fir.ref + %6 = fir.load %0 : !fir.ref + %7 = fir.convert %6 : (i32) -> i64 + %8 = arith.subi %7, %c1_i64 : i64 + %9 = fir.coordinate_of %arg0, %8 : (!fir.ref>, i64) -> !fir.ref + %10 = fir.load %9 : !fir.ref + %11 = arith.addi %10, %10 : i32 + %12 = fir.coordinate_of %1, %8 : (!fir.ref>, i64) -> !fir.ref + fir.store %11 to %12 : !fir.ref + %13 = arith.addi %2, %c1 : index + %14 = arith.subi %3, %c1 : index + br ^bb1(%13, %14 : index, index) +^bb3: // pred: ^bb1 + %15 = fir.convert %2 : (index) -> i32 + fir.store %15 to %0 : !fir.ref + br ^bb4(%c1, %c60 : index, index) +^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5 + %18 = arith.cmpi sgt, %17, %c0 : index + cond_br %18, ^bb5, ^bb6 +^bb5: // pred: ^bb4 + %19 = fir.convert %16 : (index) -> i32 + fir.store %19 to %0 : !fir.ref + %20 = fir.load %0 : !fir.ref + %21 = fir.convert %20 : (i32) -> i64 + %22 = arith.subi %21, %c1_i64 : i64 + %23 = fir.coordinate_of %1, %22 : (!fir.ref>, i64) -> !fir.ref + %24 = fir.load %23 : !fir.ref + %25 = fir.coordinate_of %arg1, %22 : (!fir.ref>, i64) -> !fir.ref + %26 = fir.load %25 : !fir.ref + %27 = arith.muli %24, %26 : i32 + %28 = fir.coordinate_of %arg2, %22 : (!fir.ref>, i64) -> !fir.ref + fir.store %27 to %28 : !fir.ref + %29 = arith.addi %16, %c1 : index + %30 = arith.subi %17, %c1 : index + br ^bb4(%29, %30 : index, index) +^bb6: // pred: ^bb4 + %31 = fir.convert %16 : (index) -> i32 + fir.store %31 to %0 : !fir.ref + return +} + +// CHECK-LABEL: func @load_store_chain_removal +// CHECK-LABEL: ^bb1 +// CHECK-LABEL: ^bb2: +// Make sure the previous fir.store/fir.load pair have been elimated and we +// preserve the last pair of fir.load/fir.store. +// CHECK-COUNT-1: %{{.*}} = fir.load %{{.*}} : !fir.ref +// CHECK-COUNT-1: fir.store %{{.*}} to %{{.*}} : !fir.ref +// CHECK-LABEL: ^bb3: +// Make sure the fir.store has been removed. +// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref +// CHECK-LABEL: ^bb5: +// CHECK: %{{.*}} = fir.convert %{{.*}} : (index) -> i32 +// Check that the fir.store/fir.load pair has been removed between the convert. +// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref +// CHECK-NOT: %{{.*}} = fir.load %{{.*}} : !fir.ref +// CHECK: %{{.*}} = fir.convert %{{.*}} : (i32) -> i64 +// CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref +// CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref +// CHECK: fir.store %{{.*}} to %{{.*}} : !fir.ref +// CHECK-LABEL: ^bb6: +// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref