diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1102,6 +1102,22 @@ llvm::AttrBuilder().addAlignmentAttr(llvm::Align(attr.getInt()))); } + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.sret")) { + auto argTy = mlirArg.getType().dyn_cast(); + if (!argTy.isa()) + return func.emitError( + "llvm.sret attribute attached to LLVM non-pointer argument"); + llvmArg.addAttr(llvm::Attribute::AttrKind::StructRet); + } + + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { + auto argTy = mlirArg.getType().dyn_cast(); + if (!argTy.isa()) + return func.emitError( + "llvm.byval attribute attached to LLVM non-pointer argument"); + llvmArg.addAttr(llvm::Attribute::AttrKind::ByVal); + } + valueMapping[mlirArg] = &llvmArg; argIdx++; } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -87,6 +87,16 @@ llvm.return } + // CHECK: llvm.func @byvalattr(%{{.*}}: !llvm.ptr {llvm.byval}) + llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval}) { + llvm.return + } + + // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret}) + llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret}) { + llvm.return + } + // CHECK: llvm.func @variadic(...) llvm.func @variadic(...) diff --git a/mlir/test/Target/llvmir-invalid.mlir b/mlir/test/Target/llvmir-invalid.mlir --- a/mlir/test/Target/llvmir-invalid.mlir +++ b/mlir/test/Target/llvmir-invalid.mlir @@ -14,6 +14,19 @@ // ----- +// expected-error @+1 {{llvm.sret attribute attached to LLVM non-pointer argument}} +llvm.func @invalid_noalias(%arg0 : !llvm.float {llvm.sret}) -> !llvm.float { + llvm.return %arg0 : !llvm.float +} +// ----- + +// expected-error @+1 {{llvm.byval attribute attached to LLVM non-pointer argument}} +llvm.func @invalid_noalias(%arg0 : !llvm.float {llvm.byval}) -> !llvm.float { + llvm.return %arg0 : !llvm.float +} + +// ----- + // expected-error @+1 {{llvm.align attribute attached to LLVM non-pointer argument}} llvm.func @invalid_align(%arg0 : !llvm.float {llvm.align = 4}) -> !llvm.float { llvm.return %arg0 : !llvm.float