diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -38,12 +38,53 @@ [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); addConversion([&](VectorType type) { return convertVectorType(type); }); - // LLVM-compatible types are legal, so add a pass-through conversion. + // LLVM-compatible types are legal, so add a pass-through conversion. Do this + // before the conversions below since conversions are attempted in reverse + // order and those should take priority. addConversion([](Type type) { return LLVM::isCompatibleType(type) ? llvm::Optional(type) : llvm::None; }); + // LLVM container types may (recursively) contain other types that must be + // converted even when the outer type is compatible. + addConversion([&](LLVM::LLVMPointerType type) -> llvm::Optional { + if (auto pointee = convertType(type.getElementType())) + return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace()); + return llvm::None; + }); + addConversion([&](LLVM::LLVMStructType type) -> llvm::Optional { + // TODO: handle conversion of identified structs, which may be recursive. + if (type.isIdentified()) + return type; + + SmallVector convertedSubtypes; + convertedSubtypes.reserve(type.getBody().size()); + if (failed(convertTypes(type.getBody(), convertedSubtypes))) + return llvm::None; + + return LLVM::LLVMStructType::getLiteral(type.getContext(), + convertedSubtypes, type.isPacked()); + }); + addConversion([&](LLVM::LLVMArrayType type) -> llvm::Optional { + if (auto element = convertType(type.getElementType())) + return LLVM::LLVMArrayType::get(element, type.getNumElements()); + return llvm::None; + }); + addConversion([&](LLVM::LLVMFunctionType type) -> llvm::Optional { + Type convertedResType = convertType(type.getReturnType()); + if (!convertedResType) + return llvm::None; + + SmallVector convertedArgTypes; + convertedArgTypes.reserve(type.getNumParams()); + if (failed(convertTypes(type.getParams(), convertedArgTypes))) + return llvm::None; + + return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes, + type.isVarArg()); + }); + // Materialization for memrefs creates descriptor structs from individual // values constituting them, when descriptors are used, i.e. more than one // value represents a memref. diff --git a/mlir/test/Conversion/StandardToLLVM/convert-types.mlir b/mlir/test/Conversion/StandardToLLVM/convert-types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/convert-types.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -test-convert-call-op %s | FileCheck %s + +// CHECK-LABEL: @ptr +// CHECK: !llvm.ptr +func private @ptr() -> !llvm.ptr + +// CHECK-LABEL: @ptr_ptr() +// CHECK: !llvm.ptr> +func private @ptr_ptr() -> !llvm.ptr> + +// CHECK-LABEL: @struct_ptr() +// CHECK: !llvm.struct<(ptr)> +func private @struct_ptr() -> !llvm.struct<(ptr)> + +// CHECK-LABEL: @named_struct_ptr() +// CHECK: !llvm.struct<"named", (ptr)> +func private @named_struct_ptr() -> !llvm.struct<"named", (ptr)> + +// CHECK-LABEL: @array_ptr() +// CHECK: !llvm.array<10 x ptr> +func private @array_ptr() -> !llvm.array<10 x ptr> + +// CHECK-LABEL: @func() +// CHECK: !llvm.ptr> +func private @func() -> !llvm.ptr> + +// TODO: support conversion of recursive types in the conversion infra. +// CHECK-LABEL: @named_recursive() +// CHECK: !llvm.struct<"recursive", (ptr, ptr>)> +func private @named_recursive() -> !llvm.struct<"recursive", (ptr, ptr>)> + diff --git a/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp --- a/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp +++ b/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp @@ -52,6 +52,9 @@ typeConverter.addConversion([&](test::TestType type) { return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8)); }); + typeConverter.addConversion([&](test::SimpleAType type) { + return IntegerType::get(type.getContext(), 42); + }); // Populate patterns. RewritePatternSet patterns(m.getContext());