diff --git a/mlir/docs/LLVMDialectMemRefConvention.md b/mlir/docs/LLVMDialectMemRefConvention.md --- a/mlir/docs/LLVMDialectMemRefConvention.md +++ b/mlir/docs/LLVMDialectMemRefConvention.md @@ -232,29 +232,40 @@ }; ``` +Furthermore, we also rewrite function results to reference parameters if the +rewritten function result has a struct type. The special result parameter is +added as the first parameter and is of pointer-to-struct type. + If enabled, the option will do the following. For _external_ functions declared in the MLIR module. 1. Declare a new function `_mlir_ciface_` where memref arguments are converted to pointer-to-struct and the remaining arguments are converted - as usual. -1. Add a body to the original function (making it non-external) that - 1. allocates a memref descriptor, - 1. populates it, and - 1. passes the pointer to it into the newly declared interface function, + as usual. Results are converted to a special argument if they are of struct + type. +2. Add a body to the original function (making it non-external) that + 1. allocates memref descriptors, + 2. populates them, + 3. potentially allocates space for the result struct, and + 4. passes the pointers to these into the newly declared interface function, then - 1. collects the result of the call and returns it to the caller. + 5. collects the result of the call (potentially from the result struct), + and + 6. returns it to the caller. For (non-external) functions defined in the MLIR module. 1. Define a new function `_mlir_ciface_` where memref arguments are converted to pointer-to-struct and the remaining arguments are converted - as usual. -1. Populate the body of the newly defined function with IR that + as usual. Results are converted to a special argument if they are of struct + type. +2. Populate the body of the newly defined function with IR that 1. loads descriptors from pointers; - 1. unpacks descriptor into individual non-aggregate values; - 1. passes these values into the original function; - 1. collects the result of the call and returns it to the caller. + 2. unpacks descriptor into individual non-aggregate values; + 3. passes these values into the original function; + 4. collects the results of the call and + 5. either copies the results into the result struct or returns them to the + caller. Examples: @@ -342,6 +353,57 @@ } ``` +```mlir +func @foo(%arg0: memref) -> memref { + return %arg0 : memref +} + +// Gets converted into the following +// (using type alias for brevity): +!llvm.memref_2d = type !llvm.struct<(ptr, ptr, i64, + array<2xi64>, array<2xi64>)> +!llvm.memref_2d_ptr = type !llvm.ptr, ptr, i64, + array<2xi64>, array<2xi64>)>> + +// Function with unpacked arguments. +llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, + %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) + -> !llvm.memref_2d { + %0 = llvm.mlir.undef : !llvm.memref_2d + %1 = llvm.insertvalue %arg0, %0[0] : !llvm.memref_2d + %2 = llvm.insertvalue %arg1, %1[1] : !llvm.memref_2d + %3 = llvm.insertvalue %arg2, %2[2] : !llvm.memref_2d + %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.memref_2d + %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.memref_2d + %6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm.memref_2d + %7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.memref_2d + llvm.return %7 : !llvm.memref_2d +} + +// Interface function callable from C. +llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr, %arg1: !llvm.memref_2d_ptr) { + %0 = llvm.load %arg1 : !llvm.memref_2d_ptr + %1 = llvm.extractvalue %0[0] : !llvm.memref_2d + %2 = llvm.extractvalue %0[1] : !llvm.memref_2d + %3 = llvm.extractvalue %0[2] : !llvm.memref_2d + %4 = llvm.extractvalue %0[3, 0] : !llvm.memref_2d + %5 = llvm.extractvalue %0[3, 1] : !llvm.memref_2d + %6 = llvm.extractvalue %0[4, 0] : !llvm.memref_2d + %7 = llvm.extractvalue %0[4, 1] : !llvm.memref_2d + %8 = llvm.call @foo(%1, %2, %3, %4, %5, %6, %7) + : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> !llvm.memref_2d + llvm.store %8, %arg0 : !llvm.memref_2d_ptr + llvm.return +} + +// Function with unpacked arguments. +llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr, + %arg2: i64, %arg3: i64, %arg4: i64, + %arg5: i64, %arg6: i64) { + llvm.return +} +``` + Rationale: Introducing auxiliary functions for C-compatible interfaces is preferred to modifying the calling convention since it will minimize the effect of C compatibility on intra-module calls or calls between MLIR-generated diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -116,8 +116,10 @@ OpBuilder &builder); /// Converts the function type to a C-compatible format, in particular using - /// pointers to memref descriptors for arguments. - Type convertFunctionTypeCWrapper(FunctionType type); + /// pointers to memref descriptors for arguments. Also converts the return + /// type to a pointer argument if it is a struct. Returns true if this + /// was the case. + std::pair convertFunctionTypeCWrapper(FunctionType type); /// Returns the data layout to use during and after conversion. const llvm::DataLayout &getDataLayout() { return options.dataLayout; } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -252,8 +252,24 @@ /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. -Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { +std::pair +LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { SmallVector inputs; + bool resultIsNowArg = false; + + Type resultType = type.getNumResults() == 0 + ? LLVM::LLVMVoidType::get(&getContext()) + : unwrap(packFunctionResults(type.getResults())); + if (!resultType) + return {}; + + if (auto structType = resultType.dyn_cast()) { + // Struct types cannot be safely returned via C interface. Make this a + // pointer argument, instead. + inputs.push_back(LLVM::LLVMPointerType::get(structType)); + resultType = LLVM::LLVMVoidType::get(&getContext()); + resultIsNowArg = true; + } for (Type t : type.getInputs()) { auto converted = convertType(t); @@ -264,13 +280,7 @@ inputs.push_back(converted); } - Type resultType = type.getNumResults() == 0 - ? LLVM::LLVMVoidType::get(&getContext()) - : unwrap(packFunctionResults(type.getResults())); - if (!resultType) - return {}; - - return LLVM::LLVMFunctionType::get(resultType, inputs); + return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg}; } static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; @@ -1211,8 +1221,11 @@ /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. This function can be called from C /// by passing a pointer to a C struct corresponding to a memref descriptor. +/// Similarly, returned memrefs are passed via pointers to a C struct that is +/// passed as additional argument. /// Internally, the auxiliary function unpacks the descriptor into individual -/// components and forwards them to `newFuncOp`. +/// components and forwards them to `newFuncOp` and forwards the results to +/// the extra arguments. static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { @@ -1220,17 +1233,21 @@ SmallVector attributes; filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false, attributes); + Type wrapperFuncType; + bool resultIsNowArg; + std::tie(wrapperFuncType, resultIsNowArg) = + typeConverter.convertFunctionTypeCWrapper(type); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), - typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External, - attributes); + wrapperFuncType, LLVM::Linkage::External, attributes); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); SmallVector args; + size_t argOffset = resultIsNowArg ? 1 : 0; for (auto &en : llvm::enumerate(type.getInputs())) { - Value arg = wrapperFuncOp.getArgument(en.index()); + Value arg = wrapperFuncOp.getArgument(en.index() + argOffset); if (auto memrefType = en.value().dyn_cast()) { Value loaded = rewriter.create(loc, arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); @@ -1242,28 +1259,40 @@ continue; } - args.push_back(wrapperFuncOp.getArgument(en.index())); + args.push_back(arg); } + auto call = rewriter.create(loc, newFuncOp, args); - rewriter.create(loc, call.getResults()); + + if (resultIsNowArg) { + rewriter.create(loc, call.getResult(0), + wrapperFuncOp.getArgument(0)); + rewriter.create(loc, ValueRange{}); + } else { + rewriter.create(loc, call.getResults()); + } } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. Creates a body for the (external) /// `newFuncOp` that allocates a memref descriptor on stack, packs the /// individual arguments into this descriptor and passes a pointer to it into -/// the auxiliary function. This auxiliary external function is now compatible -/// with functions defined in C using pointers to C structs corresponding to a -/// memref descriptor. +/// the auxiliary function. If the result of the function cannot be directly +/// returned, we write it to a special first argument that provides a pointer +/// to a corresponding struct. This auxiliary external function is now +/// compatible with functions defined in C using pointers to C structs +/// corresponding to a memref descriptor. static void wrapExternalFunction(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { OpBuilder::InsertionGuard guard(builder); - Type wrapperType = + Type wrapperType; + bool resultIsNowArg; + std::tie(wrapperType, resultIsNowArg) = typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); // This conversion can only fail if it could not convert one of the argument - // types. But since it has been applies to a non-wrapper function before, it + // types. But since it has been applied to a non-wrapper function before, it // should have failed earlier and not reach this point at all. assert(wrapperType && "unexpected type conversion failure"); @@ -1284,6 +1313,17 @@ args.reserve(type.getNumInputs()); ValueRange wrapperArgsRange(newFuncOp.getArguments()); + if (resultIsNowArg) { + // Allocate the struct on the stack and pass the pointer. + Type resultType = + wrapperType.dyn_cast().getParamType(0); + Value one = builder.create( + loc, typeConverter.convertType(builder.getIndexType()), + builder.getIntegerAttr(builder.getIndexType(), 1)); + Value result = builder.create(loc, resultType, one); + args.push_back(result); + } + // Iterate over the inputs of the original function and pack values into // memref descriptors if the original type is a memref. for (auto &en : llvm::enumerate(type.getInputs())) { @@ -1321,7 +1361,13 @@ assert(wrapperArgsRange.empty() && "did not map some of the arguments"); auto call = builder.create(loc, wrapperFunc, args); - builder.create(loc, call.getResults()); + + if (resultIsNowArg) { + Value result = builder.create(loc, args.front()); + builder.create(loc, ValueRange{result}); + } else { + builder.create(loc, call.getResults()); + } } namespace { diff --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir --- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir @@ -144,7 +144,7 @@ } // CHECK-LABEL: llvm.func @return_var_memref -func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> { +func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes { llvm.emit_c_interface } { // Match the construction of the unranked descriptor. // CHECK: %[[ALLOCA:.*]] = llvm.alloca // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]] @@ -177,6 +177,10 @@ return %0 : memref<*xf32> } +// Check that the result memref is passed as parameter +// CHECK-LABEL: @_mlir_ciface_return_var_memref +// CHECK-SAME: (%{{.*}}: !llvm.ptr)>>, %{{.*}}: !llvm.ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>>) + // CHECK-LABEL: llvm.func @return_two_var_memref_caller func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) { // Only check that we create two different descriptors using different @@ -206,7 +210,7 @@ } // CHECK-LABEL: llvm.func @return_two_var_memref -func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) { +func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) attributes { llvm.emit_c_interface } { // Match the construction of the unranked descriptor. // CHECK: %[[ALLOCA:.*]] = llvm.alloca // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]] @@ -240,3 +244,8 @@ return %0, %0 : memref<*xf32>, memref<*xf32> } +// Check that the result memrefs are passed as parameter +// CHECK-LABEL: @_mlir_ciface_return_two_var_memref +// CHECK-SAME: (%{{.*}}: !llvm.ptr)>, struct<(i64, ptr)>)>>, +// CHECK-SAME: %{{.*}}: !llvm.ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>>) +