diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -26,7 +26,7 @@ #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -167,25 +167,22 @@ StackArraysAnalysisWrapper(mlir::Operation *op) {} - bool hasErrors() const; - - const AllocMemMap &getCandidateOps(mlir::Operation *func); + // returns nullptr if analysis failed + const AllocMemMap *getCandidateOps(mlir::Operation *func); private: llvm::DenseMap funcMaps; - bool gotError = false; - void analyseFunction(mlir::Operation *func); + mlir::LogicalResult analyseFunction(mlir::Operation *func); }; /// Converts a fir.allocmem to a fir.alloca class AllocMemConversion : public mlir::OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - - AllocMemConversion( + explicit AllocMemConversion( mlir::MLIRContext *ctx, - const llvm::DenseMap &candidateOps); + const StackArraysAnalysisWrapper::AllocMemMap &candidateOps) + : OpRewritePattern(ctx), candidateOps{candidateOps} {} mlir::LogicalResult matchAndRewrite(fir::AllocMemOp allocmem, @@ -196,9 +193,8 @@ static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc); private: - /// allocmem operations that DFA has determined are safe to move to the stack - /// mapping to where to insert replacement freemem operations - const llvm::DenseMap &candidateOps; + /// Handle to the DFA (already run) + const StackArraysAnalysisWrapper::AllocMemMap &candidateOps; /// If we failed to find an insertion point not inside a loop, see if it would /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop @@ -412,7 +408,8 @@ visitOperationImpl(op, *before, after); } -void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { +mlir::LogicalResult +StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { assert(mlir::isa(func)); mlir::DataFlowSolver solver; // constant propagation is required for dead code analysis, dead code analysis @@ -426,8 +423,7 @@ solver.load(); if (failed(solver.initializeAndRun(func))) { llvm::errs() << "DataFlowSolver failed!"; - gotError = true; - return; + return mlir::failure(); } LatticePoint point{func}; @@ -458,22 +454,17 @@ : candidateOps) { llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; }); + return mlir::success(); } -bool StackArraysAnalysisWrapper::hasErrors() const { return gotError; } - -const StackArraysAnalysisWrapper::AllocMemMap & +const StackArraysAnalysisWrapper::AllocMemMap * StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) { - if (!funcMaps.count(func)) - analyseFunction(func); - return funcMaps[func]; + if (!funcMaps.contains(func)) + if (mlir::failed(analyseFunction(func))) + return nullptr; + return &funcMaps[func]; } -AllocMemConversion::AllocMemConversion( - mlir::MLIRContext *ctx, - const llvm::DenseMap &candidateOps) - : OpRewritePattern(ctx), candidateOps(candidateOps) {} - mlir::LogicalResult AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, mlir::PatternRewriter &rewriter) const { @@ -485,9 +476,13 @@ return mlir::failure(); // remove freemem operations + llvm::SmallVector erases; for (mlir::Operation *user : allocmem.getOperation()->getUsers()) if (mlir::isa(user)) - rewriter.eraseOp(user); + erases.push_back(user); + // now we are done iterating the users, it is safe to mutate them + for (mlir::Operation *erase : erases) + rewriter.eraseOp(erase); // replace references to heap allocation with references to stack allocation rewriter.replaceAllUsesWith(allocmem.getResult(), alloca->getResult()); @@ -709,29 +704,31 @@ assert(mlir::isa(func)); auto &analysis = getAnalysis(); - const auto &candidateOps = analysis.getCandidateOps(func); - if (analysis.hasErrors()) { + const StackArraysAnalysisWrapper::AllocMemMap *candidateOps = + analysis.getCandidateOps(func); + if (!candidateOps) { signalPassFailure(); return; } - if (candidateOps.empty()) + if (candidateOps->empty()) return; - runCount += candidateOps.size(); + runCount += candidateOps->size(); + + llvm::SmallVector opsToConvert; + opsToConvert.reserve(candidateOps->size()); + for (auto [op, _] : *candidateOps) + opsToConvert.push_back(op); mlir::MLIRContext &context = getContext(); mlir::RewritePatternSet patterns(&context); - mlir::ConversionTarget target(context); - - target.addLegalDialect(); - target.addDynamicallyLegalOp([&](fir::AllocMemOp alloc) { - return !candidateOps.count(alloc.getOperation()); - }); + mlir::GreedyRewriteConfig config; + // prevent the pattern driver form merging blocks + config.enableRegionSimplification = false; - patterns.insert(&context, candidateOps); - if (mlir::failed( - mlir::applyPartialConversion(func, target, std::move(patterns)))) { + patterns.insert(&context, *candidateOps); + if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, + std::move(patterns), config))) { mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); signalPassFailure(); } diff --git a/flang/test/Transforms/stack-arrays.fir b/flang/test/Transforms/stack-arrays.fir --- a/flang/test/Transforms/stack-arrays.fir +++ b/flang/test/Transforms/stack-arrays.fir @@ -84,6 +84,33 @@ // CHECK-NEXT: return // CHECK-NEXT: } +func.func private @dfa3a_foo(!fir.ref>) -> () +func.func private @dfa3a_bar(!fir.ref>) -> () + +// Check freemem in both regions, with other uses +func.func @dfa3a(%arg0: i1) { + %a = fir.allocmem !fir.array<1xi8> + fir.if %arg0 { + %ref = fir.convert %a : (!fir.heap>) -> !fir.ref> + func.call @dfa3a_foo(%ref) : (!fir.ref>) -> () + fir.freemem %a : !fir.heap> + } else { + %ref = fir.convert %a : (!fir.heap>) -> !fir.ref> + func.call @dfa3a_bar(%ref) : (!fir.ref>) -> () + fir.freemem %a : !fir.heap> + } + return +} +// CHECK: func.func @dfa3a(%arg0: i1) { +// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array<1xi8> +// CHECK-NEXT: fir.if %arg0 { +// CHECK-NEXT: func.call @dfa3a_foo(%[[MEM]]) +// CHECK-NEXT: } else { +// CHECK-NEXT: func.call @dfa3a_bar(%[[MEM]]) +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + // check the alloca is placed after all operands become available func.func @placement1() { // do some stuff with other ssa values