diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -100,6 +100,10 @@ } else if (auto dispatch = dyn_cast(op)) { if (!hasPortableSignature(dispatch.getFunctionType())) convertCallOp(dispatch); + } else if (auto addr = dyn_cast(op)) { + if (addr.getType().isa() && + !hasPortableSignature(addr.getType())) + convertAddrOp(addr); } }); @@ -319,6 +323,55 @@ newInTys.push_back(std::get(tup)); } + /// Taking the address of a function. Modify the signature as needed. + void convertAddrOp(AddrOfOp addrOp) { + rewriter->setInsertionPoint(addrOp); + auto addrTy = addrOp.getType().cast(); + llvm::SmallVector newResTys; + llvm::SmallVector newInTys; + for (mlir::Type ty : addrTy.getResults()) { + llvm::TypeSwitch(ty) + .Case([&](fir::ComplexType ty) { + lowerComplexSignatureRes(ty, newResTys, newInTys); + }) + .Case([&](mlir::ComplexType ty) { + lowerComplexSignatureRes(ty, newResTys, newInTys); + }) + .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); + } + llvm::SmallVector trailingInTys; + for (mlir::Type ty : addrTy.getInputs()) { + llvm::TypeSwitch(ty) + .Case([&](BoxCharType box) { + if (noCharacterConversion) { + newInTys.push_back(box); + } else { + for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { + auto attr = std::get(tup); + auto argTy = std::get(tup); + llvm::SmallVector &vec = + attr.isAppend() ? trailingInTys : newInTys; + vec.push_back(argTy); + } + } + }) + .Case([&](fir::ComplexType ty) { + lowerComplexSignatureArg(ty, newInTys); + }) + .Case([&](mlir::ComplexType ty) { + lowerComplexSignatureArg(ty, newInTys); + }) + .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); + } + // append trailing input types + newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); + // replace this op with a new one with the updated signature + auto newTy = rewriter->getFunctionType(newInTys, newResTys); + auto newOp = + rewriter->create(addrOp.getLoc(), newTy, addrOp.symbol()); + replaceOp(addrOp, newOp.getResult()); + } + /// Convert the type signatures on all the functions present in the module. /// As the type signature is being changed, this must also update the /// function itself to use any new arguments, etc. diff --git a/flang/test/Fir/target-rewrite-boxchar.fir b/flang/test/Fir/target-rewrite-boxchar.fir --- a/flang/test/Fir/target-rewrite-boxchar.fir +++ b/flang/test/Fir/target-rewrite-boxchar.fir @@ -93,3 +93,13 @@ //constant 1 fir.has_value %str : !fir.char<1,9> } + +// Test that we rewrite the fir.address_of operator +// INT32-LABEL: @addrof +// INT64-LABEL: @addrof +func @addrof() { + // INT32: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref>, i32) -> () + // INT64: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref>, i64) -> () + %f = fir.address_of(@boxcharcallee) : (!fir.boxchar<1>) -> () + return +} diff --git a/flang/test/Fir/target-rewrite-complex.fir b/flang/test/Fir/target-rewrite-complex.fir --- a/flang/test/Fir/target-rewrite-complex.fir +++ b/flang/test/Fir/target-rewrite-complex.fir @@ -452,3 +452,23 @@ // PPC: return [[RES]] : tuple return %0 : complex } + +// Test that we rewrite the fir.address_of operator. +// I32-LABEL: func @addrof() +// X64-LABEL: func @addrof() +// AARCH64-LABEL: func @addrof() +// PPC-LABEL: func @addrof() +func @addrof() { + // I32: {{%.*}} = fir.address_of(@returncomplex4) : () -> i64 + // X64: {{%.*}} = fir.address_of(@returncomplex4) : () -> !fir.vector<2:!fir.real<4>> + // AARCH64: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple, !fir.real<4>> + // PPC: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple, !fir.real<4>> + %r = fir.address_of(@returncomplex4) : () -> !fir.complex<4> + + // I32: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.ref, !fir.real<4>>>) -> () + // X64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.vector<2:!fir.real<4>>) -> () + // AARCH64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.array<2x!fir.real<4>>) -> () + // PPC: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.real<4>, !fir.real<4>) -> () + %p = fir.address_of(@paramcomplex4) : (!fir.complex<4>) -> () + return +}