diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3157,6 +3157,7 @@ def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>; +def SPV_OC_OpTypeImage : I32EnumAttrCase<"OpTypeImage", 25>; def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; @@ -3315,7 +3316,7 @@ SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, - SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, + SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, SPV_OC_OpTypeImage, SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -157,6 +157,7 @@ case spirv::Opcode::OpTypeMatrix: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: + case spirv::Opcode::OpTypeImage: case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -273,6 +273,8 @@ LogicalResult processFunctionType(ArrayRef operands); + LogicalResult processImageType(ArrayRef operands); + LogicalResult processRuntimeArrayType(ArrayRef operands); LogicalResult processStructType(ArrayRef operands); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -713,6 +713,8 @@ return processCooperativeMatrixType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); + case spirv::Opcode::OpTypeImage: + return processImageType(operands); case spirv::Opcode::OpTypeRuntimeArray: return processRuntimeArrayType(operands); case spirv::Opcode::OpTypeStruct: @@ -1004,6 +1006,40 @@ return success(); } +LogicalResult +spirv::Deserializer::processImageType(ArrayRef operands) { + // TODO: Add support for Access Qualifier. + assert(!operands.empty() && "No operands for processing image type"); + if (operands.size() != 8) + return emitError(unknownLoc, "OpTypeImage must have eight operands"); + + Type elementTy = getType(operands[1]); + if (!elementTy) + return emitError(unknownLoc, "OpTypeImage references undefined ") + << operands[1]; + + auto dim = spirv::symbolizeDim(operands[2]); + auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]).getValue(); + auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]).getValue(); + auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]).getValue(); + auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); + auto format = spirv::symbolizeImageFormat(operands[7]); + + if (dim == spirv::Dim::SubpassData) { + if (samplerUseInfo != spirv::ImageSamplerUseInfo::NoSampler || + format != spirv::ImageFormat::Unknown) { + return emitError(unknownLoc, + "OpTypeImage with Dim: SubpassData must have" + "Sampled: NoSampler and ImageFormat: Unknown"); + } + } + + typeMap[operands[0]] = spirv::ImageType::get( + elementTy, dim.getValue(), depthInfo, arrayedInfo, samplingInfo, + samplerUseInfo.getValue(), format.getValue()); + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp @@ -1192,6 +1192,22 @@ return success(); } + if (auto imageType = type.dyn_cast()) { + typeEnum = spirv::Opcode::OpTypeImage; + uint32_t sampledTypeID = 0; + if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) + return failure(); + + operands.push_back(sampledTypeID); + operands.push_back(static_cast(imageType.getDim())); + operands.push_back(static_cast(imageType.getDepthInfo())); + operands.push_back(static_cast(imageType.getArrayedInfo())); + operands.push_back(static_cast(imageType.getSamplingInfo())); + operands.push_back(static_cast(imageType.getSamplerUseInfo())); + operands.push_back(static_cast(imageType.getImageFormat())); + return success(); + } + if (auto arrayType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; diff --git a/mlir/test/Target/SPIRV/image.mlir b/mlir/test/Target/SPIRV/image.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/SPIRV/image.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var0 bind(0, 1) : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var1 bind(0, 2) : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var2 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var3 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var4 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var5 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var6 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var7 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var8 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var9 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var10 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var11 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var12 : !spv.ptr, UniformConstant> + + // CHECK: !spv.ptr, UniformConstant> + spv.globalVariable @var13 : !spv.ptr, UniformConstant> +}