diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -26,7 +26,7 @@ #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -181,11 +181,9 @@ /// Converts a fir.allocmem to a fir.alloca class AllocMemConversion : public mlir::OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - - AllocMemConversion( - mlir::MLIRContext *ctx, - const llvm::DenseMap &candidateOps); + explicit AllocMemConversion(mlir::MLIRContext *ctx, + StackArraysAnalysisWrapper &analysis) + : OpRewritePattern(ctx), analysis{analysis} {} mlir::LogicalResult matchAndRewrite(fir::AllocMemOp allocmem, @@ -196,9 +194,8 @@ static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc); private: - /// allocmem operations that DFA has determined are safe to move to the stack - /// mapping to where to insert replacement freemem operations - const llvm::DenseMap &candidateOps; + /// Handle to the DFA (already run) + StackArraysAnalysisWrapper &analysis; /// If we failed to find an insertion point not inside a loop, see if it would /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop @@ -469,11 +466,6 @@ return funcMaps[func]; } -AllocMemConversion::AllocMemConversion( - mlir::MLIRContext *ctx, - const llvm::DenseMap &candidateOps) - : OpRewritePattern(ctx), candidateOps(candidateOps) {} - mlir::LogicalResult AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, mlir::PatternRewriter &rewriter) const { @@ -623,6 +615,8 @@ std::optional AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { + mlir::func::FuncOp func = oldAlloc->getParentOfType(); + const auto &candidateOps = analysis.getCandidateOps(func); auto it = candidateOps.find(oldAlloc.getOperation()); if (it == candidateOps.end()) return {}; @@ -709,29 +703,25 @@ assert(mlir::isa(func)); auto &analysis = getAnalysis(); - const auto &candidateOps = analysis.getCandidateOps(func); if (analysis.hasErrors()) { signalPassFailure(); return; } + const auto &candidateOps = analysis.getCandidateOps(func); if (candidateOps.empty()) return; runCount += candidateOps.size(); mlir::MLIRContext &context = getContext(); mlir::RewritePatternSet patterns(&context); - mlir::ConversionTarget target(context); - - target.addLegalDialect(); - target.addDynamicallyLegalOp([&](fir::AllocMemOp alloc) { - return !candidateOps.count(alloc.getOperation()); - }); + mlir::GreedyRewriteConfig config; + // prevent the pattern driver form merging blocks + config.enableRegionSimplification = false; - patterns.insert(&context, candidateOps); - if (mlir::failed( - mlir::applyPartialConversion(func, target, std::move(patterns)))) { + patterns.insert(&context, analysis); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(func, std::move(patterns), + config))) { mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); signalPassFailure(); } diff --git a/flang/test/Transforms/stack-arrays.f90 b/flang/test/Transforms/stack-arrays.f90 --- a/flang/test/Transforms/stack-arrays.f90 +++ b/flang/test/Transforms/stack-arrays.f90 @@ -147,7 +147,7 @@ end subroutine ! CHECK: func.func ! CHECK-SAME: cfgloop -! CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array<100000000xi32> +! CHECK: %[[MEM:.*]] = fir.alloca !fir.array<100000000xi32> ! CHECK-NOT: fir.allocmem ! CHECK-NOT: fir.freemem ! CHECK: return diff --git a/flang/test/Transforms/stack-arrays.fir b/flang/test/Transforms/stack-arrays.fir --- a/flang/test/Transforms/stack-arrays.fir +++ b/flang/test/Transforms/stack-arrays.fir @@ -78,13 +78,11 @@ } // CHECK: func.func @dfa3(%arg0: i1) { // CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array<1xi8> -// CHECK-NEXT: fir.if %arg0 { -// CHECK-NEXT: } else { -// CHECK-NEXT: } // CHECK-NEXT: return -// CHECK-NEXT: } +// CHECK-NEXT:} // check the alloca is placed after all operands become available +// note: the greedy rewrite driver folds the add to get 3 func.func @placement1() { // do some stuff with other ssa values %1 = arith.constant 1 : index @@ -97,19 +95,14 @@ return } // CHECK: func.func @placement1() { -// CHECK-NEXT: %[[ONE:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[TWO:.*]] = arith.constant 2 : index -// CHECK-NEXT: %[[ARG:.*]] = arith.addi %[[ONE]], %[[TWO]] : index -// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array, %[[ARG]] +// CHECK-NEXT: %[[THREE:.*]] = arith.constant 3 : index +// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array, %[[THREE]] // CHECK-NEXT: return // CHECK-NEXT: } // check that if there are no operands, then the alloca is placed early func.func @placement2() { - // do some stuff with other ssa values - %1 = arith.constant 1 : index - %2 = arith.constant 2 : index - %3 = arith.addi %1, %2 : index + call @placement1() : () -> () %4 = fir.allocmem !fir.array<42xi32> // ... fir.freemem %4 : !fir.heap> @@ -117,20 +110,17 @@ } // CHECK: func.func @placement2() { // CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array<42xi32> -// CHECK-NEXT: %[[ONE:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[TWO:.*]] = arith.constant 2 : index -// CHECK-NEXT: %[[SUM:.*]] = arith.addi %[[ONE]], %[[TWO]] : index +// CHECK-NEXT: call @placement1() : () -> () // CHECK-NEXT: return // CHECK-NEXT: } // check that stack allocations which must be placed in loops use stacksave func.func @placement3() { %c1 = arith.constant 1 : index - %c1_i32 = fir.convert %c1 : (index) -> i32 - %c2 = arith.constant 2 : index + %c1_i32 = arith.constant 1 : i32 %c10 = arith.constant 10 : index %0:2 = fir.do_loop %arg0 = %c1 to %c10 step %c1 iter_args(%arg1 = %c1_i32) -> (index, i32) { - %3 = arith.addi %c1, %c2 : index + %3 = fir.call @foo() : () -> index // operand is now available %4 = fir.allocmem !fir.array, %3 // ... @@ -140,14 +130,10 @@ return } // CHECK: func.func @placement3() { -// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[C1_I32:.*]] = fir.convert %[[C1]] : (index) -> i32 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index -// CHECK-NEXT: %[[C10:.*]] = arith.constant 10 : index -// CHECK-NEXT: fir.do_loop -// CHECK-NEXT: %[[SUM:.*]] = arith.addi %[[C1]], %[[C2]] : index +// CHECK: fir.do_loop +// CHECK-NEXT: %[[SIZE:.*]] = fir.call @foo() : () -> index // CHECK-NEXT: %[[SP:.*]] = fir.call @llvm.stacksave() : () -> !fir.ref -// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array, %[[SUM]] +// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array, %[[SIZE]] // CHECK-NEXT: fir.call @llvm.stackrestore(%[[SP]]) // CHECK-NEXT: fir.result // CHECK-NEXT: } @@ -156,13 +142,9 @@ // check that stack save/restore are used in CFG loops func.func @placement4(%arg0 : i1) { - %c1 = arith.constant 1 : index - %c1_i32 = fir.convert %c1 : (index) -> i32 - %c2 = arith.constant 2 : index - %c10 = arith.constant 10 : index cf.br ^bb1 ^bb1: - %3 = arith.addi %c1, %c2 : index + %3 = fir.call @foo() : () -> index // operand is now available %4 = fir.allocmem !fir.array, %3 // ... @@ -172,15 +154,11 @@ return } // CHECK: func.func @placement4(%arg0: i1) { -// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[C1_I32:.*]] = fir.convert %[[C1]] : (index) -> i32 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index -// CHECK-NEXT: %[[C10:.*]] = arith.constant 10 : index // CHECK-NEXT: cf.br ^bb1 // CHECK-NEXT: ^bb1: -// CHECK-NEXT: %[[SUM:.*]] = arith.addi %[[C1]], %[[C2]] : index +// CHECK-NEXT: %[[SIZE:.*]] = fir.call @foo() : () -> index // CHECK-NEXT: %[[SP:.*]] = fir.call @llvm.stacksave() : () -> !fir.ref -// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array, %[[SUM]] +// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array, %[[SIZE]] // CHECK-NEXT: fir.call @llvm.stackrestore(%[[SP]]) : (!fir.ref) -> () // CHECK-NEXT: cf.cond_br %arg0, ^bb1, ^bb2 // CHECK-NEXT: ^bb2: @@ -336,9 +314,9 @@ fir.unreachable } // CHECK: func.func @stop_terminator() { -// CHECK-NEXT: fir.alloca !fir.array<42xi32> -// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32 -// CHECK-NEXT: %[[FALSE:.*]] = arith.constant false +// CHECK-DAG: fir.alloca !fir.array<42xi32> +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[FALSE:.*]] = arith.constant false // CHECK-NEXT: %[[NONE:.*]] = fir.call @_FortranAStopStatement(%[[ZERO]], %[[FALSE]], %[[FALSE]]) : (i32, i1, i1) -> none // CHECK-NEXT: fir.unreachable // CHECK-NEXT: }