diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1182,6 +1182,34 @@ UnknownLoc::get(context), f->getName(), functionType, convertLinkageFromLLVM(f->getLinkage()), dsoLocal, cconv); + for (const auto &arg : llvm::enumerate(functionType.getParams())) { + llvm::SmallVector argAttrs; + llvm::errs() << "ARG " << arg.index() << "\n"; + if (auto *type = f->getParamByValType(arg.index())) { + llvm::errs() << "BYVAL " << arg.index() << "\n"; + auto mlirType = processType(type); + argAttrs.push_back(NamedAttribute(b.getStringAttr("llvm.byval"), + TypeAttr::get(mlirType))); + } + if (auto *type = f->getParamByRefType(arg.index())) { + auto mlirType = processType(type); + argAttrs.push_back(NamedAttribute(b.getStringAttr("llvm.byref"), + TypeAttr::get(mlirType))); + } + if (auto *type = f->getParamStructRetType(arg.index())) { + auto mlirType = processType(type); + argAttrs.push_back(NamedAttribute(b.getStringAttr("llvm.sret"), + TypeAttr::get(mlirType))); + } + if (auto *type = f->getParamInAllocaType(arg.index())) { + auto mlirType = processType(type); + argAttrs.push_back(NamedAttribute(b.getStringAttr("llvm.inalloca"), + TypeAttr::get(mlirType))); + } + + fop.setArgAttrs(arg.index(), argAttrs); + } + if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f)) fop->setAttr(b.getStringAttr("personality"), personality); else if (f->hasPersonalityFn()) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -840,23 +840,28 @@ .addAlignmentAttr(llvm::Align(attr.getInt()))); } - if (auto attr = func.getArgAttrOfType(argIdx, "llvm.sret")) { + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.sret")) { auto argTy = mlirArg.getType().dyn_cast(); if (!argTy) return func.emitError( "llvm.sret attribute attached to LLVM non-pointer argument"); - llvmArg.addAttrs( - llvm::AttrBuilder(llvmArg.getContext()) - .addStructRetAttr(convertType(argTy.getElementType()))); + if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue()) + return func.emitError("llvm.sret attribute attached to LLVM pointer " + "argument of a different type"); + llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext()) + .addStructRetAttr(convertType(attr.getValue()))); } - if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { auto argTy = mlirArg.getType().dyn_cast(); if (!argTy) return func.emitError( "llvm.byval attribute attached to LLVM non-pointer argument"); + if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue()) + return func.emitError("llvm.byval attribute attached to LLVM pointer " + "argument of a different type"); llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext()) - .addByValAttr(convertType(argTy.getElementType()))); + .addByValAttr(convertType(attr.getValue()))); } if (auto attr = func.getArgAttrOfType(argIdx, "llvm.nest")) { diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -93,9 +93,9 @@ llvm.return } - // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret}) - // LOCINFO: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret} loc("some_source_loc")) - llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret} loc("some_source_loc")) { + // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret = i32}) + // LOCINFO: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret = i32} loc("some_source_loc")) + llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret = i32} loc("some_source_loc")) { llvm.return } diff --git a/mlir/test/Target/LLVMIR/Import/func-attrs.ll b/mlir/test/Target/LLVMIR/Import/func-attrs.ll new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/func-attrs.ll @@ -0,0 +1,6 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s + +; CHECK: llvm.func @foo(%arg0: !llvm.ptr {llvm.byval = i64}, %arg1: !llvm.ptr {llvm.byref = i64}, %arg2: !llvm.ptr {llvm.sret = i64}, %arg3: !llvm.ptr {llvm.inalloca = i64}) +define void @foo(ptr byval(i64) %arg0, ptr byref(i64) %arg1, ptr sret(i64) %arg2, ptr inalloca(i64) %arg3) { + ret void +} diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -15,7 +15,7 @@ // ----- // expected-error @+1 {{llvm.sret attribute attached to LLVM non-pointer argument}} -llvm.func @invalid_sret(%arg0 : f32 {llvm.sret}) -> f32 { +llvm.func @invalid_sret(%arg0 : f32 {llvm.sret = f32}) -> f32 { llvm.return %arg0 : f32 } @@ -28,7 +28,7 @@ // ----- // expected-error @+1 {{llvm.byval attribute attached to LLVM non-pointer argument}} -llvm.func @invalid_byval(%arg0 : f32 {llvm.byval}) -> f32 { +llvm.func @invalid_byval(%arg0 : f32 {llvm.byval = f32}) -> f32 { llvm.return %arg0 : f32 } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1050,12 +1050,12 @@ } // CHECK-LABEL: define void @byvalattr(ptr byval(i32) % -llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval}) { +llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval = i32}) { llvm.return } // CHECK-LABEL: define void @sretattr(ptr sret(i32) % -llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret}) { +llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret = i32}) { llvm.return }