diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -16,11 +16,14 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include +#include #define DEBUG_TYPE "mlir-spirv-conversion" @@ -149,6 +152,13 @@ return bitWidth / 8; } + if (auto complexType = type.dyn_cast()) { + auto elementSize = getTypeNumBytes(options, complexType.getElementType()); + if (!elementSize) + return std::nullopt; + return 2 * *elementSize; + } + if (auto vecType = type.dyn_cast()) { auto elementSize = getTypeNumBytes(options, vecType.getElementType()); if (!elementSize) @@ -299,6 +309,30 @@ return nullptr; } +static Type +convertComplexType(const spirv::TargetEnv &targetEnv, + const SPIRVConversionOptions &options, ComplexType type, + std::optional storageClass = {}) { + auto scalarType = dyn_cast_or_null(type.getElementType()); + if (!scalarType) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot convert non-scalar element type\n"); + return nullptr; + } + + auto elementType = + convertScalarType(targetEnv, options, scalarType, storageClass); + if (!elementType) + return nullptr; + if (elementType != type.getElementType()) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: complex type emulation unsupported\n"); + return nullptr; + } + + return VectorType::get(2, elementType); +} + /// Converts a tensor `type` to a suitable type under the given `targetEnv`. /// /// Note that this is mainly for lowering constant tensors. In SPIR-V one can @@ -372,7 +406,6 @@ return nullptr; } - if (!type.hasStaticShape()) { // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing // to the element. @@ -419,6 +452,9 @@ if (auto vecType = elementType.dyn_cast()) { arrayElemType = convertVectorType(targetEnv, options, vecType, storageClass); + } else if (auto complexType = elementType.dyn_cast()) { + arrayElemType = + convertComplexType(targetEnv, options, complexType, storageClass); } else if (auto scalarType = elementType.dyn_cast()) { arrayElemType = convertScalarType(targetEnv, options, scalarType, storageClass); @@ -443,7 +479,6 @@ return nullptr; } - if (!type.hasStaticShape()) { // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing // to the element. @@ -500,6 +535,10 @@ return Type(); }); + addConversion([this](ComplexType complexType) { + return convertComplexType(this->targetEnv, this->options, complexType); + }); + addConversion([this](VectorType vectorType) { return convertVectorType(this->targetEnv, this->options, vectorType); }); diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -196,6 +196,65 @@ // ----- +//===----------------------------------------------------------------------===// +// Complex types +//===----------------------------------------------------------------------===// + +// Check that capabilities for scalar types affects complex types too: having +// special capabilities means keep vector types untouched. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: func @complex_types +// CHECK-SAME: vector<2xf32> +// CHECK-SAME: vector<2xi32> +// CHECK-SAME: vector<2xf64> +// CHECK-SAME: vector<2xi16> +func.func @complex_types( + %arg0: complex, + %arg1: complex, + %arg2: complex, + %arg3: complex +) { return } + +// CHECK-LABEL: func @memref_complex_types_with_cap +// CHECK-SAME: !spirv.ptr, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spirv.ptr, stride=2> [0])>, Uniform> +func.func @memref_complex_types_with_cap( + %arg0: memref<4xcomplex, #spirv.storage_class>, + %arg1: memref<2x8xcomplex, #spirv.storage_class> +) { return } + +} // end module + +// ----- + +// Check that capabilities for scalar types affects complex types too: no special +// capabilities available means widening element types to 32-bit. + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// Emulation is unimplemented right now. +// CHECK-LABEL: func @memref_complex_types_no_cap +// CHECK-SAME: memref<4xcomplex, #spirv.storage_class> +// CHECK-SAME: memref<2x8xcomplex, #spirv.storage_class> +// NOEMU-LABEL: func @memref_complex_types_no_cap +// NOEMU-SAME: memref<4xcomplex, #spirv.storage_class> +// NOEMU-SAME: memref<2x8xcomplex, #spirv.storage_class> +func.func @memref_complex_types_no_cap( + %arg0: memref<4xcomplex, #spirv.storage_class>, + %arg1: memref<2x8xcomplex, #spirv.storage_class> +) { return } + +} // end module + +// ----- + //===----------------------------------------------------------------------===// // Vector types //===----------------------------------------------------------------------===//