diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -274,6 +274,7 @@ | image-type | pointer-type | runtime-array-type + | sampled-image-type | struct-type ``` @@ -363,6 +364,22 @@ !spv.rtarray !spv.rtarray> ``` +### Sampled image type + +This corresponds to SPIR-V [sampled image type][SampledImageType]. Its syntax is + +``` +sampled-image-type ::= `!spv.sampled_image>` +``` + +For example, + +```mlir +!spv.sampled_image> +!spv.sampled_image> +``` ### Struct type @@ -1382,6 +1399,7 @@ [ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage [PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer [RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray +[SampledImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeSampledImage [MlirDialectConversion]: ../DialectConversion.md [StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure [SpirvTools]: https://github.com/KhronosGroup/SPIRV-Tools diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3158,6 +3158,7 @@ def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>; def SPV_OC_OpTypeImage : I32EnumAttrCase<"OpTypeImage", 25>; +def SPV_OC_OpTypeSampledImage : I32EnumAttrCase<"OpTypeSampledImage", 27>; def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; @@ -3317,18 +3318,19 @@ SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, - SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, SPV_OC_OpTypeImage, - SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, - SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer, - SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, - SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, - SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, - SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp, SPV_OC_OpFunction, - SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, - SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, - SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, - SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic, - SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, + SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, + SPV_OC_OpTypeImage, SPV_OC_OpTypeSampledImage, SPV_OC_OpTypeArray, + SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, + SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer, SPV_OC_OpConstantTrue, + SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, + SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, + SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp, + SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, + SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, + SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, + SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic, + SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle, + SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -32,6 +32,7 @@ struct MatrixTypeStorage; struct PointerTypeStorage; struct RuntimeArrayTypeStorage; +struct SampledImageTypeStorage; struct StructTypeStorage; } // namespace detail @@ -233,6 +234,28 @@ Optional storage = llvm::None); }; +// SPIR-V sampled image type +class SampledImageType + : public Type::TypeBase { +public: + using Base::Base; + + static SampledImageType get(Type imageType); + + static SampledImageType getChecked(Type imageType, Location location); + + static LogicalResult verifyConstructionInvariants(Location Loc, + Type imageType); + + Type getImageType() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); +}; + /// SPIR-V struct type. Two kinds of struct types are supported: /// - Literal: a literal struct type is uniqued by its fields (types + offset /// info + decoration info). diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -116,7 +116,7 @@ void SPIRVDialect::initialize() { addTypes(); + PointerType, RuntimeArrayType, SampledImageType, StructType>(); addAttributes(); @@ -232,6 +232,23 @@ return type; } +static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + Type type; + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + if (parser.parseType(type)) + return Type(); + + if (!type.isa()) { + parser.emitError(typeLoc, + "sampled image must be composed using image type, got ") + << type; + return Type(); + } + + return type; +} + /// Parses an optional `, stride = N` assembly segment. If no parsing failure /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if /// missing. @@ -530,6 +547,21 @@ return ImageType::get(value.getValue()); } +// sampledImage-type :: = `!spv.sampledImage<` image-type `>` +static Type parseSampledImageType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type parsedType = parseAndVerifySampledImageType(dialect, parser); + if (!parsedType) + return Type(); + + if (parser.parseGreater()) + return Type(); + return SampledImageType::get(parsedType); +} + // Parse decorations associated with a member. static ParseResult parseStructMemberDecorations( SPIRVDialect const &dialect, DialectAsmParser &parser, @@ -707,6 +739,7 @@ // | image-type // | pointer-type // | runtime-array-type +// | sampled-image-type // | struct-type Type SPIRVDialect::parseType(DialectAsmParser &parser) const { StringRef keyword; @@ -723,6 +756,8 @@ return parsePointerType(*this, parser); if (keyword == "rtarray") return parseRuntimeArrayType(*this, parser); + if (keyword == "sampled_image") + return parseSampledImageType(*this, parser); if (keyword == "struct") return parseStructType(*this, parser); if (keyword == "matrix") @@ -763,6 +798,10 @@ << stringifyImageFormat(type.getImageFormat()) << ">"; } +static void print(SampledImageType type, DialectAsmPrinter &os) { + os << "sampled_image<" << type.getImageType() << ">"; +} + static void print(StructType type, DialectAsmPrinter &os) { thread_local llvm::SetVector structContext; @@ -825,7 +864,7 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case( + ImageType, SampledImageType, StructType, MatrixType>( [&](auto type) { print(type, os); }) .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -668,6 +668,8 @@ compositeType.getExtensions(extensions, storage); } else if (auto imageType = dyn_cast()) { imageType.getExtensions(extensions, storage); + } else if (auto sampledImageType = dyn_cast()) { + sampledImageType.getExtensions(extensions, storage); } else if (auto matrixType = dyn_cast()) { matrixType.getExtensions(extensions, storage); } else if (auto ptrType = dyn_cast()) { @@ -686,6 +688,8 @@ compositeType.getCapabilities(capabilities, storage); } else if (auto imageType = dyn_cast()) { imageType.getCapabilities(capabilities, storage); + } else if (auto sampledImageType = dyn_cast()) { + sampledImageType.getCapabilities(capabilities, storage); } else if (auto matrixType = dyn_cast()) { matrixType.getCapabilities(capabilities, storage); } else if (auto ptrType = dyn_cast()) { @@ -703,6 +707,56 @@ return llvm::None; } +//===----------------------------------------------------------------------===// +// SampledImageType +//===----------------------------------------------------------------------===// +struct spirv::detail::SampledImageTypeStorage : public TypeStorage { + using KeyTy = Type; + + SampledImageTypeStorage(const KeyTy &key) : imageType{key} {} + + bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); } + + static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + SampledImageTypeStorage(key); + } + + Type imageType; +}; + +SampledImageType SampledImageType::get(Type imageType) { + return Base::get(imageType.getContext(), imageType); +} + +SampledImageType SampledImageType::getChecked(Type imageType, + Location location) { + return Base::getChecked(location, imageType); +} + +Type SampledImageType::getImageType() const { return getImpl()->imageType; } + +LogicalResult SampledImageType::verifyConstructionInvariants(Location loc, + Type imageType) { + if (!imageType.isa()) + return emitError(loc, "expected image type"); + + return success(); +} + +void SampledImageType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + getImageType().cast().getExtensions(extensions, storage); +} + +void SampledImageType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + getImageType().cast().getCapabilities(capabilities, storage); +} + //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -158,6 +158,7 @@ case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: case spirv::Opcode::OpTypeImage: + case spirv::Opcode::OpTypeSampledImage: case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -275,6 +275,8 @@ LogicalResult processImageType(ArrayRef operands); + LogicalResult processSampledImageType(ArrayRef operands); + LogicalResult processRuntimeArrayType(ArrayRef operands); LogicalResult processStructType(ArrayRef operands); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -715,6 +715,8 @@ return processFunctionType(operands); case spirv::Opcode::OpTypeImage: return processImageType(operands); + case spirv::Opcode::OpTypeSampledImage: + return processSampledImageType(operands); case spirv::Opcode::OpTypeRuntimeArray: return processRuntimeArrayType(operands); case spirv::Opcode::OpTypeStruct: @@ -1054,6 +1056,21 @@ return success(); } +LogicalResult +spirv::Deserializer::processSampledImageType(ArrayRef operands) { + if (operands.size() != 2) + return emitError(unknownLoc, "OpTypeSampledImage must have two operands"); + + Type elementTy = getType(operands[1]); + if (!elementTy) + return emitError(unknownLoc, + "OpTypeSampledImage references undefined : ") + << operands[1]; + + typeMap[operands[0]] = spirv::SampledImageType::get(elementTy); + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -511,6 +511,17 @@ return processTypeDecoration(loc, runtimeArrayType, resultID); } + if (auto sampledImageType = type.dyn_cast()) { + typeEnum = spirv::Opcode::OpTypeSampledImage; + uint32_t imageTypeID = 0; + if (failed( + processType(loc, sampledImageType.getImageType(), imageTypeID))) { + return failure(); + } + operands.push_back(imageTypeID); + return success(); + } + if (auto structType = type.dyn_cast()) { if (structType.isIdentified()) { (void)processName(resultID, structType.getIdentifier()); diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -226,6 +226,20 @@ // ----- +//===----------------------------------------------------------------------===// +// SampledImageType +//===----------------------------------------------------------------------===// + +// CHECK: func private @sampled_image_type(!spv.sampled_image>) +func private @sampled_image_type(!spv.sampled_image>) -> () + +// ----- + +// expected-error @+1 {{sampled image must be composed using image type, got 'f32'}} +func private @samped_image_type_invaid_type(!spv.sampled_image) -> () + +// ----- + //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/sampled-image.mlir b/mlir/test/Target/SPIRV/sampled-image.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/SPIRV/sampled-image.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK: !spv.ptr>, UniformConstant> + spv.globalVariable @var0 bind(0, 1) : !spv.ptr>, UniformConstant> + + // CHECK: !spv.ptr>, UniformConstant> + spv.globalVariable @var1 bind(0, 0) : !spv.ptr>, UniformConstant> + + // CHECK: !spv.ptr>, UniformConstant> + spv.globalVariable @var2 bind(0, 0) : !spv.ptr>, UniformConstant> +}