diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1105,7 +1105,8 @@ if (!argTy.isa()) return func.emitError( "llvm.sret attribute attached to LLVM non-pointer argument"); - llvmArg.addAttr(llvm::Attribute::AttrKind::StructRet); + llvmArg.addAttrs(llvm::AttrBuilder().addStructRetAttr( + llvmArg.getType()->getPointerElementType())); } if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { @@ -1113,7 +1114,8 @@ if (!argTy.isa()) return func.emitError( "llvm.byval attribute attached to LLVM non-pointer argument"); - llvmArg.addAttr(llvm::Attribute::AttrKind::ByVal); + llvmArg.addAttrs(llvm::AttrBuilder().addByValAttr( + llvmArg.getType()->getPointerElementType())); } valueMapping[mlirArg] = &llvmArg; diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -939,6 +939,16 @@ llvm.return } +// CHECK-LABEL: define void @byvalattr(i32* byval(i32) % +llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval}) { + llvm.return +} + +// CHECK-LABEL: define void @sretattr(i32* sret(i32) % +llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret}) { + llvm.return +} + // CHECK-LABEL: define void @llvm_align(float* align 4 {{%*.}}) llvm.func @llvm_align(%arg0: !llvm.ptr {llvm.align = 4}) { llvm.return