diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3157,6 +3157,7 @@ def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>; +def SPV_OC_OpTypeImage : I32EnumAttrCase<"OpTypeImage", 25>; def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; @@ -3315,7 +3316,7 @@ SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, - SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, + SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, SPV_OC_OpTypeImage, SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -157,6 +157,7 @@ case spirv::Opcode::OpTypeMatrix: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: + case spirv::Opcode::OpTypeImage: case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -273,6 +273,8 @@ LogicalResult processFunctionType(ArrayRef operands); + LogicalResult processImageType(ArrayRef operands); + LogicalResult processRuntimeArrayType(ArrayRef operands); LogicalResult processStructType(ArrayRef operands); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -713,6 +713,8 @@ return processCooperativeMatrixType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); + case spirv::Opcode::OpTypeImage: + return processImageType(operands); case spirv::Opcode::OpTypeRuntimeArray: return processRuntimeArrayType(operands); case spirv::Opcode::OpTypeStruct: @@ -1004,6 +1006,54 @@ return success(); } +LogicalResult +spirv::Deserializer::processImageType(ArrayRef operands) { + // TODO: Add support for Access Qualifier. + if (operands.size() != 8) + return emitError( + unknownLoc, + "OpTypeImage with non-eight operands are not supported yet"); + + Type elementTy = getType(operands[1]); + if (!elementTy) + return emitError(unknownLoc, "OpTypeImage references undefined : ") + << operands[1]; + + auto dim = spirv::symbolizeDim(operands[2]); + if (!dim) + return emitError(unknownLoc, "unknown Dim for OpTypeImage: ") + << operands[2]; + + auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]); + if (!depthInfo) + return emitError(unknownLoc, "unknown Depth for OpTypeImage: ") + << operands[3]; + + auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]); + if (!arrayedInfo) + return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ") + << operands[4]; + + auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]); + if (!samplingInfo) + return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5]; + + auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); + if (!samplerUseInfo) + return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ") + << operands[6]; + + auto format = spirv::symbolizeImageFormat(operands[7]); + if (!format) + return emitError(unknownLoc, "unknown Format for OpTypeImage: ") + << operands[7]; + + typeMap[operands[0]] = spirv::ImageType::get( + elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(), + samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue()); + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp @@ -1192,6 +1192,22 @@ return success(); } + if (auto imageType = type.dyn_cast()) { + typeEnum = spirv::Opcode::OpTypeImage; + uint32_t sampledTypeID = 0; + if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) + return failure(); + + operands.push_back(sampledTypeID); + operands.push_back(static_cast(imageType.getDim())); + operands.push_back(static_cast(imageType.getDepthInfo())); + operands.push_back(static_cast(imageType.getArrayedInfo())); + operands.push_back(static_cast(imageType.getSamplingInfo())); + operands.push_back(static_cast(imageType.getSamplerUseInfo())); + operands.push_back(static_cast(imageType.getImageFormat())); + return success(); + } + if (auto arrayType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; diff --git a/mlir/test/Target/SPIRV/image.mlir b/mlir/test/Target/SPIRV/image.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/SPIRV/image.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var0 bind(0, 1) : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var1 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var2 : !spv.ptr, UniformConstant> +}