diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h @@ -32,6 +32,9 @@ LLVMTypeConverter &typeConverter; }; +/// Populates type conversions with additional SPIR-V types. +void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter); + /// Populates the given list with patterns that convert from SPIR-V to LLVM. void populateSPIRVToLLVMConversionPatterns(MLIRContext *context, LLVMTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -63,6 +63,129 @@ return builder.getIntegerAttr(integerType, -1); } +/// Converts SPIR-V struct with no offset to packed LLVM struct. +static Type convertStructTypePacked(spirv::StructType type, + LLVMTypeConverter &converter) { + SmallVector elementTypes; + for (unsigned i = 0, e = type.getNumElements(); i < e; ++i) { + auto convertedType = + converter.convertType(type.getElementType(i)).cast(); + elementTypes.push_back(convertedType); + } + return LLVM::LLVMType::getStructTy(converter.getDialect(), elementTypes, + /*isPacked=*/true); +} + +/// Returns the type size in bytes. If the type is an array, returns the size in +/// bytes of the first element. +static Optional getElementOrSelfNumBytes(spirv::SPIRVType type) { + if (auto arrayType = type.dyn_cast()) + return getElementOrSelfNumBytes( + arrayType.getElementType().cast()); + return type.getTypeNumBytes(); +} + +/// Converts SPIR-V struct with offset to LLVM struct. Fails if the struct is +/// padded unnaturally. +static Optional +convertStructTypeWithOffset(spirv::StructType type, + LLVMTypeConverter &converter) { + unsigned numElements = type.getNumElements(); + if (numElements == 0) + return LLVM::LLVMType::getStructTy(converter.getDialect(), + ArrayRef()); + + // Calculate the offset and the size in butes of the first struct member. + int64_t predOffset = type.getMemberOffset(0); + if (predOffset != 0) + return llvm::None; + auto predElementType = type.getElementType(0).cast(); + Optional typeSizePred = predElementType.getTypeNumBytes(); + if (!typeSizePred) + return llvm::None; + + SmallVector elementTypes; + elementTypes.push_back( + converter.convertType(predElementType).cast()); + for (unsigned i = 1; i < numElements; ++i) { + auto elementType = type.getElementType(i).cast(); + Optional elementOrSelfSize = getElementOrSelfNumBytes(elementType); + if (!elementOrSelfSize) + return llvm::None; + + // Check if the offset of struct members is natural. + int64_t offset = type.getMemberOffset(i); + if (offset % *elementOrSelfSize != 0 || + offset - predOffset - *typeSizePred >= *elementOrSelfSize || + offset - predOffset - *typeSizePred < 0) + return llvm::None; + + // If the member is an array, we need to calculate its size that will be + // used on next iteration. + if (elementType.isa()) { + Optional fullSize = + elementType.cast().getTypeNumBytes(); + if (!fullSize) + return llvm::None; + typeSizePred = *fullSize; + } else { + typeSizePred = *elementOrSelfSize; + } + predOffset = offset; + + auto convertedType = + converter.convertType(elementType).cast(); + elementTypes.push_back(convertedType); + } + return LLVM::LLVMType::getStructTy(converter.getDialect(), elementTypes); +} + +//===----------------------------------------------------------------------===// +// Type conversion +//===----------------------------------------------------------------------===// + +/// Converts SPIR-V array type to LLVM array. There is no modelling of array +/// stride at the moment. +static Optional convertArrayType(spirv::ArrayType type, + TypeConverter &converter) { + if (type.getArrayStride() != 0) + return llvm::None; + auto elementType = + converter.convertType(type.getElementType()).cast(); + unsigned numElements = type.getNumElements(); + return LLVM::LLVMType::getArrayTy(elementType, numElements); +} + +/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not +/// modelled at the moment. +static Type convertPointerType(spirv::PointerType type, + TypeConverter &converter) { + auto pointeeType = + converter.convertType(type.getPointeeType()).cast(); + return pointeeType.getPointerTo(); +} + +/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over +/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is +/// no modelling of array stride at the moment. +static Optional convertRuntimeArrayType(spirv::RuntimeArrayType type, + TypeConverter &converter) { + if (type.getArrayStride() != 0) + return llvm::None; + auto elementType = + converter.convertType(type.getElementType()).cast(); + return LLVM::LLVMType::getArrayTy(elementType, 0); +} + +/// Converts SPIR-V struct to LLVM struct. There is no support of struct member +/// decorations. Only "natural" offsets are supported. +static Optional convertStructType(spirv::StructType type, + LLVMTypeConverter &converter) { + if (type.hasOffset()) + return convertStructTypeWithOffset(type, converter); + return convertStructTypePacked(type, converter); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -280,6 +403,8 @@ funcType.getNumInputs()); auto llvmType = this->typeConverter.convertFunctionSignature( funcOp.getType(), /*isVariadic=*/false, signatureConverter); + if (!llvmType) + return failure(); // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); @@ -361,6 +486,21 @@ // Pattern population //===----------------------------------------------------------------------===// +void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { + typeConverter.addConversion([&](spirv::ArrayType type) { + return convertArrayType(type, typeConverter); + }); + typeConverter.addConversion([&](spirv::PointerType type) { + return convertPointerType(type, typeConverter); + }); + typeConverter.addConversion([&](spirv::RuntimeArrayType type) { + return convertRuntimeArrayType(type, typeConverter); + }); + typeConverter.addConversion([&](spirv::StructType type) { + return convertStructType(type, typeConverter); + }); +} + void mlir::populateSPIRVToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -34,6 +34,9 @@ LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; + + populateSPIRVToLLVMTypeConversion(converter); + populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns); populateSPIRVToLLVMConversionPatterns(context, converter, patterns); populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns); diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -convert-spirv-to-llvm -verify-diagnostics -split-input-file + +// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} +spv.func @array_with_stride(%arg: !spv.array<4 x f32, stride=4>) -> () "None" { + spv.Return +} + +// ----- + +// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} +spv.func @nested_struct_with_offset(%arg: !spv.struct[0], i8[16]>) -> () "None" { + spv.Return +} + +// ----- + +// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} +spv.func @struct_with_unnatural_offset1(%arg: !spv.struct) -> () "None" { + spv.Return +} + +// ----- + +// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} +spv.func @struct_with_unnatural_offset2(%arg: !spv.struct) -> () "None" { + spv.Return +} + +// ----- + +// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} +spv.func @struct_with_unnatural_offset3(%arg: !spv.struct) -> () "None" { + spv.Return +} diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt -split-input-file -convert-spirv-to-llvm -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Array type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @array(!llvm<"[16 x float]">, !llvm<"[32 x <4 x float>]">) +func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> () + +//===----------------------------------------------------------------------===// +// Pointer type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @pointer_scalar(!llvm<"i1*">, !llvm<"float*">) +func @pointer_scalar(!spv.ptr, !spv.ptr) -> () + +// CHECK-LABEL: @pointer_vector(!llvm<"<4 x i32>*">) +func @pointer_vector(!spv.ptr, Function>) -> () + +//===----------------------------------------------------------------------===// +// Runtime array type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @runtime_array_vector(!llvm<"[0 x <4 x float>]">) +func @runtime_array_vector(!spv.rtarray< vector<4xf32> >) -> () + +// CHECK-LABEL: @runtime_array_scalar(!llvm<"[0 x float]">) +func @runtime_array_scalar(!spv.rtarray) -> () + +//===----------------------------------------------------------------------===// +// Struct type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @struct(!llvm<"<{ double }>">) +func @struct(!spv.struct) -> () + +// CHECK-LABEL: @struct_nested(!llvm<"<{ i32, <{ i64, i32 }> }>">) +func @struct_nested(!spv.struct>) + +// CHECK-LABEL: llvm.func @struct_with_decoration(!llvm<"<{ float }>">) +func @struct_with_decoration(!spv.struct) + +// CHECK-LABEL: llvm.func @struct_with_offset1(!llvm<"{ i8, i64, i32 }">) +func @struct_with_offset1(!spv.struct) + +// CHECK-LABEL: llvm.func @struct_with_offset2(!llvm<"{ [5 x i8], i32, i8, i8 }">) +func @struct_with_offset2(!spv.struct[0], i32[8], i8[12], i8[13]>)