diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -465,6 +465,29 @@ return &funcMaps[func]; } +/// Restore the old allocation type exected by existing code +static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, + const mlir::Location &loc, + mlir::Value heap, mlir::Value stack) { + mlir::Type heapTy = heap.getType(); + mlir::Type stackTy = stack.getType(); + + if (heapTy == stackTy) + return stack; + + fir::HeapType firHeapTy = mlir::cast(heapTy); + fir::ReferenceType firRefTy = mlir::cast(stackTy); + assert(firHeapTy.getElementType() == firRefTy.getElementType() && + "Allocations must have the same type"); + + auto insertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(stack.getDefiningOp()); + mlir::Value conv = + rewriter.create(loc, firHeapTy, stack).getResult(); + rewriter.restoreInsertionPoint(insertionPoint); + return conv; +} + mlir::LogicalResult AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, mlir::PatternRewriter &rewriter) const { @@ -485,7 +508,9 @@ rewriter.eraseOp(erase); // replace references to heap allocation with references to stack allocation - rewriter.replaceAllUsesWith(allocmem.getResult(), alloca->getResult()); + mlir::Value newValue = convertAllocationType( + rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); + rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); // remove allocmem operation rewriter.eraseOp(allocmem.getOperation()); diff --git a/flang/test/Transforms/stack-arrays.fir b/flang/test/Transforms/stack-arrays.fir --- a/flang/test/Transforms/stack-arrays.fir +++ b/flang/test/Transforms/stack-arrays.fir @@ -103,10 +103,13 @@ } // CHECK: func.func @dfa3a(%arg0: i1) { // CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array<1xi8> +// CHECK-NEXT: %[[HEAP:.*]] = fir.convert %[[MEM]] : (!fir.ref>) -> !fir.heap> // CHECK-NEXT: fir.if %arg0 { -// CHECK-NEXT: func.call @dfa3a_foo(%[[MEM]]) +// CHECK-NEXT: %[[REF:.*]] = fir.convert %[[HEAP]] : (!fir.heap>) -> !fir.ref> +// CHECK-NEXT: func.call @dfa3a_foo(%[[REF]]) // CHECK-NEXT: } else { -// CHECK-NEXT: func.call @dfa3a_bar(%[[MEM]]) +// CHECK-NEXT: %[[REF:.*]] = fir.convert %[[HEAP]] : (!fir.heap>) -> !fir.ref> +// CHECK-NEXT: func.call @dfa3a_bar(%[[REF]]) // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: }