diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -859,6 +859,15 @@ .addByValAttr(convertType(argTy.getElementType()))); } + if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.nest")) { + auto argTy = mlirArg.getType(); + if (!argTy.isa<LLVM::LLVMPointerType>()) + return func.emitError( + "llvm.nest attribute attached to LLVM non-pointer argument"); + llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext()) + .addAttribute(llvm::Attribute::Nest)); + } + mapValue(mlirArg, &llvmArg); argIdx++; } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -97,6 +97,11 @@ llvm.return } + // CHECK: llvm.func @nestattr(%{{.*}}: !llvm.ptr<i32> {llvm.nest}) + llvm.func @nestattr(%arg0: !llvm.ptr<i32> {llvm.nest}) { + llvm.return + } + // CHECK: llvm.func @variadic(...) llvm.func @variadic(...) diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -15,13 +15,20 @@ // ----- // expected-error @+1 {{llvm.sret attribute attached to LLVM non-pointer argument}} -llvm.func @invalid_noalias(%arg0 : f32 {llvm.sret}) -> f32 { +llvm.func @invalid_sret(%arg0 : f32 {llvm.sret}) -> f32 { + llvm.return %arg0 : f32 +} + +// ----- + +// expected-error @+1 {{llvm.nest attribute attached to LLVM non-pointer argument}} +llvm.func @invalid_nest(%arg0 : f32 {llvm.nest}) -> f32 { llvm.return %arg0 : f32 } // ----- // expected-error @+1 {{llvm.byval attribute attached to LLVM non-pointer argument}} -llvm.func @invalid_noalias(%arg0 : f32 {llvm.byval}) -> f32 { +llvm.func @invalid_byval(%arg0 : f32 {llvm.byval}) -> f32 { llvm.return %arg0 : f32 } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1057,6 +1057,11 @@ llvm.return } +// CHECK-LABEL: define void @nestattr(i32* nest % +llvm.func @nestattr(%arg0: !llvm.ptr<i32> {llvm.nest}) { + llvm.return +} + // CHECK-LABEL: define void @llvm_align(float* align 4 {{%*.}}) llvm.func @llvm_align(%arg0: !llvm.ptr<f32> {llvm.align = 4}) { llvm.return