diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h @@ -32,6 +32,9 @@ LLVMTypeConverter &typeConverter; }; +/// Populates type conversions with additional SPIR-V types. +void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter); + /// Populates the given list with patterns that convert from SPIR-V to LLVM. void populateSPIRVToLLVMConversionPatterns(MLIRContext *context, LLVMTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -63,6 +63,42 @@ return builder.getIntegerAttr(integerType, -1); } +//===----------------------------------------------------------------------===// +// Type conversion +//===----------------------------------------------------------------------===// + +static Type convertArrayType(spirv::ArrayType type, TypeConverter &converter) { + auto elementType = + converter.convertType(type.getElementType()).cast(); + unsigned numElements = type.getNumElements(); + return LLVM::LLVMType::getArrayTy(elementType, numElements); +} + +static Type convertPointerType(spirv::PointerType type, + TypeConverter &converter) { + auto pointeeType = + converter.convertType(type.getPointeeType()).cast(); + return pointeeType.getPointerTo(); +} + +static Type convertRuntimeArrayType(spirv::RuntimeArrayType type, + TypeConverter &converter) { + auto elementType = + converter.convertType(type.getElementType()).cast(); + return LLVM::LLVMType::getArrayTy(elementType, 0); +} + +static Type convertStructType(spirv::StructType type, + LLVMTypeConverter &converter) { + SmallVector elementTypes; + for (unsigned i = 0, e = type.getNumElements(); i < e; ++i) { + auto convertedType = + converter.convertType(type.getElementType(i)).cast(); + elementTypes.push_back(convertedType); + } + return LLVM::LLVMType::getStructTy(converter.getDialect(), elementTypes); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -361,6 +397,21 @@ // Pattern population //===----------------------------------------------------------------------===// +void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { + typeConverter.addConversion([&](spirv::ArrayType type) { + return convertArrayType(type, typeConverter); + }); + typeConverter.addConversion([&](spirv::PointerType type) { + return convertPointerType(type, typeConverter); + }); + typeConverter.addConversion([&](spirv::RuntimeArrayType type) { + return convertRuntimeArrayType(type, typeConverter); + }); + typeConverter.addConversion([&](spirv::StructType type) { + return convertStructType(type, typeConverter); + }); +} + void mlir::populateSPIRVToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -34,6 +34,9 @@ LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; + + populateSPIRVToLLVMTypeConversion(converter); + populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns); populateSPIRVToLLVMConversionPatterns(context, converter, patterns); populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns); diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Array type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @array(!llvm<"[16 x float]">, !llvm<"[32 x <4 x float>]">) +func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> () + +//===----------------------------------------------------------------------===// +// Pointer type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @pointer_scalar(!llvm<"i1*">, !llvm<"float*">) +func @pointer_scalar(!spv.ptr, !spv.ptr) -> () + +// CHECK-LABEL: @pointer_vector(!llvm<"<4 x i32>*">) +func @pointer_vector(!spv.ptr, Function>) -> () + +//===----------------------------------------------------------------------===// +// Runtime array type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @runtime_array_vector(!llvm<"[0 x <4 x float>]">) +func @runtime_array_vector(!spv.rtarray< vector<4xf32> >) -> () + +// CHECK-LABEL: @runtime_array_scalar(!llvm<"[0 x float]">) +func @runtime_array_scalar(!spv.rtarray) -> () + +//===----------------------------------------------------------------------===// +// Struct type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @struct(!llvm<"{ double }">) +func @struct(!spv.struct) -> () + +// CHECK-LABEL: @struct_nested(!llvm<"{ i32, { i64, i32 } }">) +func @struct_nested(!spv.struct>) + +// CHECK-LABEL: @struct_with_offset(!llvm<"{ float, i32 }">) +func @struct_with_offset(!spv.struct) -> () + +// CHECK-LABEL: @struct_with_decoration(!llvm<"{ float }">) +func @struct_with_decoration(!spv.struct) \ No newline at end of file