diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h @@ -32,6 +32,9 @@ LLVMTypeConverter &typeConverter; }; +/// Populates type conversions with additional SPIR-V types. +void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter); + /// Populates the given list with patterns that convert from SPIR-V to LLVM. void populateSPIRVToLLVMConversionPatterns(MLIRContext *context, LLVMTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -165,6 +165,43 @@ return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); } +//===----------------------------------------------------------------------===// +// Type conversion +//===----------------------------------------------------------------------===// + +/// Converts SPIR-V array type to LLVM array. There is no modelling of array +/// stride at the moment. +static Optional convertArrayType(spirv::ArrayType type, + TypeConverter &converter) { + if (type.getArrayStride() != 0) + return llvm::None; + auto elementType = + converter.convertType(type.getElementType()).cast(); + unsigned numElements = type.getNumElements(); + return LLVM::LLVMType::getArrayTy(elementType, numElements); +} + +/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not +/// modelled at the moment. +static Type convertPointerType(spirv::PointerType type, + TypeConverter &converter) { + auto pointeeType = + converter.convertType(type.getPointeeType()).cast(); + return pointeeType.getPointerTo(); +} + +/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over +/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is +/// no modelling of array stride at the moment. +static Optional convertRuntimeArrayType(spirv::RuntimeArrayType type, + TypeConverter &converter) { + if (type.getArrayStride() != 0) + return llvm::None; + auto elementType = + converter.convertType(type.getElementType()).cast(); + return LLVM::LLVMType::getArrayTy(elementType, 0); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -581,6 +618,8 @@ funcType.getNumInputs()); auto llvmType = this->typeConverter.convertFunctionSignature( funcOp.getType(), /*isVariadic=*/false, signatureConverter); + if (!llvmType) + return failure(); // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); @@ -662,6 +701,18 @@ // Pattern population //===----------------------------------------------------------------------===// +void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { + typeConverter.addConversion([&](spirv::ArrayType type) { + return convertArrayType(type, typeConverter); + }); + typeConverter.addConversion([&](spirv::PointerType type) { + return convertPointerType(type, typeConverter); + }); + typeConverter.addConversion([&](spirv::RuntimeArrayType type) { + return convertRuntimeArrayType(type, typeConverter); + }); +} + void mlir::populateSPIRVToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -34,6 +34,9 @@ LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; + + populateSPIRVToLLVMTypeConversion(converter); + populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns); populateSPIRVToLLVMConversionPatterns(context, converter, patterns); populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns); diff --git a/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir @@ -34,6 +34,12 @@ return } +func @bitcast_pointer(%arg0: !spv.ptr) { + // CHECK: %{{.*}} = llvm.bitcast %{{.*}} : !llvm<"float*"> to !llvm<"i32*"> + %0 = spv.Bitcast %arg0 : !spv.ptr to !spv.ptr + return +} + //===----------------------------------------------------------------------===// // spv.ConvertFToS //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.invalid.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.invalid.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s -convert-spirv-to-llvm -verify-diagnostics -split-input-file + +// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} +spv.func @array_with_stride(%arg: !spv.array<4 x f32, stride=4>) -> () "None" { + spv.Return +} diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt -split-input-file -convert-spirv-to-llvm -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Array type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @array(!llvm<"[16 x float]">, !llvm<"[32 x <4 x float>]">) +func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> () + +//===----------------------------------------------------------------------===// +// Pointer type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @pointer_scalar(!llvm<"i1*">, !llvm<"float*">) +func @pointer_scalar(!spv.ptr, !spv.ptr) -> () + +// CHECK-LABEL: @pointer_vector(!llvm<"<4 x i32>*">) +func @pointer_vector(!spv.ptr, Function>) -> () + +//===----------------------------------------------------------------------===// +// Runtime array type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @runtime_array_vector(!llvm<"[0 x <4 x float>]">) +func @runtime_array_vector(!spv.rtarray< vector<4xf32> >) -> () + +// CHECK-LABEL: @runtime_array_scalar(!llvm<"[0 x float]">) +func @runtime_array_scalar(!spv.rtarray) -> ()