diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -235,9 +235,9 @@ } /// Converts a scalar `type` to a suitable type under the given `targetEnv`. -static Optional -convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type, - Optional storageClass = {}) { +static Type convertScalarType(const spirv::TargetEnv &targetEnv, + spirv::ScalarType type, + Optional storageClass = {}) { // Get extension and capability requirements for the given type. SmallVector, 1> extensions; SmallVector, 2> capabilities; @@ -271,9 +271,9 @@ } /// Converts a vector `type` to a suitable type under the given `targetEnv`. -static Optional -convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type, - Optional storageClass = {}) { +static Type convertVectorType(const spirv::TargetEnv &targetEnv, + VectorType type, + Optional storageClass = {}) { if (type.getRank() == 1 && type.getNumElements() == 1) return type.getElementType(); @@ -281,7 +281,7 @@ // TODO: Vector types with more than four elements can be translated into // array types. LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); - return llvm::None; + return nullptr; } // Get extension and capability requirements for the given type. @@ -298,8 +298,8 @@ auto elementType = convertScalarType( targetEnv, type.getElementType().cast(), storageClass); if (elementType) - return VectorType::get(type.getShape(), *elementType); - return llvm::None; + return VectorType::get(type.getShape(), elementType); + return nullptr; } /// Converts a tensor `type` to a suitable type under the given `targetEnv`. @@ -308,20 +308,20 @@ /// create composite constants with OpConstantComposite to embed relative large /// constant values and use OpCompositeExtract and OpCompositeInsert to /// manipulate, like what we do for vectors. -static Optional convertTensorType(const spirv::TargetEnv &targetEnv, - TensorType type) { +static Type convertTensorType(const spirv::TargetEnv &targetEnv, + TensorType type) { // TODO: Handle dynamic shapes. if (!type.hasStaticShape()) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: dynamic shape unimplemented\n"); - return llvm::None; + return nullptr; } auto scalarType = type.getElementType().dyn_cast(); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert non-scalar element type\n"); - return llvm::None; + return nullptr; } Optional scalarSize = getTypeNumBytes(scalarType); @@ -329,35 +329,35 @@ if (!scalarSize || !tensorSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element count\n"); - return llvm::None; + return nullptr; } auto arrayElemCount = *tensorSize / *scalarSize; auto arrayElemType = convertScalarType(targetEnv, scalarType); if (!arrayElemType) - return llvm::None; - Optional arrayElemSize = getTypeNumBytes(*arrayElemType); + return nullptr; + Optional arrayElemSize = getTypeNumBytes(arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce converted element size\n"); - return llvm::None; + return nullptr; } - return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); + return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); } -static Optional convertMemrefType(const spirv::TargetEnv &targetEnv, - MemRefType type) { +static Type convertMemrefType(const spirv::TargetEnv &targetEnv, + MemRefType type) { Optional storageClass = SPIRVTypeConverter::getStorageClassForMemorySpace( type.getMemorySpaceAsInt()); if (!storageClass) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert memory space\n"); - return llvm::None; + return nullptr; } - Optional arrayElemType; + Type arrayElemType; Type elementType = type.getElementType(); if (auto vecType = elementType.dyn_cast()) { arrayElemType = convertVectorType(targetEnv, vecType, storageClass); @@ -368,20 +368,20 @@ llvm::dbgs() << type << " unhandled: can only convert scalar or vector element type\n"); - return llvm::None; + return nullptr; } if (!arrayElemType) - return llvm::None; + return nullptr; Optional elementSize = getTypeNumBytes(elementType); if (!elementSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element size\n"); - return llvm::None; + return nullptr; } if (!type.hasStaticShape()) { - auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize); + auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *elementSize); // Wrap in a struct to satisfy Vulkan interface requirements. auto structType = spirv::StructType::get(arrayType, 0); return spirv::PointerType::get(structType, *storageClass); @@ -391,20 +391,20 @@ if (!memrefSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element count\n"); - return llvm::None; + return nullptr; } auto arrayElemCount = *memrefSize / *elementSize; - Optional arrayElemSize = getTypeNumBytes(*arrayElemType); + Optional arrayElemSize = getTypeNumBytes(arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce converted element size\n"); - return llvm::None; + return nullptr; } auto arrayType = - spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); + spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with // workgroup storage class do not need the struct to be laid out explicitly. @@ -418,9 +418,6 @@ : targetEnv(targetAttr) { // Add conversions. The order matters here: later ones will be tried earlier. - // All other cases failed. Then we cannot convert this type. - addConversion([](Type type) { return llvm::None; }); - // Allow all SPIR-V dialect specific types. This assumes all builtin types // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) // were tried before. @@ -438,13 +435,13 @@ addConversion([this](IntegerType intType) -> Optional { if (auto scalarType = intType.dyn_cast()) return convertScalarType(targetEnv, scalarType); - return llvm::None; + return Type(); }); addConversion([this](FloatType floatType) -> Optional { if (auto scalarType = floatType.dyn_cast()) return convertScalarType(targetEnv, scalarType); - return llvm::None; + return Type(); }); addConversion([this](VectorType vectorType) {