Index: mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp =================================================================== --- mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -40,6 +40,7 @@ #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" +#include #include using namespace mlir; @@ -50,19 +51,71 @@ /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. static void filterFuncAttributes(ArrayRef attrs, - bool filterArgAttrs, + bool filterArgAndResAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.getName() == SymbolTable::getSymbolAttrName() || attr.getName() == FunctionOpInterface::getTypeAttrName() || attr.getName() == "std.varargs" || - (filterArgAttrs && - attr.getName() == FunctionOpInterface::getArgDictAttrName())) + (filterArgAndResAttrs && + (attr.getName() == FunctionOpInterface::getArgDictAttrName() || + attr.getName() == FunctionOpInterface::getResultDictAttrName()))) continue; result.push_back(attr); } } +/// Helper function for wrapping all attributes into a single DictionaryAttr +static constexpr StringRef kStructAttrs = "llvm.struct_attrs"; +static auto wrapAsStructAttrs(OpBuilder &b, Attribute attrs) { + return DictionaryAttr::get(b.getContext(), + b.getNamedAttr(kStructAttrs, attrs)); +} + +/// Combines all result attributes into a single DictionaryAttr +/// and prepends to argument attrs. +/// This is intended to be used to format the attributes for a C wrapper +/// function when the result(s) is converted to the first function argument +/// (in the multiple return case, all returns get wrapped into a single +/// argument). The total number of argument attributes should be equal to +/// (number of function arguments) + 1. +static void +prependResAttrsToArgAttrs(OpBuilder &builder, + SmallVectorImpl &attributes, + size_t numArguments) { + auto allAttrs = SmallVector( + numArguments + 1, DictionaryAttr::get(builder.getContext())); + NamedAttribute *argAttrs = nullptr; + for (auto it = attributes.begin(); it != attributes.end();) { + if (it->getName() == FunctionOpInterface::getArgDictAttrName()) { + auto arrayAttrs = it->getValue().cast(); + assert(arrayAttrs.size() == numArguments && + "Number of arg attrs and args should match"); + std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1); + argAttrs = it; + } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) { + auto arrayAttrs = it->getValue().cast(); + assert(!arrayAttrs.empty() && "expected array to be non-empty"); + allAttrs[0] = (arrayAttrs.size() == 1) + ? arrayAttrs[0] + : wrapAsStructAttrs(builder, arrayAttrs); + it = attributes.erase(it); + continue; + } + it++; + } + + auto newArgAttrs = + builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), + builder.getArrayAttr(allAttrs)); + if (!argAttrs) { + attributes.emplace_back(newArgAttrs); + return; + } + *argAttrs = newArgAttrs; + return; +} + /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. This function can be called from C /// by passing a pointer to a C struct corresponding to a memref descriptor. @@ -76,12 +129,14 @@ FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getType(); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false, + filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, attributes); Type wrapperFuncType; bool resultIsNowArg; std::tie(wrapperFuncType, resultIsNowArg) = typeConverter.convertFunctionTypeCWrapper(type); + if (resultIsNowArg) + prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments()); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes); @@ -142,9 +197,11 @@ assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false, + filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, attributes); + if (resultIsNowArg) + prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), @@ -235,11 +292,21 @@ if (!llvmType) return nullptr; - // Propagate argument attributes to all converted arguments obtained after - // converting a given original argument. + // Propagate argument/result attributes to all converted arguments/result + // obtained after converting a given original argument/result. SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, + filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true, attributes); + if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { + assert(!resAttrDicts.empty() && "expected array to be non-empty"); + auto newResAttrDicts = + (funcOp.getNumResults() == 1) + ? resAttrDicts + : rewriter.getArrayAttr( + {wrapAsStructAttrs(rewriter, resAttrDicts)}); + attributes.push_back(rewriter.getNamedAttr( + FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); + } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( llvmType.cast().getNumParams()); Index: mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-callers.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-callers.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt -convert-std-to-llvm='emit-c-wrappers=1' %s | FileCheck %s + +// CHECK: llvm.func @res_attrs_with_memref_return() -> (!llvm.struct{{.*}} {test.returnOne}) +// CHECK-LABEL: llvm.func @_mlir_ciface_res_attrs_with_memref_return +// CHECK: %{{.*}}: !llvm.ptr{{.*}} {test.returnOne} +func @res_attrs_with_memref_return() -> (memref {test.returnOne}) { + %0 = memref.alloc() : memref + return %0 : memref +} + +// CHECK: llvm.func @res_attrs_with_value_return() -> (f32 {test.returnOne = 1 : i64}) +// CHECK-LABEL: llvm.func @_mlir_ciface_res_attrs_with_value_return +// CHECK: -> (f32 {test.returnOne = 1 : i64}) +func @res_attrs_with_value_return() -> (f32 {test.returnOne = 1}) { + %0 = arith.constant 1.00 : f32 + return %0 : f32 +} + +// CHECK: llvm.func @multiple_return() -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]}) +// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_return +// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]}) +func @multiple_return() -> (memref {test.returnOne = 1}, f32 {test.returnTwo = 2, test.returnThree = 3}) { + %0 = memref.alloc() : memref + %1 = arith.constant 1.00 : f32 + return %0, %1 : memref, f32 +} + +// CHECK: llvm.func @multiple_return_missing_res_attr() -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]}) +// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_return_missing_res_attr +// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]}) +func @multiple_return_missing_res_attr() -> (memref {test.returnOne = 1}, i64, f32 {test.returnTwo = 2, test.returnThree = 3}) { + %0 = memref.alloc() : memref + %1 = arith.constant 2 : i64 + %2 = arith.constant 1.00 : f32 + return %0, %1, %2 : memref, i64, f32 +} + +// CHECK: llvm.func @one_arg_attr_no_res_attrs_with_memref_return({{.*}}) -> !llvm.struct{{.*}} +// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_no_res_attrs_with_memref_return +// CHECK: %{{.*}}: !llvm.ptr<{{.*}}>, %{{.*}}: !llvm.ptr<{{.*}}> {test.argOne = 1 : i64} +func @one_arg_attr_no_res_attrs_with_memref_return(%arg0: memref {test.argOne = 1}) -> memref { + %0 = memref.alloc() : memref + return %0 : memref +} + +// CHECK: llvm.func @one_arg_attr_one_res_attr_with_memref_return({{.*}}) -> (!llvm.struct<{{.*}}> {test.returnOne = 1 : i64}) +// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_one_res_attr_with_memref_return +// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {test.returnOne = 1 : i64}, %{{.*}}: !llvm.ptr<{{.*}}> {test.argOne = 1 : i64} +func @one_arg_attr_one_res_attr_with_memref_return(%arg0: memref {test.argOne = 1}) -> (memref {test.returnOne = 1}) { + %0 = memref.alloc() : memref + return %0 : memref +} + +// CHECK: llvm.func @one_arg_attr_one_res_attr_with_value_return({{.*}}) -> (f32 {test.returnOne = 1 : i64}) +// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_one_res_attr_with_value_return +// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {test.argOne = 1 : i64}) -> (f32 {test.returnOne = 1 : i64}) +func @one_arg_attr_one_res_attr_with_value_return(%arg0: memref {test.argOne = 1}) -> (f32 {test.returnOne = 1}) { + %0 = arith.constant 1.00 : f32 + return %0 : f32 +} + +// CHECK: llvm.func @multiple_arg_attr_multiple_res_attr({{.*}}) -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{}, {test.returnOne = 1 : i64}, {test.returnTwo = 2 : i64}]}) +// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_arg_attr_multiple_res_attr +// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {llvm.struct_attrs = [{}, {test.returnOne = 1 : i64}, {test.returnTwo = 2 : i64}]}, %{{.*}}: !llvm.ptr<{{.*}}> {test.argZero = 0 : i64}, %{{.*}}: f32, %{{.*}}: i32 {test.argTwo = 2 : i64} +func @multiple_arg_attr_multiple_res_attr(%arg0: memref {test.argZero = 0}, %arg1: f32, %arg2: i32 {test.argTwo = 2}) -> (f32, memref {test.returnOne = 1}, i32 {test.returnTwo = 2}) { + %0 = arith.constant 1.00 : f32 + %1 = memref.alloc() : memref + %2 = arith.constant 2 : i32 + return %0, %1, %2 : f32, memref, i32 +} Index: mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-functions.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-functions.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt -convert-std-to-llvm='emit-c-wrappers=1' %s | FileCheck %s + +// CHECK: llvm.func @res_attrs_with_memref_return() -> (!llvm.struct{{.*}} {test.returnOne}) +// CHECK-LABEL: llvm.func @_mlir_ciface_res_attrs_with_memref_return +// CHECK: !llvm.ptr{{.*}} {test.returnOne} +func private @res_attrs_with_memref_return() -> (memref {test.returnOne}) + +// CHECK: llvm.func @res_attrs_with_value_return() -> (f32 {test.returnOne = 1 : i64}) +// CHECK-LABEL: llvm.func @_mlir_ciface_res_attrs_with_value_return +// CHECK: -> (f32 {test.returnOne = 1 : i64}) +func private @res_attrs_with_value_return() -> (f32 {test.returnOne = 1}) + +// CHECK: llvm.func @multiple_return() -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]}) +// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_return +// CHECK: (!llvm.ptr<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]}) +func private @multiple_return() -> (memref {test.returnOne = 1}, f32 {test.returnTwo = 2, test.returnThree = 3}) + +// CHECK: llvm.func @multiple_return_missing_res_attr() -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]}) +// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_return_missing_res_attr +// CHECK: (!llvm.ptr<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]}) +func private @multiple_return_missing_res_attr() -> (memref {test.returnOne = 1}, i64, f32 {test.returnTwo = 2, test.returnThree = 3}) + +// CHECK: llvm.func @one_arg_attr_no_res_attrs_with_memref_return({{.*}}) -> !llvm.struct{{.*}} +// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_no_res_attrs_with_memref_return +// CHECK: !llvm.ptr<{{.*}}>, !llvm.ptr<{{.*}}> {test.argOne = 1 : i64} +func private @one_arg_attr_no_res_attrs_with_memref_return(%arg0: memref {test.argOne = 1}) -> memref + +// CHECK: llvm.func @one_arg_attr_one_res_attr_with_memref_return({{.*}}) -> (!llvm.struct<{{.*}}> {test.returnOne = 1 : i64}) +// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_one_res_attr_with_memref_return +// CHECK: (!llvm.ptr<{{.*}}> {test.returnOne = 1 : i64}, !llvm.ptr<{{.*}}> {test.argOne = 1 : i64} +func private @one_arg_attr_one_res_attr_with_memref_return(%arg0: memref {test.argOne = 1}) -> (memref {test.returnOne = 1}) + +// CHECK: llvm.func @one_arg_attr_one_res_attr_with_value_return({{.*}}) -> (f32 {test.returnOne = 1 : i64}) +// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_one_res_attr_with_value_return +// CHECK: (!llvm.ptr<{{.*}}> {test.argOne = 1 : i64}) -> (f32 {test.returnOne = 1 : i64}) +func private @one_arg_attr_one_res_attr_with_value_return(%arg0: memref {test.argOne = 1}) -> (f32 {test.returnOne = 1}) + +// CHECK: llvm.func @multiple_arg_attr_multiple_res_attr({{.*}}) -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{}, {test.returnOne = 1 : i64}, {test.returnTwo = 2 : i64}]}) +// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_arg_attr_multiple_res_attr +// CHECK: (!llvm.ptr<{{.*}}> {llvm.struct_attrs = [{}, {test.returnOne = 1 : i64}, {test.returnTwo = 2 : i64}]}, !llvm.ptr<{{.*}}> {test.argZero = 0 : i64}, f32, i32 {test.argTwo = 2 : i64} +func private @multiple_arg_attr_multiple_res_attr(%arg0: memref {test.argZero = 0}, %arg1: f32, %arg2: i32 {test.argTwo = 2}) -> (f32, memref {test.returnOne = 1}, i32 {test.returnTwo = 2})