diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1114,6 +1114,24 @@ llvm::AttrBuilder().addAlignmentAttr(llvm::Align(attr.getInt()))); } + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.sret")) { + auto argTy = mlirArg.getType().dyn_cast(); + if (!argTy.isa()) + return func.emitError( + "llvm.sret attribute attached to LLVM non-pointer argument"); + if (attr.getValue()) + llvmArg.addAttr(llvm::Attribute::AttrKind::StructRet); + } + + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { + auto argTy = mlirArg.getType().dyn_cast(); + if (!argTy.isa()) + return func.emitError( + "llvm.byval attribute attached to LLVM non-pointer argument"); + if (attr.getValue()) + llvmArg.addAttr(llvm::Attribute::AttrKind::ByVal); + } + valueMapping[mlirArg] = &llvmArg; argIdx++; } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -87,6 +87,16 @@ llvm.return } + // CHECK: llvm.func @byvalattr(%{{.*}}: !llvm.ptr {llvm.byval = true}) + llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval = true}) { + llvm.return + } + + // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret = true}) + llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret = true}) { + llvm.return + } + // CHECK: llvm.func @variadic(...) llvm.func @variadic(...)