diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -23,6 +23,23 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirType pointee, unsigned addressSpace); +/// Creates an llmv.void type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx); + +/// Creates an llvm.array type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGet(MlirType elementType, + unsigned numElements); + +/// Creates an llvm.func type. +MLIR_CAPI_EXPORTED MlirType +mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, + MlirType const *argumentTypes, bool isVarArg); + +/// Creates an LLVM literal (unnamed) struct type. +MLIR_CAPI_EXPORTED MlirType +mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -19,3 +19,28 @@ MlirType mlirLLVMPointerTypeGet(MlirType pointee, unsigned addressSpace) { return wrap(LLVMPointerType::get(unwrap(pointee), addressSpace)); } + +MlirType mlirLLVMVoidTypeGet(MlirContext ctx) { + return wrap(LLVMVoidType::get(unwrap(ctx))); +} + +MlirType mlirLLVMArrayTypeGet(MlirType elementType, unsigned numElements) { + return wrap(LLVMArrayType::get(unwrap(elementType), numElements)); +} + +MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, + MlirType const *argumentTypes, bool isVarArg) { + SmallVector argumentStorage; + return wrap(LLVMFunctionType::get( + unwrap(resultType), + unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg)); +} + +MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fieldStorage; + return wrap(LLVMStructType::getLiteral( + unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage), + isPacked)); +} diff --git a/mlir/test/CAPI/llvm.c b/mlir/test/CAPI/llvm.c --- a/mlir/test/CAPI/llvm.c +++ b/mlir/test/CAPI/llvm.c @@ -22,7 +22,9 @@ // CHECK-LABEL: testTypeCreation() static void testTypeCreation(MlirContext ctx) { fprintf(stderr, "testTypeCreation()\n"); + MlirType i8 = mlirIntegerTypeGet(ctx, 8); MlirType i32 = mlirIntegerTypeGet(ctx, 32); + MlirType i64 = mlirIntegerTypeGet(ctx, 64); const char *i32p_text = "!llvm.ptr"; MlirType i32p = mlirLLVMPointerTypeGet(i32, 0); @@ -35,6 +37,37 @@ MlirType i32p4_ref = mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(i32p4_text)); // CHECK: !llvm.ptr: 1 fprintf(stderr, "%s: %d\n", i32p4_text, mlirTypeEqual(i32p4, i32p4_ref)); + + const char *voidt_text = "!llvm.void"; + MlirType voidt = mlirLLVMVoidTypeGet(ctx); + MlirType voidt_ref = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(voidt_text)); + // CHECK: !llvm.void: 1 + fprintf(stderr, "%s: %d\n", voidt_text, mlirTypeEqual(voidt, voidt_ref)); + + const char *i32_4_text = "!llvm.array<4xi32>"; + MlirType i32_4 = mlirLLVMArrayTypeGet(i32, 4); + MlirType i32_4_ref = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(i32_4_text)); + // CHECK: !llvm.array<4xi32>: 1 + fprintf(stderr, "%s: %d\n", i32_4_text, mlirTypeEqual(i32_4, i32_4_ref)); + + const char *i8_i32_i64_text = "!llvm.func"; + const MlirType i32_i64_arr[] = {i32, i64}; + MlirType i8_i32_i64 = mlirLLVMFunctionTypeGet(i8, 2, i32_i64_arr, false); + MlirType i8_i32_i64_ref = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(i8_i32_i64_text)); + // CHECK: !llvm.func: 1 + fprintf(stderr, "%s: %d\n", i8_i32_i64_text, + mlirTypeEqual(i8_i32_i64, i8_i32_i64_ref)); + + const char *i32_i64_s_text = "!llvm.struct<(i32, i64)>"; + MlirType i32_i64_s = mlirLLVMStructTypeLiteralGet(ctx, 2, i32_i64_arr, false); + MlirType i32_i64_s_ref = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(i32_i64_s_text)); + // CHECK: !llvm.struct<(i32, i64)>: 1 + fprintf(stderr, "%s: %d\n", i32_i64_s_text, + mlirTypeEqual(i32_i64_s, i32_i64_s_ref)); } int main() {