diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -475,9 +475,20 @@ return; llvm::SmallVector newResTys; llvm::SmallVector newInTys; + llvm::SmallVector> savedAttrs; llvm::SmallVector> extraAttrs; llvm::SmallVector fixups; + // Save argument attributes in case there is a shift so we can replace them + // correctly. + for (auto e : llvm::enumerate(funcTy.getInputs())) { + unsigned index = e.index(); + llvm::ArrayRef attrs = func.getArgAttrs(index); + for (mlir::NamedAttribute attr : attrs) { + savedAttrs.push_back({index, attr}); + } + } + // Convert return value(s) for (auto ty : funcTy.getResults()) llvm::TypeSwitch(ty) @@ -495,6 +506,10 @@ }) .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); + // Saved potential shift in argument. Handling of result can add arguments + // at the beginning of the function signature. + unsigned argumentShift = newInTys.size(); + // Convert arguments llvm::SmallVector trailingTys; for (auto e : llvm::enumerate(funcTy.getInputs())) { @@ -724,6 +739,17 @@ func.setArgAttr(extraAttr.first, extraAttr.second.getName(), extraAttr.second.getValue()); + // Replace attributes to the correct argument if there was an argument shift + // to the right. + if (argumentShift > 0) { + for (std::pair savedAttr : savedAttrs) { + func.removeArgAttr(savedAttr.first, savedAttr.second.getName()); + func.setArgAttr(savedAttr.first + argumentShift, + savedAttr.second.getName(), + savedAttr.second.getValue()); + } + } + for (auto &fixup : fixups) if (fixup.finalizer) (*fixup.finalizer)(func); diff --git a/flang/test/Fir/target-rewrite-arg-position.fir b/flang/test/Fir/target-rewrite-arg-position.fir --- a/flang/test/Fir/target-rewrite-arg-position.fir +++ b/flang/test/Fir/target-rewrite-arg-position.fir @@ -1,5 +1,7 @@ // RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s +// Test with an argument shift. + func.func @_QFPf(%arg0: !fir.ref>> {fir.host_assoc}) -> !fir.complex<16> { %0 = fir.alloca !fir.complex<16> {bindc_name = "f", uniq_name = "_QFfEf"} %c2_i32 = arith.constant 2 : i32 @@ -14,4 +16,14 @@ } // CHECK-LABEL: func.func @_QFPf -// CHECK-SAME: %{{.*}}: !fir.ref, !fir.real<16>>> {fir.host_assoc, llvm.align = 16 : i32, llvm.sret}, %arg1: !fir.ref>> {llvm.nest}) { +// CHECK-SAME: %{{.*}}: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret}, %arg1: !fir.ref>> {fir.host_assoc, llvm.nest}) { + +// ----- + +// Test with no shift. + +func.func @_QFPs(%arg0: !fir.ref {fir.host_assoc}) { + return +} + +// CHECK: func.func @_QFPs(%arg0: !fir.ref {fir.host_assoc, llvm.nest})