diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -756,7 +756,7 @@ } inline bool functionArgIsSRet(unsigned index, mlir::func::FuncOp func) { - if (auto attr = func.getArgAttrOfType(index, "llvm.sret")) + if (auto attr = func.getArgAttrOfType(index, "llvm.sret")) return true; return false; } @@ -782,16 +782,22 @@ if (auto align = attr.getAlignment()) fixups.emplace_back( FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) { - func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); + auto elemType = fir::dyn_cast_ptrOrBoxEleTy( + func.getFunctionType().getInput(argNo)); + func.setArgAttr(argNo, "llvm.sret", + mlir::TypeAttr::get(elemType)); func.setArgAttr(argNo, "llvm.align", rewriter->getIntegerAttr( rewriter->getIntegerType(32), align)); }); else - fixups.emplace_back( - FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) { - func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); - }); + fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo, + [=](mlir::func::FuncOp func) { + auto elemType = fir::dyn_cast_ptrOrBoxEleTy( + func.getFunctionType().getInput(argNo)); + func.setArgAttr(argNo, "llvm.sret", + mlir::TypeAttr::get(elemType)); + }); newInTys.push_back(argTy); return; } else { @@ -833,7 +839,10 @@ fixups.emplace_back( FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::func::FuncOp func) { - func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); + auto elemType = fir::dyn_cast_ptrOrBoxEleTy( + func.getFunctionType().getInput(argNo)); + func.setArgAttr(argNo, "llvm.byval", + mlir::TypeAttr::get(elemType)); func.setArgAttr(argNo, "llvm.align", rewriter->getIntegerAttr( rewriter->getIntegerType(32), align)); @@ -841,8 +850,10 @@ else fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), [=](mlir::func::FuncOp func) { + auto elemType = fir::dyn_cast_ptrOrBoxEleTy( + func.getFunctionType().getInput(argNo)); func.setArgAttr(argNo, "llvm.byval", - rewriter->getUnitAttr()); + mlir::TypeAttr::get(elemType)); }); } else { if (auto align = attr.getAlignment()) diff --git a/flang/test/Fir/target-rewrite-arg-position.fir b/flang/test/Fir/target-rewrite-arg-position.fir --- a/flang/test/Fir/target-rewrite-arg-position.fir +++ b/flang/test/Fir/target-rewrite-arg-position.fir @@ -16,7 +16,7 @@ } // CHECK-LABEL: func.func @_QFPf -// CHECK-SAME: %{{.*}}: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret}, %arg1: !fir.ref>> {fir.host_assoc, llvm.nest}) { +// CHECK-SAME: %{{.*}}: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret = tuple, !fir.real<16>>}, %arg1: !fir.ref>> {fir.host_assoc, llvm.nest}) { // ----- diff --git a/flang/test/Fir/target-rewrite-boxchar.fir b/flang/test/Fir/target-rewrite-boxchar.fir --- a/flang/test/Fir/target-rewrite-boxchar.fir +++ b/flang/test/Fir/target-rewrite-boxchar.fir @@ -27,10 +27,10 @@ // Test that we rewrite the signatures and bodies of functions that return a // boxchar. // INT32-LABEL: @boxcharsret -// INT32-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref> {llvm.sret}, [[ARG1:%[0-9A-Za-z]+]]: i32, [[ARG2:%[0-9A-Za-z]+]]: !fir.ref>, [[ARG3:%[0-9A-Za-z]+]]: i32) +// INT32-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref> {llvm.sret = !fir.char<1,?>}, [[ARG1:%[0-9A-Za-z]+]]: i32, [[ARG2:%[0-9A-Za-z]+]]: !fir.ref>, [[ARG3:%[0-9A-Za-z]+]]: i32) // INT64-LABEL: @boxcharsret -// INT64-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref> {llvm.sret}, [[ARG1:%[0-9A-Za-z]+]]: i64, [[ARG2:%[0-9A-Za-z]+]]: !fir.ref>, [[ARG3:%[0-9A-Za-z]+]]: i64) -func.func @boxcharsret(%arg0 : !fir.boxchar<1> {llvm.sret}, %arg1 : !fir.boxchar<1>) { +// INT64-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref> {llvm.sret = !fir.char<1,?>}, [[ARG1:%[0-9A-Za-z]+]]: i64, [[ARG2:%[0-9A-Za-z]+]]: !fir.ref>, [[ARG3:%[0-9A-Za-z]+]]: i64) +func.func @boxcharsret(%arg0 : !fir.boxchar<1> {llvm.sret = !fir.char<1,?>}, %arg1 : !fir.boxchar<1>) { // INT32-DAG: [[B0:%[0-9]+]] = fir.emboxchar [[ARG0]], [[ARG1]] : (!fir.ref>, i32) -> !fir.boxchar<1> // INT32-DAG: [[B1:%[0-9]+]] = fir.emboxchar [[ARG2]], [[ARG3]] : (!fir.ref>, i32) -> !fir.boxchar<1> // INT32-DAG: fir.unboxchar [[B0]] : (!fir.boxchar<1>) -> (!fir.ref>>, i64) @@ -57,10 +57,10 @@ // Test that we rewrite the signatures of functions with a sret parameter and // several other parameters. // INT32-LABEL: @boxcharmultiple -// INT32-SAME: ({{%[0-9A-Za-z]+}}: !fir.ref> {llvm.sret}, {{%[0-9A-Za-z]+}}: i32, {{%[0-9A-Za-z]+}}: !fir.ref>, {{%[0-9A-Za-z]+}}: !fir.ref>, {{%[0-9A-Za-z]+}}: i32, {{%[0-9A-Za-z]+}}: i32) +// INT32-SAME: ({{%[0-9A-Za-z]+}}: !fir.ref> {llvm.sret = !fir.char<1,?>}, {{%[0-9A-Za-z]+}}: i32, {{%[0-9A-Za-z]+}}: !fir.ref>, {{%[0-9A-Za-z]+}}: !fir.ref>, {{%[0-9A-Za-z]+}}: i32, {{%[0-9A-Za-z]+}}: i32) // INT64-LABEL: @boxcharmultiple -// INT64-SAME: ({{%[0-9A-Za-z]+}}: !fir.ref> {llvm.sret}, {{%[0-9A-Za-z]+}}: i64, {{%[0-9A-Za-z]+}}: !fir.ref>, {{%[0-9A-Za-z]+}}: !fir.ref>, {{%[0-9A-Za-z]+}}: i64, {{%[0-9A-Za-z]+}}: i64) -func.func @boxcharmultiple(%arg0 : !fir.boxchar<1> {llvm.sret}, %arg1 : !fir.boxchar<1>, %arg2 : !fir.boxchar<1>) { +// INT64-SAME: ({{%[0-9A-Za-z]+}}: !fir.ref> {llvm.sret = !fir.char<1,?>}, {{%[0-9A-Za-z]+}}: i64, {{%[0-9A-Za-z]+}}: !fir.ref>, {{%[0-9A-Za-z]+}}: !fir.ref>, {{%[0-9A-Za-z]+}}: i64, {{%[0-9A-Za-z]+}}: i64) +func.func @boxcharmultiple(%arg0 : !fir.boxchar<1> {llvm.sret = !fir.char<1,?>}, %arg1 : !fir.boxchar<1>, %arg2 : !fir.boxchar<1>) { return } diff --git a/flang/test/Fir/target-rewrite-complex.fir b/flang/test/Fir/target-rewrite-complex.fir --- a/flang/test/Fir/target-rewrite-complex.fir +++ b/flang/test/Fir/target-rewrite-complex.fir @@ -53,7 +53,7 @@ // Test that we rewrite the signature and body of a function that returns a // complex<8>. // I32-LABEL:func @returncomplex8 -// I32-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref, !fir.real<8>>> {llvm.align = 4 : i32, llvm.sret}) +// I32-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref, !fir.real<8>>> {llvm.align = 4 : i32, llvm.sret = tuple, !fir.real<8>>}) // X64-LABEL: func @returncomplex8() -> tuple, !fir.real<8>> // AARCH64-LABEL: func @returncomplex8() -> tuple, !fir.real<8>> // PPC-LABEL: func @returncomplex8() -> tuple, !fir.real<8>> @@ -96,7 +96,7 @@ } // Test that we rewrite the signature of a function that accepts a complex<4>. -// I32-LABEL: func private @paramcomplex4(!fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}) +// I32-LABEL: func private @paramcomplex4(!fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple, !fir.real<4>>}) // X64-LABEL: func private @paramcomplex4(!fir.vector<2:!fir.real<4>>) // AARCH64-LABEL: func private @paramcomplex4(!fir.array<2x!fir.real<4>>) // PPC-LABEL: func private @paramcomplex4(!fir.real<4>, !fir.real<4>) @@ -156,7 +156,7 @@ } // Test that we rewrite the signature of a function that accepts a complex<8>. -// I32-LABEL: func private @paramcomplex8(!fir.ref, !fir.real<8>>> {llvm.align = 4 : i32, llvm.byval}) +// I32-LABEL: func private @paramcomplex8(!fir.ref, !fir.real<8>>> {llvm.align = 4 : i32, llvm.byval = tuple, !fir.real<8>>}) // X64-LABEL: func private @paramcomplex8(!fir.real<8>, !fir.real<8>) // AARCH64-LABEL: func private @paramcomplex8(!fir.array<2x!fir.real<8>>) // PPC-LABEL: func private @paramcomplex8(!fir.real<8>, !fir.real<8>) @@ -212,14 +212,14 @@ } // Test multiple complex<4> parameters and arguments -// I32-LABEL: func private @calleemultipleparamscomplex4(!fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}, !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}, !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}) +// I32-LABEL: func private @calleemultipleparamscomplex4(!fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple, !fir.real<4>>}, !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple, !fir.real<4>>}, !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple, !fir.real<4>>}) // X64-LABEL: func private @calleemultipleparamscomplex4(!fir.vector<2:!fir.real<4>>, !fir.vector<2:!fir.real<4>>, !fir.vector<2:!fir.real<4>>) // AARCH64-LABEL: func private @calleemultipleparamscomplex4(!fir.array<2x!fir.real<4>>, !fir.array<2x!fir.real<4>>, !fir.array<2x!fir.real<4>>) // PPC-LABEL: func private @calleemultipleparamscomplex4(!fir.real<4>, !fir.real<4>, !fir.real<4>, !fir.real<4>, !fir.real<4>, !fir.real<4>) func.func private @calleemultipleparamscomplex4(!fir.complex<4>, !fir.complex<4>, !fir.complex<4>) -> () // I32-LABEL: func @multipleparamscomplex4 -// I32-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}, [[Z2:%[0-9A-Za-z]+]]: !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}, [[Z3:%[0-9A-Za-z]+]]: !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}) +// I32-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple, !fir.real<4>>}, [[Z2:%[0-9A-Za-z]+]]: !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple, !fir.real<4>>}, [[Z3:%[0-9A-Za-z]+]]: !fir.ref, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple, !fir.real<4>>}) // X64-LABEL: func @multipleparamscomplex4 // X64-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.vector<2:!fir.real<4>>, [[Z2:%[0-9A-Za-z]+]]: !fir.vector<2:!fir.real<4>>, [[Z3:%[0-9A-Za-z]+]]: !fir.vector<2:!fir.real<4>>) // AARCH64-LABEL: func @multipleparamscomplex4 @@ -329,7 +329,7 @@ // and returns MLIR complex. // I32-LABEL: func private @mlircomplexf32 -// I32-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.ref> {llvm.align = 4 : i32, llvm.byval}, [[Z2:%[0-9A-Za-z]+]]: !fir.ref> {llvm.align = 4 : i32, llvm.byval}) +// I32-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.ref> {llvm.align = 4 : i32, llvm.byval = tuple}, [[Z2:%[0-9A-Za-z]+]]: !fir.ref> {llvm.align = 4 : i32, llvm.byval = tuple}) // I32-SAME: -> i64 // X64-LABEL: func private @mlircomplexf32 // X64-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.vector<2:f32>, [[Z2:%[0-9A-Za-z]+]]: !fir.vector<2:f32>) diff --git a/flang/test/Fir/target-rewrite-complex16.fir b/flang/test/Fir/target-rewrite-complex16.fir --- a/flang/test/Fir/target-rewrite-complex16.fir +++ b/flang/test/Fir/target-rewrite-complex16.fir @@ -48,7 +48,7 @@ } // CHECK-LABEL: func.func @returncomplex16( -// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret}) { +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret = tuple, !fir.real<16>>}) { // CHECK: %[[VAL_1:.*]] = fir.undefined !fir.complex<16> // CHECK: %[[VAL_2:.*]] = arith.constant 2.000000e+00 : f128 // CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (f128) -> !fir.real<16> @@ -61,7 +61,7 @@ // CHECK: fir.store %[[VAL_8]] to %[[VAL_9]] : !fir.ref> // CHECK: return // CHECK: } -// CHECK: func.func private @paramcomplex16(!fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}) +// CHECK: func.func private @paramcomplex16(!fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}) // CHECK-LABEL: func.func @callcomplex16() { // CHECK: %[[VAL_0:.*]] = fir.alloca tuple, !fir.real<16>> @@ -74,10 +74,10 @@ // CHECK: fir.call @paramcomplex16(%[[VAL_4]]) : (!fir.ref, !fir.real<16>>>) -> () // CHECK: return // CHECK: } -// CHECK: func.func private @calleemultipleparamscomplex16(!fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}, !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}, !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}) +// CHECK: func.func private @calleemultipleparamscomplex16(!fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}) // CHECK-LABEL: func.func @multipleparamscomplex16( -// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}, %[[VAL_1:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}, %[[VAL_2:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}) { +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, %[[VAL_1:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}, %[[VAL_2:.*]]: !fir.ref, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple, !fir.real<16>>}) { // CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref, !fir.real<16>>>) -> !fir.ref> // CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref> // CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref, !fir.real<16>>>) -> !fir.ref> @@ -98,7 +98,7 @@ // CHECK: } // CHECK-LABEL: func.func private @mlircomplexf128( -// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.sret}, %[[VAL_1:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.byval}, %[[VAL_2:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.byval}) { +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.sret = tuple}, %[[VAL_1:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.byval = tuple}, %[[VAL_2:.*]]: !fir.ref> {llvm.align = 16 : i32, llvm.byval = tuple}) { // CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref>) -> !fir.ref> // CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref> // CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref>) -> !fir.ref> diff --git a/flang/test/Fir/target.fir b/flang/test/Fir/target.fir --- a/flang/test/Fir/target.fir +++ b/flang/test/Fir/target.fir @@ -115,7 +115,7 @@ // I32-LABEL: define void @char1copy(ptr sret(i8) %0, i32 %1, ptr %2, i32 %3) // I64-LABEL: define void @char1copy(ptr sret(i8) %0, i64 %1, ptr %2, i64 %3) // PPC-LABEL: define void @char1copy(ptr sret(i8) %0, i64 %1, ptr %2, i64 %3) -func.func @char1copy(%arg0 : !fir.boxchar<1> {llvm.sret}, %arg1 : !fir.boxchar<1>) { +func.func @char1copy(%arg0 : !fir.boxchar<1> {llvm.sret = !fir.char<1, ?>}, %arg1 : !fir.boxchar<1>) { // I32-DAG: %[[p0:.*]] = insertvalue { ptr, i32 } undef, ptr %2, 0 // I32-DAG: = insertvalue { ptr, i32 } %[[p0]], i32 %3, 1 // I32-DAG: %[[p1:.*]] = insertvalue { ptr, i32 } undef, ptr %0, 0 diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -23,10 +23,13 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -34,6 +37,7 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" @@ -311,10 +315,38 @@ SmallVector newArgAttrs( llvmType.cast().getNumParams()); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { + // Some LLVM IR attribute have a type attached to them. During FuncOp -> + // LLVMFuncOp conversion these types may have changed. Account for that + // change by converting attributes' types as well. + SmallVector convertedAttrs; + auto attrsDict = argAttrDicts[i].cast(); + convertedAttrs.reserve(attrsDict.size()); + for (auto &attr : attrsDict) { + const auto convert = [&](const NamedAttribute &attr) { + return TypeAttr::get(getTypeConverter()->convertType( + attr.getValue().cast().getValue())); + }; + if (attr.getName().getValue() == "llvm.byval") { + convertedAttrs.push_back( + rewriter.getNamedAttr("llvm.byval", convert(attr))); + } else if (attr.getName().getValue() == "llvm.byref") { + convertedAttrs.push_back( + rewriter.getNamedAttr("llvm.byref", convert(attr))); + } else if (attr.getName().getValue() == "llvm.sret") { + convertedAttrs.push_back( + rewriter.getNamedAttr("llvm.sret", convert(attr))); + } else if (attr.getName().getValue() == "llvm.inalloca") { + convertedAttrs.push_back( + rewriter.getNamedAttr("llvm.inalloca", convert(attr))); + } else { + convertedAttrs.push_back(attr); + } + } auto mapping = result.getInputMapping(i); assert(mapping && "unexpected deletion of function argument"); for (size_t j = 0; j < mapping->size; ++j) - newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; + newArgAttrs[mapping->inputNo + j] = + DictionaryAttr::get(rewriter.getContext(), convertedAttrs); } attributes.push_back( rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1182,6 +1182,34 @@ UnknownLoc::get(context), f->getName(), functionType, convertLinkageFromLLVM(f->getLinkage()), dsoLocal, cconv); + for (const auto &arg : llvm::enumerate(functionType.getParams())) { + llvm::SmallVector argAttrs; + llvm::errs() << "ARG " << arg.index() << "\n"; + if (auto *type = f->getParamByValType(arg.index())) { + llvm::errs() << "BYVAL " << arg.index() << "\n"; + auto mlirType = processType(type); + argAttrs.push_back(NamedAttribute(b.getStringAttr("llvm.byval"), + TypeAttr::get(mlirType))); + } + if (auto *type = f->getParamByRefType(arg.index())) { + auto mlirType = processType(type); + argAttrs.push_back(NamedAttribute(b.getStringAttr("llvm.byref"), + TypeAttr::get(mlirType))); + } + if (auto *type = f->getParamStructRetType(arg.index())) { + auto mlirType = processType(type); + argAttrs.push_back(NamedAttribute(b.getStringAttr("llvm.sret"), + TypeAttr::get(mlirType))); + } + if (auto *type = f->getParamInAllocaType(arg.index())) { + auto mlirType = processType(type); + argAttrs.push_back(NamedAttribute(b.getStringAttr("llvm.inalloca"), + TypeAttr::get(mlirType))); + } + + fop.setArgAttrs(arg.index(), argAttrs); + } + if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f)) fop->setAttr(b.getStringAttr("personality"), personality); else if (f->hasPersonalityFn()) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -840,23 +840,28 @@ .addAlignmentAttr(llvm::Align(attr.getInt()))); } - if (auto attr = func.getArgAttrOfType(argIdx, "llvm.sret")) { + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.sret")) { auto argTy = mlirArg.getType().dyn_cast(); if (!argTy) return func.emitError( "llvm.sret attribute attached to LLVM non-pointer argument"); - llvmArg.addAttrs( - llvm::AttrBuilder(llvmArg.getContext()) - .addStructRetAttr(convertType(argTy.getElementType()))); + if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue()) + return func.emitError("llvm.sret attribute attached to LLVM pointer " + "argument of a different type"); + llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext()) + .addStructRetAttr(convertType(attr.getValue()))); } - if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { auto argTy = mlirArg.getType().dyn_cast(); if (!argTy) return func.emitError( "llvm.byval attribute attached to LLVM non-pointer argument"); + if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue()) + return func.emitError("llvm.byval attribute attached to LLVM pointer " + "argument of a different type"); llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext()) - .addByValAttr(convertType(argTy.getElementType()))); + .addByValAttr(convertType(attr.getValue()))); } if (auto attr = func.getArgAttrOfType(argIdx, "llvm.nest")) { diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -93,9 +93,9 @@ llvm.return } - // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret}) - // LOCINFO: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret} loc("some_source_loc")) - llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret} loc("some_source_loc")) { + // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret = i32}) + // LOCINFO: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret = i32} loc("some_source_loc")) + llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret = i32} loc("some_source_loc")) { llvm.return } diff --git a/mlir/test/Target/LLVMIR/Import/func-attrs.ll b/mlir/test/Target/LLVMIR/Import/func-attrs.ll new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/func-attrs.ll @@ -0,0 +1,6 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s + +; CHECK: llvm.func @foo(%arg0: !llvm.ptr {llvm.byval = i64}, %arg1: !llvm.ptr {llvm.byref = i64}, %arg2: !llvm.ptr {llvm.sret = i64}, %arg3: !llvm.ptr {llvm.inalloca = i64}) +define void @foo(ptr byval(i64) %arg0, ptr byref(i64) %arg1, ptr sret(i64) %arg2, ptr inalloca(i64) %arg3) { + ret void +} diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -15,7 +15,7 @@ // ----- // expected-error @+1 {{llvm.sret attribute attached to LLVM non-pointer argument}} -llvm.func @invalid_sret(%arg0 : f32 {llvm.sret}) -> f32 { +llvm.func @invalid_sret(%arg0 : f32 {llvm.sret = f32}) -> f32 { llvm.return %arg0 : f32 } @@ -28,7 +28,7 @@ // ----- // expected-error @+1 {{llvm.byval attribute attached to LLVM non-pointer argument}} -llvm.func @invalid_byval(%arg0 : f32 {llvm.byval}) -> f32 { +llvm.func @invalid_byval(%arg0 : f32 {llvm.byval = f32}) -> f32 { llvm.return %arg0 : f32 } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1050,12 +1050,12 @@ } // CHECK-LABEL: define void @byvalattr(ptr byval(i32) % -llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval}) { +llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval = i32}) { llvm.return } // CHECK-LABEL: define void @sretattr(ptr sret(i32) % -llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret}) { +llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret = i32}) { llvm.return }