diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td --- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td +++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td @@ -53,7 +53,7 @@ representations that may differ based on the target machine. }]; let constructor = "::fir::createFirTargetRewritePass()"; - let dependentDialects = [ "fir::FIROpsDialect" ]; + let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect" ]; let options = [ Option<"forcedTargetTriple", "target", "std::string", /*default=*/"", "Override module's target triple.">, diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -91,6 +91,13 @@ mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod)); setMembers(specifics.get(), &rewriter); + // We may need to call stacksave/stackrestore later, so + // create the FuncOps beforehand. + fir::FirOpBuilder builder(rewriter, fir::getKindMapping(mod)); + builder.setInsertionPointToStart(mod.getBody()); + stackSaveFn = fir::factory::getLlvmStackSave(builder); + stackRestoreFn = fir::factory::getLlvmStackRestore(builder); + // Perform type conversion on signatures and call sites. if (mlir::failed(convertTypes(mod))) { mlir::emitError(mlir::UnknownLoc::get(&context), @@ -121,7 +128,8 @@ template std::optional> rewriteCallComplexResultType(mlir::Location loc, A ty, B &newResTys, - B &newInTys, C &newOpers) { + B &newInTys, C &newOpers, + mlir::Value &savedStackPtr) { if (noComplexConversion) { newResTys.push_back(ty); return std::nullopt; @@ -134,6 +142,11 @@ auto attr = std::get(m[0]); if (attr.isSRet()) { assert(fir::isa_ref_type(resTy) && "must be a memory reference type"); + // Save the stack pointer, if it has not been saved for this call yet. + // We will need to restore it after the call, because the alloca + // needs to be deallocated. + if (!savedStackPtr) + savedStackPtr = genStackSave(loc); mlir::Value stack = rewriter->create(loc, fir::dyn_cast_ptrEleTy(resTy)); newInTys.push_back(resTy); @@ -145,7 +158,10 @@ }; } newResTys.push_back(resTy); - return [=](mlir::Operation *call) -> mlir::Value { + return [=, &savedStackPtr](mlir::Operation *call) -> mlir::Value { + // We are going to generate an alloca, so save the stack pointer. + if (!savedStackPtr) + savedStackPtr = genStackSave(loc); auto mem = rewriter->create(loc, resTy); rewriter->create(loc, call->getResult(0), mem); auto memTy = fir::ReferenceType::get(ty); @@ -156,7 +172,7 @@ template void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, - C &newOpers) { + C &newOpers, mlir::Value &savedStackPtr) { if (noComplexConversion) { newInTys.push_back(ty); newOpers.push_back(oper); @@ -173,6 +189,9 @@ auto resTy = std::get(m[0]); auto attr = std::get(m[0]); auto oldRefTy = fir::ReferenceType::get(ty); + // We are going to generate an alloca, so save the stack pointer. + if (!savedStackPtr) + savedStackPtr = genStackSave(loc); if (attr.isByVal()) { auto mem = rewriter->create(loc, ty); rewriter->create(loc, oper, mem); @@ -210,6 +229,7 @@ llvm::SmallVector newResTys; llvm::SmallVector newInTys; llvm::SmallVector newOpers; + mlir::Value savedStackPtr = nullptr; // If the call is indirect, the first argument must still be the function // to call. @@ -231,11 +251,11 @@ llvm::TypeSwitch(ty) .template Case([&](fir::ComplexType cmplx) { wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys, - newOpers); + newOpers, savedStackPtr); }) .template Case([&](mlir::ComplexType cmplx) { wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys, - newOpers); + newOpers, savedStackPtr); }) .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); } else if (fnTy.getResults().size() > 1) { @@ -290,10 +310,12 @@ } }) .template Case([&](fir::ComplexType cmplx) { - rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); + rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers, + savedStackPtr); }) .template Case([&](mlir::ComplexType cmplx) { - rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); + rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers, + savedStackPtr); }) .template Case([&](mlir::TupleType tuple) { if (fir::isCharacterProcedureTuple(tuple)) { @@ -343,6 +365,8 @@ } newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); + + llvm::SmallVector newCallResults; if constexpr (std::is_same_v, fir::CallOp>) { fir::CallOp newCall; if (callOp.getCallee()) { @@ -357,18 +381,40 @@ } LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); if (wrap) - replaceOp(callOp, (*wrap)(newCall.getOperation())); + newCallResults.push_back((*wrap)(newCall.getOperation())); else - replaceOp(callOp, newCall.getResults()); + newCallResults.append(newCall.result_begin(), newCall.result_end()); } else { fir::DispatchOp dispatchOp = rewriter->create( loc, newResTys, rewriter->getStringAttr(callOp.getMethod()), callOp.getOperands()[0], newOpers, rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift)); if (wrap) - replaceOp(callOp, (*wrap)(dispatchOp.getOperation())); + newCallResults.push_back((*wrap)(dispatchOp.getOperation())); else - replaceOp(callOp, dispatchOp.getResults()); + newCallResults.append(dispatchOp.result_begin(), + dispatchOp.result_end()); + } + + if (newCallResults.size() <= 1) { + if (savedStackPtr) { + if (newCallResults.size() == 1) { + // We assume that all the allocas are inserted before + // the operation that defines the new call result. + rewriter->setInsertionPointAfterValue(newCallResults[0]); + } else { + // If the call does not have results, then insert + // stack restore after the original call operation. + rewriter->setInsertionPointAfter(callOp); + } + genStackRestore(loc, savedStackPtr); + } + replaceOp(callOp, newCallResults); + } else { + // The TODO is duplicated here to make sure this part + // handles the stackrestore insertion properly, if + // we add support for multiple call results. + TODO(loc, "multiple results not supported yet"); } } @@ -974,8 +1020,22 @@ inline void clearMembers() { setMembers(nullptr, nullptr); } + // Inserts a call to llvm.stacksave at the current insertion + // point and the given location. Returns the call's result Value. + inline mlir::Value genStackSave(mlir::Location loc) { + return rewriter->create(loc, stackSaveFn).getResult(0); + } + + // Inserts a call to llvm.stackrestore at the current insertion + // point and the given location and argument. + inline void genStackRestore(mlir::Location loc, mlir::Value sp) { + rewriter->create(loc, stackRestoreFn, mlir::ValueRange{sp}); + } + fir::CodeGenSpecifics *specifics = nullptr; mlir::OpBuilder *rewriter = nullptr; + mlir::func::FuncOp stackSaveFn = nullptr; + mlir::func::FuncOp stackRestoreFn = nullptr; }; // namespace } // namespace diff --git a/flang/test/Fir/target-rewrite-complex16.fir b/flang/test/Fir/target-rewrite-complex16.fir --- a/flang/test/Fir/target-rewrite-complex16.fir +++ b/flang/test/Fir/target-rewrite-complex16.fir @@ -63,57 +63,65 @@ // CHECK: func.func private @paramcomplex16(!fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}) // CHECK-LABEL: func.func @callcomplex16() { -// CHECK: %[[VAL_0:.*]] = fir.alloca tuple, !fir.real<16>> -// CHECK: fir.call @returncomplex16(%[[VAL_0]]) : (!fir.ref, !fir.real<16>>>) -> () -// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (!fir.ref, !fir.real<16>>>) -> !fir.ref> -// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_1]] : !fir.ref> -// CHECK: %[[VAL_3:.*]] = fir.alloca !fir.complex<16> -// CHECK: fir.store %[[VAL_2]] to %[[VAL_3]] : !fir.ref> -// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref>) -> !fir.ref, !fir.real<16>>> -// CHECK: fir.call @paramcomplex16(%[[VAL_4]]) : (!fir.ref, !fir.real<16>>>) -> () +// CHECK: %[[VAL_0:.*]] = fir.call @llvm.stacksave() : () -> !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca tuple, !fir.real<16>> +// CHECK: fir.call @returncomplex16(%[[VAL_1]]) : (!fir.ref, !fir.real<16>>>) -> () +// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (!fir.ref, !fir.real<16>>>) -> !fir.ref> +// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref> +// CHECK: fir.call @llvm.stackrestore(%[[VAL_0]]) : (!fir.ref) -> () +// CHECK: %[[VAL_4:.*]] = fir.call @llvm.stacksave() : () -> !fir.ref +// CHECK: %[[VAL_5:.*]] = fir.alloca !fir.complex<16> +// CHECK: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref> +// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.ref>) -> !fir.ref, !fir.real<16>>> +// CHECK: fir.call @paramcomplex16(%[[VAL_6]]) : (!fir.ref, !fir.real<16>>>) -> () +// CHECK: fir.call @llvm.stackrestore(%[[VAL_4]]) : (!fir.ref) -> () // CHECK: return // CHECK: } // CHECK: func.func private @calleemultipleparamscomplex16(!fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}) // CHECK-LABEL: func.func @multipleparamscomplex16( -// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, %[[VAL_1:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, %[[VAL_2:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}) { +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, %[[VAL_1:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, %[[VAL_2:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}) { // CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref, !fir.real<16>>>) -> !fir.ref> // CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref> // CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref, !fir.real<16>>>) -> !fir.ref> // CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref> // CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref, !fir.real<16>>>) -> !fir.ref> // CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]] : !fir.ref> -// CHECK: %[[VAL_9:.*]] = fir.alloca !fir.complex<16> -// CHECK: fir.store %[[VAL_8]] to %[[VAL_9]] : !fir.ref> -// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (!fir.ref>) -> !fir.ref, !fir.real<16>>> -// CHECK: %[[VAL_11:.*]] = fir.alloca !fir.complex<16> -// CHECK: fir.store %[[VAL_6]] to %[[VAL_11]] : !fir.ref> -// CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (!fir.ref>) -> !fir.ref, !fir.real<16>>> -// CHECK: %[[VAL_13:.*]] = fir.alloca !fir.complex<16> -// CHECK: fir.store %[[VAL_4]] to %[[VAL_13]] : !fir.ref> -// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (!fir.ref>) -> !fir.ref, !fir.real<16>>> -// CHECK: fir.call @calleemultipleparamscomplex16(%[[VAL_10]], %[[VAL_12]], %[[VAL_14]]) : (!fir.ref, !fir.real<16>>>, !fir.ref, !fir.real<16>>>, !fir.ref, !fir.real<16>>>) -> () +// CHECK: %[[VAL_9:.*]] = fir.call @llvm.stacksave() : () -> !fir.ref +// CHECK: %[[VAL_10:.*]] = fir.alloca !fir.complex<16> +// CHECK: fir.store %[[VAL_8]] to %[[VAL_10]] : !fir.ref> +// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.ref>) -> !fir.ref, !fir.real<16>>> +// CHECK: %[[VAL_12:.*]] = fir.alloca !fir.complex<16> +// CHECK: fir.store %[[VAL_6]] to %[[VAL_12]] : !fir.ref> +// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_12]] : (!fir.ref>) -> !fir.ref, !fir.real<16>>> +// CHECK: %[[VAL_14:.*]] = fir.alloca !fir.complex<16> +// CHECK: fir.store %[[VAL_4]] to %[[VAL_14]] : !fir.ref> +// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.ref>) -> !fir.ref, !fir.real<16>>> +// CHECK: fir.call @calleemultipleparamscomplex16(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]]) : (!fir.ref, !fir.real<16>>>, !fir.ref, !fir.real<16>>>, !fir.ref, !fir.real<16>>>) -> () +// CHECK: fir.call @llvm.stackrestore(%[[VAL_9]]) : (!fir.ref) -> () // CHECK: return // CHECK: } // CHECK-LABEL: func.func private @mlircomplexf128( -// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.sret = tuple}, %[[VAL_1:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.byval = tuple}, %[[VAL_2:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.byval = tuple}) { +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.sret = tuple}, %[[VAL_1:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.byval = tuple}, %[[VAL_2:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.byval = tuple}) { // CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref> // CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref> -// CHECK: %[[VAL_7:.*]] = fir.alloca tuple -// CHECK: %[[VAL_8:.*]] = fir.alloca complex -// CHECK: fir.store %[[VAL_6]] to %[[VAL_8]] : !fir.ref> -// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (!fir.ref>) -> !fir.ref> -// CHECK: %[[VAL_10:.*]] = fir.alloca complex -// CHECK: fir.store %[[VAL_4]] to %[[VAL_10]] : !fir.ref> -// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.ref>) -> !fir.ref> -// CHECK: fir.call @mlircomplexf128(%[[VAL_7]], %[[VAL_9]], %[[VAL_11]]) : (!fir.ref>, !fir.ref>, !fir.ref>) -> () -// CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_7]] : (!fir.ref>) -> !fir.ref> -// CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_12]] : !fir.ref> -// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> !fir.ref> -// CHECK: fir.store %[[VAL_13]] to %[[VAL_14]] : !fir.ref> +// CHECK: %[[VAL_7:.*]] = fir.call @llvm.stacksave() : () -> !fir.ref +// CHECK: %[[VAL_8:.*]] = fir.alloca tuple +// CHECK: %[[VAL_9:.*]] = fir.alloca complex +// CHECK: fir.store %[[VAL_6]] to %[[VAL_9]] : !fir.ref> +// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (!fir.ref>) -> !fir.ref> +// CHECK: %[[VAL_11:.*]] = fir.alloca complex +// CHECK: fir.store %[[VAL_4]] to %[[VAL_11]] : !fir.ref> +// CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (!fir.ref>) -> !fir.ref> +// CHECK: fir.call @mlircomplexf128(%[[VAL_8]], %[[VAL_10]], %[[VAL_12]]) : (!fir.ref>, !fir.ref>, !fir.ref>) -> () +// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_8]] : (!fir.ref>) -> !fir.ref> +// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref> +// CHECK: fir.call @llvm.stackrestore(%[[VAL_7]]) : (!fir.ref) -> () +// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> !fir.ref> +// CHECK: fir.store %[[VAL_14]] to %[[VAL_15]] : !fir.ref> // CHECK: return // CHECK: } @@ -122,3 +130,5 @@ // CHECK: %[[VAL_1:.*]] = fir.address_of(@paramcomplex16) : (!fir.ref, !fir.real<16>>>) -> () // CHECK: return // CHECK: } +// CHECK: func.func private @llvm.stacksave() -> !fir.ref +// CHECK: func.func private @llvm.stackrestore(!fir.ref) diff --git a/flang/test/Fir/target.fir b/flang/test/Fir/target.fir --- a/flang/test/Fir/target.fir +++ b/flang/test/Fir/target.fir @@ -87,7 +87,7 @@ // PPC: = call { double, double } @gen8() %1 = fir.call @gen8() : () -> !fir.complex<8> // I32: call void @sink8(ptr % - // X64: call void @sink8(double %4, double %5) + // X64: call void @sink8(double %{{[0-9]*}}, double %{{[0-9]*}}) // AARCH64: call void @sink8([2 x double] % // PPC: call void @sink8(double %{{.*}}, double %{{.*}}) fir.call @sink8(%1) : (!fir.complex<8>) -> ()