diff --git a/mlir/docs/ConversionToLLVMDialect.md b/mlir/docs/ConversionToLLVMDialect.md --- a/mlir/docs/ConversionToLLVMDialect.md +++ b/mlir/docs/ConversionToLLVMDialect.md @@ -374,12 +374,14 @@ ### C-compatible wrapper emission -In practical cases, it may be desirable to have externally-facing functions -with a single attribute corresponding to a MemRef argument. When interfacing -with LLVM IR produced from C, the code needs to respect the corresponding -calling convention. The conversion to the LLVM dialect provides an option to -generate wrapper functions that take memref descriptors as pointers-to-struct -compatible with data types produced by Clang when compiling C sources. +In practical cases, it may be desirable to have externally-facing functions with +a single attribute corresponding to a MemRef argument. When interfacing with +LLVM IR produced from C, the code needs to respect the corresponding calling +convention. The conversion to the LLVM dialect provides an option to generate +wrapper functions that take memref descriptors as pointers-to-struct compatible +with data types produced by Clang when compiling C sources. The generation of +such wrapper functions can additionally be controlled at a function granularity +by setting the `llvm.emit_c_interface` unit attribute. More specifically, a memref argument is converted into a pointer-to-struct argument of type `{T*, T*, i64, i64[N], i64[N]}*` in the wrapper function, where diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/ADT/TypeSwitch.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -932,6 +932,7 @@ /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. +static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter, bool emitCWrappers) : FuncOpConversionBase(converter), emitWrappers(emitCWrappers) {} @@ -942,7 +943,7 @@ auto funcOp = cast(op); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); - if (emitWrappers) { + if (emitWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); @@ -1130,8 +1131,7 @@ ////////////// End Support for Lowering operations on n-D vectors ////////////// namespace { -template -struct OpCountValidator { +template struct OpCountValidator { static_assert( std::is_base_of< typename OpTrait::NOperands::template Impl, @@ -1139,14 +1139,12 @@ "wrong operand count"); }; -template -struct OpCountValidator { +template struct OpCountValidator { static_assert(std::is_base_of, SourceOp>::value, "expected a single operand"); }; -template -void ValidateOpCount() { +template void ValidateOpCount() { OpCountValidator(); } } // namespace @@ -2843,7 +2841,6 @@ if (failed(applyPartialConversion(m, target, patterns, &typeConverter))) signalPassFailure(); } - }; } // end namespace diff --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir --- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -convert-std-to-llvm='emit-c-wrappers=1' %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s --check-prefix=EMIT_C_ATTRIBUTE // This tests the default memref calling convention and the emission of C // wrappers. We don't need to separate runs because the wrapper-emission @@ -72,6 +73,7 @@ } // CHECK-LABEL: @callee +// EMIT_C_ATTRIBUTE-LABEL: @callee func @callee(%arg0: memref, %arg1: index) { %0 = load %arg0[%arg1] : memref return @@ -93,3 +95,17 @@ // Forward the descriptor components to the call. // CHECK: llvm.call @callee(%[[ALLOC]], %[[ALIGN]], %[[OFFSET]], %[[SIZE]], %[[STRIDE]], %{{.*}}) : (!llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64) -> () +// EMIT_C_ATTRIBUTE-NOT: @mlir_ciface_callee + +// CHECK-LABEL: @other_callee +// EMIT_C_ATTRIBUTE-LABEL: @other_callee +func @other_callee(%arg0: memref, %arg1: index) attributes { llvm.emit_c_interface } { + %0 = load %arg0[%arg1] : memref + return +} + +// CHECK: @_mlir_ciface_other_callee +// CHECK: llvm.call @other_callee + +// EMIT_C_ATTRIBUTE: @_mlir_ciface_other_callee +// EMIT_C_ATTRIBUTE: llvm.call @other_callee