diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3155,6 +3155,7 @@ def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>; def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; +def SPV_OC_OpTypeForwardPointer : I32EnumAttrCase<"OpTypeForwardPointer", 39>; def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>; def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>; def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>; @@ -3302,21 +3303,21 @@ SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, - SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, - SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, - SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, - SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, - SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, - SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, - SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, - SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, - SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, - SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, - SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, - SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, - SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, - SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, + SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer, + SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, + SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, + SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, + SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, + SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, + SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, + SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, + SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, + SPV_OC_OpBitcast, SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, + SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, + SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, + SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -262,7 +262,24 @@ Optional storage = llvm::None); }; -// SPIR-V struct type +// SPIR-V struct type. Two kinds of struct types are supported: +// - Literal: a literal struct type is uniqued by its fields (types + offset +// info + decoration info). +// - Identified: an indentified struct type is uniqued by its string identifier +// (name). This is useful in representing recursive structs. For example, the +// following C struct: +// +// struct A { +// A* next; +// }; +// +// would be represented in MLIR as: +// +// !spv.struct, Generic>)> +// +// In the above, expressing recursive struct types is accomplished by giving a +// recursive struct a unique identified and using that identifier in the struct +// definition for recursive references. class StructType : public Type::TypeBase { public: @@ -297,13 +314,33 @@ } }; - /// Construct a StructType with at least one member. + /// Construct a literal StructType with at least one member. static StructType get(ArrayRef memberTypes, ArrayRef offsetInfo = {}, ArrayRef memberDecorations = {}); - /// Construct a struct with no members. - static StructType getEmpty(MLIRContext *context); + /// Construct an identified StructType. This creates a StructType whose body + /// (member types, offset info, and decorations) is not set yet. A call to + /// StructType::trySetBody(...) must follow when the StructType contents are + /// available (e.g. parsed or deserialized). + /// + /// Note: If another thread creates (or had already created) a struct with the + /// same identifier, that struct will be returned as a result. + static StructType getIdentified(MLIRContext *context, StringRef identifier); + + /// Construct a (possibly identified) StructType with no members. + /// + /// Note: this method might fail in a multi-threaded setup if another thread + /// created an identified struct with the same identifier but with different + /// contents before returning. In which case, an empty (default-constructed) + /// StructType is returned. + static StructType getEmpty(MLIRContext *context, StringRef identifier = ""); + + /// For literal structs, return an empty string. + /// For identified structs, return the struct's identifier. + StringRef getIdentifier() const; + + bool isIdentified() const; unsigned getNumElements() const; @@ -346,6 +383,13 @@ SmallVectorImpl &decorationsInfo) const; + /// Sets the contents of an incomplete identified StructType. This method must + /// be called only for identified StructTypes and it must be called only once + /// per instance. Otherwise, failure() is returned. + LogicalResult + trySetBody(ArrayRef memberTypes, ArrayRef offsetInfo = {}, + ArrayRef memberDecorations = {}); + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp --- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -67,7 +67,13 @@ size = llvm::alignTo(structMemberOffset, maxMemberAlignment); alignment = maxMemberAlignment; structType.getMemberDecorations(memberDecorations); - return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); + + if (!structType.isIdentified()) + return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); + else + // Identified structs are uniqued by identifier so it is not possible + // to create 2 structs with the same name but different decorations. + return nullptr; } Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -23,6 +23,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSwitch.h" @@ -589,15 +590,80 @@ } // struct-member-decoration ::= integer-literal? spirv-decoration* -// struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)? -// (`, ` spirv-type (`[` struct-member-decoration `]`)? `>` +// struct-type ::= +// `!spv.struct<` (id `,`)? +// `(` +// (spirv-type (`[` struct-member-decoration `]`)?)* +// `)>` static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser) { + // TODO: This function is quite lengthy. Break it down into smaller chunks. + + // To properly resolve recursive references while parsing recursive struct + // types, we need to maintain a list of enclosing struct type names. This set + // maintains the names of struct types in which the type we are about to parse + // is nested. + // + // Note: This has to be thread_local to enable multiple threads to safely + // parse concurrently. + thread_local llvm::SetVector structContext; + + static auto removeIdentifierAndFail = + [](llvm::SetVector &structContext, StringRef identifier) { + if (!identifier.empty()) + structContext.remove(identifier); + + return Type(); + }; + if (parser.parseLess()) return Type(); - if (succeeded(parser.parseOptionalGreater())) - return StructType::getEmpty(dialect.getContext()); + StringRef identifier; + + // Check if this is an idenitifed struct type. + if (succeeded(parser.parseOptionalKeyword(&identifier))) { + // Check if this is a possible recursive reference. + if (succeeded(parser.parseOptionalGreater())) { + if (structContext.count(identifier) == 0) { + parser.emitError( + parser.getNameLoc(), + "recursive struct reference not nested in struct definition"); + + return Type(); + } + + return StructType::getIdentified(dialect.getContext(), identifier); + } + + if (failed(parser.parseComma())) + return Type(); + + if (structContext.count(identifier) != 0) { + parser.emitError(parser.getNameLoc(), + "identifier already used for an enclosing struct"); + + return removeIdentifierAndFail(structContext, identifier); + } + + structContext.insert(identifier); + } + + if (failed(parser.parseLParen())) + return removeIdentifierAndFail(structContext, identifier); + + if (succeeded(parser.parseOptionalRParen()) && + succeeded(parser.parseOptionalGreater())) { + if (!identifier.empty()) + structContext.remove(identifier); + + return StructType::getEmpty(dialect.getContext(), identifier); + } + + StructType idStructTy; + + if (!identifier.empty()) + idStructTy = StructType::getIdentified(dialect.getContext(), identifier); SmallVector memberTypes; SmallVector offsetInfo; @@ -606,24 +672,33 @@ do { Type memberType; if (parser.parseType(memberType)) - return Type(); + return removeIdentifierAndFail(structContext, identifier); memberTypes.push_back(memberType); - if (succeeded(parser.parseOptionalLSquare())) { + if (succeeded(parser.parseOptionalLSquare())) if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, - memberDecorationInfo)) { - return Type(); - } - } + memberDecorationInfo)) + return removeIdentifierAndFail(structContext, identifier); } while (succeeded(parser.parseOptionalComma())); if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { parser.emitError(parser.getNameLoc(), "offset specification must be given for all members"); - return Type(); + return removeIdentifierAndFail(structContext, identifier); } - if (parser.parseGreater()) - return Type(); + + if (failed(parser.parseRParen()) || failed(parser.parseGreater())) + return removeIdentifierAndFail(structContext, identifier); + + if (!identifier.empty()) { + if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, + memberDecorationInfo))) + return Type(); + + structContext.remove(identifier); + return idStructTy; + } + return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); } @@ -689,7 +764,24 @@ } static void print(StructType type, DialectAsmPrinter &os) { + thread_local llvm::SetVector structContext; + os << "struct<"; + + if (type.isIdentified()) { + os << type.getIdentifier(); + + if (structContext.count(type.getIdentifier()) == 0) { + os << ", "; + structContext.insert(type.getIdentifier()); + } else { + os << ">"; + return; + } + } + + os << "("; + auto printMember = [&](unsigned i) { os << type.getElementType(i); SmallVector decorations; @@ -713,7 +805,10 @@ }; llvm::interleaveComma(llvm::seq(0, type.getNumElements()), os, printMember); - os << ">"; + os << ")>"; + + if (type.isIdentified()) + structContext.remove(type.getIdentifier()); } static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -759,25 +759,92 @@ // StructType //===----------------------------------------------------------------------===// +/// Type storage for SPIR-V structure types: +/// +/// Structures are uniqued using: +/// - for identified structs: +/// - a string identifier; +/// - for literal structs: +/// - a list of member types; +/// - a list of member offset info; +/// - a list of member decoration info. +/// +/// Identified structures only have a mutable component consisting of: +/// - a list of member types; +/// - a list of member offset info; +/// - a list of member decoration info. struct spirv::detail::StructTypeStorage : public TypeStorage { + /// Construct a storage object for an identified struct type. A struct type + /// associated with such storage must call StructType::trySetBody(...) later + /// in order to mutate the storage object providing the actual content. + StructTypeStorage(StringRef identifier) + : memberTypes(nullptr), offsetInfo(nullptr), numMemberDecorations(0), + memberDecorationsInfo(nullptr), identifier(identifier), + isBodySet(false) {} + + /// Construct a storage object for a literal struct type. A struct type + /// associated with such storage is immutable. StructTypeStorage( unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo) : memberTypes(memberTypes), offsetInfo(layoutInfo), numMembers(numMembers), numMemberDecorations(numMemberDecorations), - memberDecorationsInfo(memberDecorationsInfo) {} - - using KeyTy = std::tuple, ArrayRef, - ArrayRef>; + memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()), + isBodySet(false) {} + + /// A storage key is divided into 2 parts: + /// - for identified structs: + /// - a StringRef representing the struct identifier; + /// - for literal structs: + /// - an ArrayRef for member types; + /// - an ArrayRef for member offset info; + /// - an ArrayRef for member decoration + /// info. + /// + /// An identified struct type is uniqued only by the first part (field 0) + /// of the key. + /// + /// A literal struct type is unqiued only by the second part (fields 1, 2, and + /// 3) of the key. The identifier field (field 0) must be empty. + using KeyTy = + std::tuple, ArrayRef, + ArrayRef>; + + /// For idetified structs, return true if the given key contains the same + /// identifier. + /// + /// For literal structs, return true if the given key contains a matching list + /// of member types + offset info + decoration info. bool operator==(const KeyTy &key) const { - return key == - KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo()); + if (isIdentified()) + // Identified types are uniqued by their identifier. + return getIdentifier() == std::get<0>(key); + else + return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), + getMemberDecorationsInfo()); } + /// If the given key contains a non-empty identifier, this method constructs + /// an identified struct and leaves the rest of the struct type data to be set + /// through a later call to StructType::trySetBody(...). + /// + /// If, on the other hand, the key contains an empty identifier, a literal + /// struct is constructed using the other fields of the key. static StructTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { - ArrayRef keyTypes = std::get<0>(key); + StringRef keyIdentifier = std::get<0>(key); + + if (!keyIdentifier.empty()) { + StringRef identifier = allocator.copyInto(keyIdentifier); + + // Identified StructType body/members will be set through trySetBody(...) + // later. + return new (allocator.allocate()) + StructTypeStorage(identifier); + } + + ArrayRef keyTypes = std::get<1>(key); // Copy the member type and layout information into the bump pointer const Type *typesList = nullptr; @@ -786,8 +853,8 @@ } const StructType::OffsetInfo *offsetInfoList = nullptr; - if (!std::get<1>(key).empty()) { - ArrayRef keyOffsetInfo = std::get<1>(key); + if (!std::get<2>(key).empty()) { + ArrayRef keyOffsetInfo = std::get<2>(key); assert(keyOffsetInfo.size() == keyTypes.size() && "size of offset information must be same as the size of number of " "elements"); @@ -796,11 +863,12 @@ const StructType::MemberDecorationInfo *memberDecorationList = nullptr; unsigned numMemberDecorations = 0; - if (!std::get<2>(key).empty()) { - auto keyMemberDecorations = std::get<2>(key); + if (!std::get<3>(key).empty()) { + auto keyMemberDecorations = std::get<3>(key); numMemberDecorations = keyMemberDecorations.size(); memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } + return new (allocator.allocate()) StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, numMemberDecorations, memberDecorationList); @@ -825,11 +893,61 @@ return {}; } + StringRef getIdentifier() const { return identifier; } + + bool isIdentified() const { return !identifier.empty(); } + + /// Sets the struct type content for identified structs. Calling this method + /// is only valid for identified structs. + /// + /// Fails under the following conditions: + /// - If called for a literal struct; + /// - If called for an identified struct whose body was set before (through a + /// call to this method) but with different contents from the passed + /// arguments. + LogicalResult + mutate(TypeStorageAllocator &allocator, ArrayRef memberTypes, + ArrayRef offsetInfo, + ArrayRef memberDecorationsInfo) { + if (!isIdentified()) + return failure(); + + if (isBodySet && + (getMemberTypes() != memberTypes || getOffsetInfo() != offsetInfo || + getMemberDecorationsInfo() != memberDecorationsInfo)) + return failure(); + + isBodySet = true; + numMembers = memberTypes.size(); + + // Copy the member type and layout information into the bump pointer. + if (!memberTypes.empty()) + this->memberTypes = allocator.copyInto(memberTypes).data(); + + if (!offsetInfo.empty()) { + assert(offsetInfo.size() == memberTypes.size() && + "size of offset information must be same as the size of number of " + "elements"); + this->offsetInfo = allocator.copyInto(offsetInfo).data(); + } + + if (!memberDecorationsInfo.empty()) { + this->numMemberDecorations = memberDecorationsInfo.size(); + this->memberDecorationsInfo = + allocator.copyInto(memberDecorationsInfo).data(); + } + + return success(); + } + Type const *memberTypes; StructType::OffsetInfo const *offsetInfo; unsigned numMembers; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; + + StringRef identifier; + bool isBodySet; }; StructType @@ -841,16 +959,39 @@ SmallVector sortedDecorations( memberDecorations.begin(), memberDecorations.end()); llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); - return Base::get(memberTypes.vec().front().getContext(), memberTypes, - offsetInfo, sortedDecorations); + return Base::get(memberTypes.vec().front().getContext(), + /*identifier=*/StringRef(), memberTypes, offsetInfo, + sortedDecorations); } -StructType StructType::getEmpty(MLIRContext *context) { - return Base::get(context, ArrayRef(), +StructType StructType::getIdentified(MLIRContext *context, + StringRef identifier) { + assert(!identifier.empty() && + "StructType identifier must be non-empty string"); + + return Base::get(context, identifier, ArrayRef(), ArrayRef(), ArrayRef()); } +StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { + StructType newStructType = Base::get( + context, identifier, ArrayRef(), ArrayRef(), + ArrayRef()); + // Set an empty body in case this is a identified struct. + if (newStructType.isIdentified() && + failed(newStructType.trySetBody( + ArrayRef(), ArrayRef(), + ArrayRef()))) + return StructType(); + + return newStructType; +} + +StringRef StructType::getIdentifier() const { return getImpl()->identifier; } + +bool StructType::isIdentified() const { return getImpl()->isIdentified(); } + unsigned StructType::getNumElements() const { return getImpl()->numMembers; } Type StructType::getElementType(unsigned index) const { @@ -895,6 +1036,13 @@ } } +LogicalResult +StructType::trySetBody(ArrayRef memberTypes, + ArrayRef offsetInfo, + ArrayRef memberDecorations) { + return Base::mutate(memberTypes, offsetInfo, memberDecorations); +} + void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { for (Type elementType : getElementTypes()) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -87,6 +87,43 @@ /// Map from a selection/loop's header block to its merge (and continue) target. using BlockMergeInfoMap = DenseMap; +/// A "deferred struct type" is a struct type with one or more member types not +/// known when the Deserializer first encounters the struct. This happens, for +/// example, with recursive structs where a pointer to the struct type is +/// forward declared through OpTypeForwardPointer in the SPIR-V module before +/// the struct declaration; the actual pointer to struct type should be defined +/// later through an OpTypePointer. For example, the following C struct: +/// +/// struct A { +/// A* next; +/// }; +/// +/// would be represented in the SPIR-V module as: +/// +/// OpName %A "A" +/// OpTypeForwardPointer %APtr Generic +/// %A = OpTypeStruct %APtr +/// %APtr = OpTypePointer Generic %A +/// +/// This means that the spirv::StructType cannot be fully constructed directly +/// when the Deserializer encounters it. Instead we create a +/// DeferredStructTypeInfo that contains all the information we know about the +/// spirv::StructType. Once all forward references for the struct are resolved, +/// the struct's body is set with all member info. +struct DeferredStructTypeInfo { + spirv::StructType deferredStructType; + + // A list of all unresolved member types for the struct. First element of each + // item is operand ID, second element is member index in the struct. + SmallVector, 0> unresolvedMemberTypes; + + // The list of member types. For unresolved members, this list contains + // place-holder empty types that will be updated later. + SmallVector memberTypes; + SmallVector offsetInfo; + SmallVector memberDecorationsInfo; +}; + /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each @@ -219,6 +256,8 @@ /// registers the type into `module`. LogicalResult processType(spirv::Opcode opcode, ArrayRef operands); + LogicalResult processOpTypePointer(ArrayRef operands); + LogicalResult processArrayType(ArrayRef operands); LogicalResult processCooperativeMatrixType(ArrayRef operands); @@ -382,6 +421,8 @@ /// insertion point. LogicalResult processUndef(ArrayRef operands); + LogicalResult processTypeForwardPointer(ArrayRef operands); + /// Method to dispatch to the specialized deserialization function for an /// operation in SPIR-V dialect that is a mirror of an instruction in the /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for @@ -520,6 +561,13 @@ // processed. SmallVector>, 4> deferredInstructions; + + // A list of IDs for all types forward-declared through OpTypeForwardPointer + // instructions. + llvm::SetVector typeForwardPointerIDs; + + // A list of all structs which have unresolved member types. + SmallVector deferredStructTypesInfos; }; } // namespace @@ -1157,16 +1205,7 @@ typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); } break; case spirv::Opcode::OpTypePointer: { - if (operands.size() != 3) { - return emitError(unknownLoc, "OpTypePointer must have two parameters"); - } - auto pointeeType = getType(operands[2]); - if (!pointeeType) { - return emitError(unknownLoc, "unknown OpTypePointer pointee type ") - << operands[2]; - } - auto storageClass = static_cast(operands[1]); - typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass); + return processOpTypePointer(operands); } break; case spirv::Opcode::OpTypeArray: return processArrayType(operands); @@ -1186,6 +1225,60 @@ return success(); } +LogicalResult Deserializer::processOpTypePointer(ArrayRef operands) { + if (operands.size() != 3) + return emitError(unknownLoc, "OpTypePointer must have two parameters"); + + auto pointeeType = getType(operands[2]); + + if (!pointeeType) + return emitError(unknownLoc, "unknown OpTypePointer pointee type ") + << operands[2]; + + uint32_t typePointerID = operands[0]; + auto storageClass = static_cast(operands[1]); + typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); + + for (auto *deferredStructIt = std::begin(deferredStructTypesInfos); + deferredStructIt != std::end(deferredStructTypesInfos);) { + for (auto *unresolvedMemberIt = + std::begin(deferredStructIt->unresolvedMemberTypes); + unresolvedMemberIt != + std::end(deferredStructIt->unresolvedMemberTypes);) { + if (unresolvedMemberIt->first == typePointerID) { + // The newly constructed pointer type can resolve one of the + // deferred struct type members; update the memberTypes list and + // clean the unresolvedMemberTypes list accordingly. + deferredStructIt->memberTypes[unresolvedMemberIt->second] = + typeMap[typePointerID]; + unresolvedMemberIt = + deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt); + } else { + ++unresolvedMemberIt; + } + } + + if (deferredStructIt->unresolvedMemberTypes.empty()) { + // All deferred struct type members are now resolved, set the struct body. + auto structType = deferredStructIt->deferredStructType; + + assert(structType && "expected a spirv::StructType"); + assert(structType.isIdentified() && "expected an indentified struct"); + + if (failed(structType.trySetBody( + deferredStructIt->memberTypes, deferredStructIt->offsetInfo, + deferredStructIt->memberDecorationsInfo))) + return failure(); + + deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); + } else { + ++deferredStructIt; + } + } + + return success(); +} + LogicalResult Deserializer::processArrayType(ArrayRef operands) { if (operands.size() != 3) { return emitError(unknownLoc, @@ -1289,22 +1382,34 @@ } LogicalResult Deserializer::processStructType(ArrayRef operands) { + // TODO: Find a way to handle identified structs when debug info is stripped. + if (operands.empty()) { return emitError(unknownLoc, "OpTypeStruct must have at least result "); } + if (operands.size() == 1) { // Handle empty struct. - typeMap[operands[0]] = spirv::StructType::getEmpty(context); + typeMap[operands[0]] = + spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str()); return success(); } - SmallVector memberTypes; + // First element is operand ID, second element is member index in the struct. + SmallVector, 0> unresolvedMemberTypes; + SmallVector memberTypes; + for (auto op : llvm::drop_begin(operands, 1)) { Type memberType = getType(op); - if (!memberType) { + bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0); + + if (!memberType && !typeForwardPtr) return emitError(unknownLoc, "OpTypeStruct references undefined ") << op; - } + + if (!memberType) + unresolvedMemberTypes.emplace_back(op, memberTypes.size()); + memberTypes.push_back(memberType); } @@ -1336,8 +1441,28 @@ } } } - typeMap[operands[0]] = - spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + + uint32_t structID = operands[0]; + std::string structIdentifier = nameMap.lookup(structID).str(); + + if (structIdentifier.empty()) { + assert(unresolvedMemberTypes.empty() && + "didn't expect unresolved member types"); + typeMap[structID] = + spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + } else { + auto structTy = spirv::StructType::getIdentified(context, structIdentifier); + typeMap[structID] = structTy; + + if (!unresolvedMemberTypes.empty()) + deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, + memberTypes, offsetInfo, + memberDecorationsInfo}); + else if (failed(structTy.trySetBody(memberTypes, offsetInfo, + memberDecorationsInfo))) + return failure(); + } + // TODO: Update StructType to have member name as attribute as // well. return success(); @@ -2343,6 +2468,8 @@ return processPhi(operands); case spirv::Opcode::OpUndef: return processUndef(operands); + case spirv::Opcode::OpTypeForwardPointer: + return processTypeForwardPointer(operands); default: break; } @@ -2361,6 +2488,19 @@ return success(); } +LogicalResult +Deserializer::processTypeForwardPointer(ArrayRef operands) { + if (operands.size() != 2) + return emitError(unknownLoc, + "OpTypeForwardPointer instruction must have two operands"); + + typeForwardPointerIDs.insert(operands[0]); + // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer + // instruction that defines the actual type. + + return success(); +} + LogicalResult Deserializer::processExtInst(ArrayRef operands) { if (operands.size() < 4) { return emitError(unknownLoc, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -22,6 +22,7 @@ #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" @@ -252,12 +253,16 @@ /// Main dispatch method for serializing a type. The result of the /// serialized type will be returned as `typeID`. LogicalResult processType(Location loc, Type type, uint32_t &typeID); + LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID, + llvm::SetVector &serializationCtx); /// Method for preparing basic SPIR-V type serialization. Returns the type's /// opcode and operands for the instruction via `typeEnum` and `operands`. LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, - SmallVectorImpl &operands); + SmallVectorImpl &operands, + bool &deferSerialization, + llvm::SetVector &serializationCtx); LogicalResult prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, @@ -424,6 +429,20 @@ SmallVector typesGlobalValues; SmallVector functions; + /// Recursive struct references are serialized as OpTypePointer instructions + /// to the recursive struct type. However, the OpTypePointer instruction + /// cannot be emitted before the recursive struct's OpTypeStruct. + /// RecursiveStructPointerInfo stores the data needed to emit such + /// OpTypePointer instructions after forward references to such types. + struct RecursiveStructPointerInfo { + uint32_t pointerTypeID; + spirv::StorageClass storageClass; + }; + + // Maps spirv::StructType to its recursive reference member info. + DenseMap> + recursiveStructInfos; + /// `functionHeader` contains all the instructions that must be in the first /// block in the function, and `functionBody` contains the rest. After /// processing FuncOp, the encoded instructions of a function are appended to @@ -1013,28 +1032,71 @@ LogicalResult Serializer::processType(Location loc, Type type, uint32_t &typeID) { + // Maintains a set of names for nested identified struct types. This is used + // to properly seialize resursive references. + llvm::SetVector serializationCtx; + return processTypeImpl(loc, type, typeID, serializationCtx); +} + +LogicalResult +Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, + llvm::SetVector &serializationCtx) { typeID = getTypeID(type); if (typeID) { return success(); } typeID = getNextID(); SmallVector operands; + operands.push_back(typeID); auto typeEnum = spirv::Opcode::OpTypeVoid; + bool deferSerialization = false; + if ((type.isa() && succeeded(prepareFunctionType(loc, type.cast(), typeEnum, operands))) || - succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) { + succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, + deferSerialization, serializationCtx))) { + if (deferSerialization) { + return success(); + } + typeIDMap[type] = typeID; - return encodeInstructionInto(typesGlobalValues, typeEnum, operands); + + if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) + return failure(); + + if (recursiveStructInfos.count(type) != 0) { + // This recursive struct type is emitted already, now the OpTypePointer + // instructions referring to recursive references are emitted as well. + for (auto &ptrInfo : recursiveStructInfos[type]) { + // TODO: This might not work if more than 1 recursive reference is + // present in the struct. + SmallVector ptrOperands; + ptrOperands.push_back(ptrInfo.pointerTypeID); + ptrOperands.push_back(static_cast(ptrInfo.storageClass)); + ptrOperands.push_back(typeIDMap[type]); + + if (failed(encodeInstructionInto( + typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) + return failure(); + } + + recursiveStructInfos[type].clear(); + } + + return success(); } + return failure(); } -LogicalResult -Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, - spirv::Opcode &typeEnum, - SmallVectorImpl &operands) { +LogicalResult Serializer::prepareBasicType( + Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, + SmallVectorImpl &operands, bool &deferSerialization, + llvm::SetVector &serializationCtx) { + deferSerialization = false; + if (isVoidType(type)) { typeEnum = spirv::Opcode::OpTypeVoid; return success(); @@ -1064,7 +1126,8 @@ if (auto vectorType = type.dyn_cast()) { uint32_t elementTypeID = 0; - if (failed(processType(loc, vectorType.getElementType(), elementTypeID))) { + if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, + serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeVector; @@ -1076,7 +1139,8 @@ if (auto arrayType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; - if (failed(processType(loc, arrayType.getElementType(), elementTypeID))) { + if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, + serializationCtx))) { return failure(); } operands.push_back(elementTypeID); @@ -1089,9 +1153,45 @@ if (auto ptrType = type.dyn_cast()) { uint32_t pointeeTypeID = 0; - if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) { - return failure(); + spirv::StructType pointeeStruct = + ptrType.getPointeeType().dyn_cast(); + + if (pointeeStruct && pointeeStruct.isIdentified() && + serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { + // A recursive reference to an enclosing struct is found. + // + // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage + // class as operands. + SmallVector forwardPtrOperands; + forwardPtrOperands.push_back(resultID); + forwardPtrOperands.push_back( + static_cast(ptrType.getStorageClass())); + + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpTypeForwardPointer, + forwardPtrOperands); + + // 2. Find the the pointee (enclosing) struct. + auto structType = spirv::StructType::getIdentified( + module.getContext(), pointeeStruct.getIdentifier()); + + if (!structType) + return failure(); + + // 3. Mark the OpTypePointer that is supposed to be emitted by this call + // as deferred. + deferSerialization = true; + + // 4. Record the info needed to emit the deferred OpTypePointer + // instruction when the enclosing struct is completely serialized. + recursiveStructInfos[structType].push_back( + {resultID, ptrType.getStorageClass()}); + } else { + if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, + serializationCtx))) + return failure(); } + typeEnum = spirv::Opcode::OpTypePointer; operands.push_back(static_cast(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); @@ -1100,8 +1200,8 @@ if (auto runtimeArrayType = type.dyn_cast()) { uint32_t elementTypeID = 0; - if (failed(processType(loc, runtimeArrayType.getElementType(), - elementTypeID))) { + if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), + elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeRuntimeArray; @@ -1110,12 +1210,17 @@ } if (auto structType = type.dyn_cast()) { + if (structType.isIdentified()) { + processName(resultID, structType.getIdentifier()); + serializationCtx.insert(structType.getIdentifier()); + } + bool hasOffset = structType.hasOffset(); for (auto elementIndex : llvm::seq(0, structType.getNumElements())) { uint32_t elementTypeID = 0; - if (failed(processType(loc, structType.getElementType(elementIndex), - elementTypeID))) { + if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), + elementTypeID, serializationCtx))) { return failure(); } operands.push_back(elementTypeID); @@ -1133,6 +1238,7 @@ } SmallVector memberDecorations; structType.getMemberDecorations(memberDecorations); + for (auto &memberDecoration : memberDecorations) { if (failed(processMemberDecoration(resultID, memberDecoration))) { return emitError(loc, "cannot decorate ") @@ -1141,15 +1247,20 @@ << stringifyDecoration(memberDecoration.decoration); } } + typeEnum = spirv::Opcode::OpTypeStruct; + + if (structType.isIdentified()) + serializationCtx.remove(structType.getIdentifier()); + return success(); } if (auto cooperativeMatrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; - if (failed(processType(loc, cooperativeMatrixType.getElementType(), - elementTypeID))) { + if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), + elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; @@ -1167,7 +1278,8 @@ if (auto matrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; - if (failed(processType(loc, matrixType.getColumnType(), elementTypeID))) { + if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, + serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeMatrix; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -35,6 +35,10 @@ auto ptrType = op.type().cast(); auto structType = VulkanLayoutUtils::decorateType( ptrType.getPointeeType().cast()); + + if (!structType) + return failure(); + auto decoratedType = spirv::PointerType::get(structType, ptrType.getStorageClass()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -53,6 +53,10 @@ // Set the offset information. varPointeeType = VulkanLayoutUtils::decorateType(varPointeeType).cast(); + + if (!varPointeeType) + return nullptr; + varType = spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); diff --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir --- a/mlir/test/Conversion/GPUToSPIRV/if.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/if.mlir @@ -133,20 +133,20 @@ // VariablePointer capability is supported. This test is still useful to // make sure we can handle scf op result with type change. // CHECK-LABEL: @simple_if_yield_type_change - // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0]>, StorageBuffer>, Function> + // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0])>, StorageBuffer>, Function> // CHECK: spv.selection { // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] // CHECK-NEXT: [[TRUE]]: - // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0])>, StorageBuffer> // CHECK: spv.Branch ^[[MERGE:.*]] // CHECK-NEXT: [[FALSE]]: - // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0])>, StorageBuffer> // CHECK: spv.Branch ^[[MERGE]] // CHECK-NEXT: ^[[MERGE]]: // CHECK: spv._merge // CHECK-NEXT: } - // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr [0]>, StorageBuffer> - // CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr [0]>, StorageBuffer> + // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr [0])>, StorageBuffer> // CHECK: spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32 // CHECK: spv.Return gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -25,9 +25,9 @@ // CHECK-DAG: spv.globalVariable @[[$LOCALINVOCATIONIDVAR:.*]] built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable @[[$WORKGROUPIDVAR:.*]] built_in("WorkgroupId") : !spv.ptr, Input> // CHECK-LABEL: spv.func @load_store_kernel - // CHECK-SAME: %[[ARG0:.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>} - // CHECK-SAME: %[[ARG1:.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} - // CHECK-SAME: %[[ARG2:.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>} + // CHECK-SAME: %[[ARG0:.*]]: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>} + // CHECK-SAME: %[[ARG1:.*]]: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} + // CHECK-SAME: %[[ARG2:.*]]: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>} // CHECK-SAME: %[[ARG3:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>} // CHECK-SAME: %[[ARG4:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>} // CHECK-SAME: %[[ARG5:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>} diff --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir --- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir @@ -9,7 +9,7 @@ // CHECK: spv.func // CHECK-SAME: {{%.*}}: f32 // CHECK-NOT: spv.interface_var_abi - // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, CrossWorkgroup> + // CHECK-SAME: {{%.*}}: !spv.ptr [0])>, CrossWorkgroup> // CHECK-NOT: spv.interface_var_abi // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -5,7 +5,7 @@ // CHECK: spv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spv.func @basic_module_structure // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>} - // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} + // CHECK-SAME: {{%.*}}: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>) kernel attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { @@ -32,7 +32,7 @@ // CHECK-LABEL: spv.func @basic_module_structure_preset_ABI // CHECK-SAME: {{%[a-zA-Z0-9_]*}}: f32 // CHECK-SAME: spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer> - // CHECK-SAME: !spv.ptr [0]>, StorageBuffer> + // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // CHECK-SAME: spv.interface_var_abi = #spv.interface_var_abi<(3, 0)> // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure_preset_ABI( diff --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir --- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir +++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir @@ -6,12 +6,12 @@ module attributes {gpu.container_module} { spv.module Logical GLSL450 requires #spv.vce { - spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr [0])>, StorageBuffer> spv.func @kernel() "None" attributes {workgroup_attributions = 0 : i64} { - %0 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> + %0 = spv._address_of @kernel_arg_0 : !spv.ptr [0])>, StorageBuffer> %2 = spv.constant 0 : i32 - %3 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> - %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %3 = spv._address_of @kernel_arg_0 : !spv.ptr [0])>, StorageBuffer> + %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0])>, StorageBuffer>, i32, i32 %5 = spv.Load "StorageBuffer" %4 : f32 spv.Return } diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -8,10 +8,10 @@ spv.func @access_chain() "None" { // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 %0 = spv.constant 1: i32 - %1 = spv.Variable : !spv.ptr>, Function> + %1 = spv.Variable : !spv.ptr)>, Function> // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr)>>, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.ptr - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Function>, i32, i32 spv.Return } @@ -38,9 +38,9 @@ // CHECK: llvm.mlir.global private @struct() : !llvm.struct)> // CHECK-LABEL: @func // CHECK: llvm.mlir.addressof @struct : !llvm.ptr)>> - spv.globalVariable @struct : !spv.ptr>, Private> + spv.globalVariable @struct : !spv.ptr)>, Private> spv.func @func() "None" { - %0 = spv._address_of @struct : !spv.ptr>, Private> + %0 = spv._address_of @struct : !spv.ptr)>, Private> spv.Return } } @@ -124,10 +124,10 @@ } // CHECK-LABEL: @store_composite -spv.func @store_composite(%arg0 : !spv.struct) "None" { - %0 = spv.Variable : !spv.ptr, Function> +spv.func @store_composite(%arg0 : !spv.struct<(f64)>) "None" { + %0 = spv.Variable : !spv.ptr, Function> // CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm.ptr> - spv.Store "Function" %0, %arg0 : !spv.struct + spv.Store "Function" %0, %arg0 : !spv.struct<(f64)> spv.Return } diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir @@ -8,13 +8,13 @@ // ----- // expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} -spv.func @struct_with_unnatural_offset(%arg: !spv.struct) -> () "None" { +spv.func @struct_with_unnatural_offset(%arg: !spv.struct<(i32[0], i32[8])>) -> () "None" { spv.Return } // ----- // expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} -spv.func @struct_with_decorations(%arg: !spv.struct) -> () "None" { +spv.func @struct_with_decorations(%arg: !spv.struct<(f32 [RelaxedPrecision])>) -> () "None" { spv.Return } diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir @@ -35,10 +35,10 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: @struct(!llvm.struct) -spv.func @struct(!spv.struct) "None" +spv.func @struct(!spv.struct<(f64)>) "None" // CHECK-LABEL: @struct_nested(!llvm.struct)>) -spv.func @struct_nested(!spv.struct>) "None" +spv.func @struct_nested(!spv.struct<(i32, !spv.struct<(i64, i32)>)>) "None" // CHECK-LABEL: @struct_with_natural_offset(!llvm.struct<(i8, i32)>) -spv.func @struct_with_natural_offset(!spv.struct) "None" +spv.func @struct_with_natural_offset(!spv.struct<(i8[0], i32[4])>) "None" diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -17,7 +17,7 @@ return } } -// CHECK: spv.globalVariable @[[VAR:.+]] : !spv.ptr>, Workgroup> +// CHECK: spv.globalVariable @[[VAR:.+]] : !spv.ptr)>, Workgroup> // CHECK: func @alloc_dealloc_workgroup_mem // CHECK-NOT: alloc // CHECK: %[[PTR:.+]] = spv._address_of @[[VAR]] @@ -45,7 +45,7 @@ } // CHECK: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr>, Workgroup> +// CHECK-SAME: !spv.ptr)>, Workgroup> // CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem // CHECK: %[[VAR:.+]] = spv._address_of @__workgroup_mem__0 // CHECK: %[[LOC:.+]] = spv.SDiv @@ -72,9 +72,9 @@ } // CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr>, Workgroup> +// CHECK-SAME: !spv.ptr)>, Workgroup> // CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr>, Workgroup> +// CHECK-SAME: !spv.ptr)>, Workgroup> // CHECK: spv.func @two_allocs() // CHECK: spv.Return @@ -93,9 +93,9 @@ } // CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr, stride=8>>, Workgroup> +// CHECK-SAME: !spv.ptr, stride=8>)>, Workgroup> // CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr, stride=16>>, Workgroup> +// CHECK-SAME: !spv.ptr, stride=16>)>, Workgroup> // CHECK: spv.func @two_allocs_vector() // CHECK: spv.Return diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -691,8 +691,8 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: @load_store_zero_rank_float -// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, -// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +// CHECK: [[ARG0:%.*]]: !spv.ptr [0])>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0])>, StorageBuffer>) func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ @@ -710,8 +710,8 @@ } // CHECK-LABEL: @load_store_zero_rank_int -// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, -// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +// CHECK: [[ARG0:%.*]]: !spv.ptr [0])>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0])>, StorageBuffer>) func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -274,35 +274,35 @@ } { // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } // CHECK-LABEL: spv.func @memref_8bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return } // CHECK-LABEL: spv.func @memref_8bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0]>, Input> +// CHECK-SAME: !spv.ptr [0])>, Input> func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0]>, Output> +// CHECK-SAME: !spv.ptr [0])>, Output> func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return } } // end module @@ -319,12 +319,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> func @memref_16bit_PushConstant( %arg0: memref<16xi16, 7>, %arg1: memref<16xf16, 7> @@ -344,12 +344,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_16bit_StorageBuffer( %arg0: memref<16xi16, 0>, %arg1: memref<16xf16, 0> @@ -369,12 +369,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> func @memref_16bit_Uniform( %arg0: memref<16xi16, 4>, %arg1: memref<16xf16, 4> @@ -393,11 +393,11 @@ } { // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0]>, Input> +// CHECK-SAME: !spv.ptr [0])>, Input> func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0]>, Output> +// CHECK-SAME: !spv.ptr [0])>, Output> func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return } } // end module @@ -412,22 +412,22 @@ // CHECK-LABEL: spv.func @memref_offset_strides func @memref_offset_strides( -// CHECK-SAME: !spv.array<64 x f32, stride=4> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<72 x f32, stride=4> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<256 x f32, stride=4> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<64 x f32, stride=4> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<88 x f32, stride=4> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f32, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<72 x f32, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<256 x f32, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f32, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<88 x f32, stride=4> [0])>, StorageBuffer> %arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>, // tightly packed; row major %arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>, // offset 8 %arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row %arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major %arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col -// CHECK-SAME: !spv.array<64 x f16, stride=2> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<72 x f16, stride=2> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<256 x f16, stride=2> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<64 x f16, stride=2> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<88 x f16, stride=2> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<72 x f16, stride=2> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<256 x f16, stride=2> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<88 x f16, stride=2> [0])>, StorageBuffer> %arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>, %arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>, %arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>, @@ -450,8 +450,8 @@ func @unranked_memref(%arg0: memref<*xi32>) { return } // CHECK-LABEL: func @dynamic_dim_memref -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @dynamic_dim_memref(%arg0: memref<8x?xi32>, %arg1: memref) { return } @@ -466,16 +466,16 @@ } { // CHECK-LABEL: func @memref_vector -// CHECK-SAME: !spv.ptr, stride=8> [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr, stride=16> [0]>, Uniform> +// CHECK-SAME: !spv.ptr, stride=8> [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr, stride=16> [0])>, Uniform> func @memref_vector( %arg0: memref<4xvector<2xf32>, 0>, %arg1: memref<4xvector<4xf32>, 4>) { return } // CHECK-LABEL: func @dynamic_dim_memref_vector -// CHECK-SAME: !spv.ptr, stride=16> [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr, stride=8> [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr, stride=8> [0])>, StorageBuffer> func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>, %arg1: memref>) { return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir @@ -1,10 +1,10 @@ // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { - spv.func @composite_insert(%arg0 : !spv.struct, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct, f32>> "None" { - // CHECK: spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct, f32>> - %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct, f32>> - spv.ReturnValue %0: !spv.struct, f32>> + spv.func @composite_insert(%arg0 : !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)>, %arg1: !spv.array<4xf32>) -> !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> "None" { + // CHECK: spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<(f32, !spv.struct<(!spv.array<4 x f32>, f32)>)> + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> + spv.ReturnValue %0: !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> } spv.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> "None" { // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> diff --git a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir @@ -29,9 +29,9 @@ spv.Return } - spv.func @composite(%arg0 : !spv.struct, f32>>, %arg1: !spv.array<4xf32>, %arg2 : f32, %arg3 : f32) "None" { + spv.func @composite(%arg0 : !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)>, %arg1: !spv.array<4xf32>, %arg2 : f32, %arg3 : f32) "None" { // CHECK: loc({{".*debug.mlir"}}:34:10) - %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct, f32>> + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> // CHECK: loc({{".*debug.mlir"}}:36:10) %1 = spv.CompositeConstruct %arg2, %arg3 : vector<2xf32> spv.Return diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -60,14 +60,14 @@ // ----- spv.module Logical GLSL450 requires #spv.vce { - spv.globalVariable @GV1 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> - spv.globalVariable @GV2 bind(0, 1) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @GV1 bind(0, 0) : !spv.ptr [0])>, StorageBuffer> + spv.globalVariable @GV2 bind(0, 1) : !spv.ptr [0])>, StorageBuffer> spv.func @loop_kernel() "None" { - %0 = spv._address_of @GV1 : !spv.ptr [0]>, StorageBuffer> + %0 = spv._address_of @GV1 : !spv.ptr [0])>, StorageBuffer> %1 = spv.constant 0 : i32 - %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer>, i32 - %3 = spv._address_of @GV2 : !spv.ptr [0]>, StorageBuffer> - %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer>, i32 + %2 = spv.AccessChain %0[%1] : !spv.ptr [0])>, StorageBuffer>, i32 + %3 = spv._address_of @GV2 : !spv.ptr [0])>, StorageBuffer> + %5 = spv.AccessChain %3[%1] : !spv.ptr [0])>, StorageBuffer>, i32 %6 = spv.constant 4 : i32 %7 = spv.constant 42 : i32 %8 = spv.constant 2 : i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -27,32 +27,32 @@ // ----- spv.module Logical GLSL450 requires #spv.vce { - spv.func @load_store_zero_rank_float(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + spv.func @load_store_zero_rank_float(%arg0: !spv.ptr [0])>, StorageBuffer>, %arg1: !spv.ptr [0])>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0])> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 %2 = spv.Load "StorageBuffer" %1 : f32 - // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0])> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0])>, StorageBuffer>, i32, i32 spv.Store "StorageBuffer" %4, %2 : f32 spv.Return } - spv.func @load_store_zero_rank_int(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + spv.func @load_store_zero_rank_int(%arg0: !spv.ptr [0])>, StorageBuffer>, %arg1: !spv.ptr [0])>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0])> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 %2 = spv.Load "StorageBuffer" %1 : i32 - // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0])> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0])>, StorageBuffer>, i32, i32 spv.Store "StorageBuffer" %4, %2 : i32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir @@ -39,8 +39,8 @@ // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32> spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32> - // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct - spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct + // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<(i32, f32, f32)> + spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<(i32, f32, f32)> // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32> spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32> diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir @@ -1,36 +1,52 @@ // RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { - // CHECK: !spv.ptr [0]>, Input> - spv.globalVariable @var0 bind(0, 1) : !spv.ptr [0]>, Input> + // CHECK: !spv.ptr [0])>, Input> + spv.globalVariable @var0 bind(0, 1) : !spv.ptr [0])>, Input> - // CHECK: !spv.ptr [4]> [4]>, Input> - spv.globalVariable @var1 bind(0, 2) : !spv.ptr [4]> [4]>, Input> + // CHECK: !spv.ptr [4])> [4])>, Input> + spv.globalVariable @var1 bind(0, 2) : !spv.ptr [4])> [4])>, Input> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @var2 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var2 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr [0]>, stride=512> [0]>, StorageBuffer> - spv.globalVariable @var3 : !spv.ptr [0]>, stride=512> [0]>, StorageBuffer> + // CHECK: !spv.ptr [0])>, stride=512> [0])>, StorageBuffer> + spv.globalVariable @var3 : !spv.ptr [0])>, stride=512> [0])>, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @var4 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var4 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @var5 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var5 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @var6 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var6 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr> [0, ColMajor, MatrixStride=16]>, StorageBuffer> - spv.globalVariable @var7 : !spv.ptr> [0, ColMajor, MatrixStride=16]>, StorageBuffer> + // CHECK: !spv.ptr> [0, ColMajor, MatrixStride=16])>, StorageBuffer> + spv.globalVariable @var7 : !spv.ptr> [0, ColMajor, MatrixStride=16])>, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @empty : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @empty : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr [0]>, Input>, - // CHECK-SAME: !spv.ptr [0]>, Output> - spv.func @kernel(%arg0: !spv.ptr [0]>, Input>, %arg1: !spv.ptr [0]>, Output>) -> () "None" { + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @id_empty : !spv.ptr, StorageBuffer> + + // CHECK: !spv.ptr [0])>, Input> + spv.globalVariable @id_var0 : !spv.ptr [0])>, Input> + + + // CHECK: !spv.ptr, StorageBuffer>)>, StorageBuffer> + spv.globalVariable @recursive_simple : !spv.ptr, StorageBuffer>)>, StorageBuffer> + + // CHECK: !spv.ptr, Uniform>)>, Uniform>)>, Uniform> + spv.globalVariable @recursive_2 : !spv.ptr, Uniform>)>, Uniform>)>, Uniform> + + // CHECK: !spv.ptr, Uniform>, !spv.ptr, Uniform>)>, Uniform>)>, Uniform> + spv.globalVariable @recursive_3 : !spv.ptr, Uniform>, !spv.ptr, Uniform>)>, Uniform>)>, Uniform> + + // CHECK: !spv.ptr [0])>, Input>, + // CHECK-SAME: !spv.ptr [0])>, Output> + spv.func @kernel(%arg0: !spv.ptr [0])>, Input>, %arg1: !spv.ptr [0])>, Output>) -> () "None" { spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir @@ -13,10 +13,10 @@ // CHECK: {{%.*}} = spv.undef : !spv.array<4 x !spv.array<4 x i32>> %5 = spv.undef : !spv.array<4x!spv.array<4xi32>> %6 = spv.CompositeExtract %5[1 : i32, 2 : i32] : !spv.array<4x!spv.array<4xi32>> - // CHECK: {{%.*}} = spv.undef : !spv.ptr, StorageBuffer> - %7 = spv.undef : !spv.ptr, StorageBuffer> + // CHECK: {{%.*}} = spv.undef : !spv.ptr, StorageBuffer> + %7 = spv.undef : !spv.ptr, StorageBuffer> %8 = spv.constant 0 : i32 - %9 = spv.AccessChain %7[%8] : !spv.ptr, StorageBuffer>, i32 + %9 = spv.AccessChain %7[%8] : !spv.ptr, StorageBuffer>, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir @@ -5,14 +5,12 @@ } { spv.module Physical64 OpenCL { // CHECK-LABEL: spv.module - // CHECK: spv.func [[FN:@.*]]( - // CHECK-SAME: {{%.*}}: f32 - // CHECK-SAME: {{%.*}}: !spv.ptr>, CrossWorkgroup> + // CHECK: spv.func [[FN:@.*]]({{%.*}}: f32, {{%.*}}: !spv.ptr)>, CrossWorkgroup> // CHECK: spv.EntryPoint "Kernel" [[FN]] // CHECK: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1 spv.func @kernel( %arg0: f32, - %arg1: !spv.ptr>, CrossWorkgroup>) "None" + %arg1: !spv.ptr)>, CrossWorkgroup>) "None" attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -7,13 +7,13 @@ // CHECK-LABEL: spv.module spv.module Logical GLSL450 { - // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0])>, StorageBuffer> // CHECK: spv.func [[FN:@.*]]() spv.func @kernel( %arg0: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>}, - %arg1: !spv.ptr>, StorageBuffer> + %arg1: !spv.ptr)>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}) "None" attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]] diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -15,20 +15,20 @@ spv.globalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") spv.globalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr, Input> - // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, stride=16> [0]>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr, stride=16> [0]>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr, stride=16> [0]>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, stride=16> [0])>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr, stride=16> [0])>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr, stride=16> [0])>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr, StorageBuffer> // CHECK: spv.func [[FN:@.*]]() spv.func @load_store_kernel( - %arg0: !spv.ptr>>, StorageBuffer> + %arg0: !spv.ptr>)>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>}, - %arg1: !spv.ptr>>, StorageBuffer> + %arg1: !spv.ptr>)>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}, - %arg2: !spv.ptr>>, StorageBuffer> + %arg2: !spv.ptr>)>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>}, %arg3: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>}, @@ -103,14 +103,14 @@ %37 = spv.IAdd %arg4, %11 : i32 // CHECK: spv.AccessChain [[ARG0]] %c0 = spv.constant 0 : i32 - %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 + %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr>)>, StorageBuffer>, i32, i32, i32 %39 = spv.Load "StorageBuffer" %38 : f32 // CHECK: spv.AccessChain [[ARG1]] - %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 + %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr>)>, StorageBuffer>, i32, i32, i32 %41 = spv.Load "StorageBuffer" %40 : f32 %42 = spv.FAdd %39, %41 : f32 // CHECK: spv.AccessChain [[ARG2]] - %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 + %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr>)>, StorageBuffer>, i32, i32, i32 spv.Store "StorageBuffer" %43, %42 : f32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir @@ -33,11 +33,11 @@ // ----- spv.module Logical GLSL450 { - spv.globalVariable @data bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @data bind(0, 0) : !spv.ptr [0])>, StorageBuffer> spv.func @callee() "None" { - %0 = spv._address_of @data : !spv.ptr [0]>, StorageBuffer> + %0 = spv._address_of @data : !spv.ptr [0])>, StorageBuffer> %1 = spv.constant 0: i32 - %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0])>, StorageBuffer>, i32, i32 spv.Branch ^next ^next: @@ -184,8 +184,8 @@ // ----- spv.module Logical GLSL450 { - spv.globalVariable @arg_0 bind(0, 0) : !spv.ptr, StorageBuffer> - spv.globalVariable @arg_1 bind(0, 1) : !spv.ptr, StorageBuffer> + spv.globalVariable @arg_0 bind(0, 0) : !spv.ptr, StorageBuffer> + spv.globalVariable @arg_1 bind(0, 1) : !spv.ptr, StorageBuffer> // CHECK: @inline_into_selection_region spv.func @inline_into_selection_region() "None" { @@ -194,9 +194,9 @@ // CHECK-DAG: [[ADDRESS_ARG1:%.*]] = spv._address_of @arg_1 // CHECK-DAG: [[LOADPTR:%.*]] = spv.AccessChain [[ADDRESS_ARG0]] // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOADPTR]] - %2 = spv._address_of @arg_0 : !spv.ptr, StorageBuffer> - %3 = spv._address_of @arg_1 : !spv.ptr, StorageBuffer> - %4 = spv.AccessChain %2[%1] : !spv.ptr, StorageBuffer>, i32 + %2 = spv._address_of @arg_0 : !spv.ptr, StorageBuffer> + %3 = spv._address_of @arg_1 : !spv.ptr, StorageBuffer> + %4 = spv.AccessChain %2[%1] : !spv.ptr, StorageBuffer>, i32 %5 = spv.Load "StorageBuffer" %4 : i32 %6 = spv.SGreaterThan %5, %1 : i32 // CHECK: spv.selection @@ -204,7 +204,7 @@ spv.BranchConditional %6, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: [[STOREPTR:%.*]] = spv.AccessChain [[ADDRESS_ARG1]] - %7 = spv.AccessChain %3[%1] : !spv.ptr, StorageBuffer>, i32 + %7 = spv.AccessChain %3[%1] : !spv.ptr, StorageBuffer>, i32 // CHECK-NOT: spv.FunctionCall // CHECK: spv.AtomicIAdd "Device" "AcquireRelease" [[STOREPTR]], [[VAL]] // CHECK: spv.Branch diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir @@ -1,30 +1,30 @@ // RUN: mlir-opt -decorate-spirv-composite-type-layout -split-input-file -verify-diagnostics %s -o - | FileCheck %s spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @var0 bind(0, 1) : !spv.ptr [4], f32 [12]>, Uniform> - spv.globalVariable @var0 bind(0,1) : !spv.ptr, f32>, Uniform> + // CHECK: spv.globalVariable @var0 bind(0, 1) : !spv.ptr [4], f32 [12])>, Uniform> + spv.globalVariable @var0 bind(0,1) : !spv.ptr, f32)>, Uniform> - // CHECK: spv.globalVariable @var1 bind(0, 2) : !spv.ptr [0], f32 [256]>, StorageBuffer> - spv.globalVariable @var1 bind(0,2) : !spv.ptr, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var1 bind(0, 2) : !spv.ptr [0], f32 [256])>, StorageBuffer> + spv.globalVariable @var1 bind(0,2) : !spv.ptr, f32)>, StorageBuffer> - // CHECK: spv.globalVariable @var2 bind(1, 0) : !spv.ptr [0], f32 [256]> [0], i32 [260]>, StorageBuffer> - spv.globalVariable @var2 bind(1,0) : !spv.ptr, f32>, i32>, StorageBuffer> + // CHECK: spv.globalVariable @var2 bind(1, 0) : !spv.ptr [0], f32 [256])> [0], i32 [260])>, StorageBuffer> + spv.globalVariable @var2 bind(1,0) : !spv.ptr, f32)>, i32)>, StorageBuffer> - // CHECK: spv.globalVariable @var3 : !spv.ptr [8]>, stride=72> [0], f32 [1152]>, StorageBuffer> - spv.globalVariable @var3 : !spv.ptr>>, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var3 : !spv.ptr [8])>, stride=72> [0], f32 [1152])>, StorageBuffer> + spv.globalVariable @var3 : !spv.ptr)>>, f32)>, StorageBuffer> - // CHECK: spv.globalVariable @var4 bind(1, 2) : !spv.ptr [0], f32 [16], i1 [20]> [0], i1 [24]>, StorageBuffer> - spv.globalVariable @var4 bind(1,2) : !spv.ptr, f32, i1>, i1>, StorageBuffer> + // CHECK: spv.globalVariable @var4 bind(1, 2) : !spv.ptr [0], f32 [16], i1 [20])> [0], i1 [24])>, StorageBuffer> + spv.globalVariable @var4 bind(1,2) : !spv.ptr, f32, i1)>, i1)>, StorageBuffer> - // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr [0]>, StorageBuffer> - spv.globalVariable @var5 bind(1,3) : !spv.ptr>, StorageBuffer> + // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr [0])>, StorageBuffer> + spv.globalVariable @var5 bind(1,3) : !spv.ptr)>, StorageBuffer> spv.func @kernel() -> () "None" { %c0 = spv.constant 0 : i32 - // CHECK: {{%.*}} = spv._address_of @var0 : !spv.ptr [4], f32 [12]>, Uniform> - %0 = spv._address_of @var0 : !spv.ptr, f32>, Uniform> - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr [4], f32 [12]>, Uniform> - %1 = spv.AccessChain %0[%c0] : !spv.ptr, f32>, Uniform>, i32 + // CHECK: {{%.*}} = spv._address_of @var0 : !spv.ptr [4], f32 [12])>, Uniform> + %0 = spv._address_of @var0 : !spv.ptr, f32)>, Uniform> + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr [4], f32 [12])>, Uniform> + %1 = spv.AccessChain %0[%c0] : !spv.ptr, f32)>, Uniform>, i32 spv.Return } } @@ -32,68 +32,68 @@ // ----- spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @var0 : !spv.ptr [0], i1 [16]> [0], i1 [24]> [0], i1 [32]> [0], i1 [40]>, Uniform> - spv.globalVariable @var0 : !spv.ptr, i1>, i1>, i1>, i1>, Uniform> + // CHECK: spv.globalVariable @var0 : !spv.ptr [0], i1 [16])> [0], i1 [24])> [0], i1 [32])> [0], i1 [40])>, Uniform> + spv.globalVariable @var0 : !spv.ptr, i1)>, i1)>, i1)>, i1)>, Uniform> - // CHECK: spv.globalVariable @var1 : !spv.ptr [8], f32 [24]> [0], f32 [32]>, Uniform> - spv.globalVariable @var1 : !spv.ptr, f32>, f32>, Uniform> + // CHECK: spv.globalVariable @var1 : !spv.ptr [8], f32 [24])> [0], f32 [32])>, Uniform> + spv.globalVariable @var1 : !spv.ptr, f32)>, f32)>, Uniform> - // CHECK: spv.globalVariable @var2 : !spv.ptr, stride=128> [8]> [8], f32 [2064]> [0], f32 [2072]>, Uniform> - spv.globalVariable @var2 : !spv.ptr>>, f32>, f32>, Uniform> + // CHECK: spv.globalVariable @var2 : !spv.ptr, stride=128> [8])> [8], f32 [2064])> [0], f32 [2072])>, Uniform> + spv.globalVariable @var2 : !spv.ptr>)>, f32)>, f32)>, Uniform> - // CHECK: spv.globalVariable @var3 : !spv.ptr [0], i1 [512]> [0], i1 [520]>, Uniform> - spv.globalVariable @var3 : !spv.ptr, i1>, i1>, Uniform> + // CHECK: spv.globalVariable @var3 : !spv.ptr [0], i1 [512])> [0], i1 [520])>, Uniform> + spv.globalVariable @var3 : !spv.ptr, i1)>, i1)>, Uniform> - // CHECK: spv.globalVariable @var4 : !spv.ptr [8], i1 [24]>, Uniform> - spv.globalVariable @var4 : !spv.ptr, i1>, Uniform> + // CHECK: spv.globalVariable @var4 : !spv.ptr [8], i1 [24])>, Uniform> + spv.globalVariable @var4 : !spv.ptr, i1)>, Uniform> - // CHECK: spv.globalVariable @var5 : !spv.ptr [8], i1 [24]>, Uniform> - spv.globalVariable @var5 : !spv.ptr, i1>, Uniform> + // CHECK: spv.globalVariable @var5 : !spv.ptr [8], i1 [24])>, Uniform> + spv.globalVariable @var5 : !spv.ptr, i1)>, Uniform> - // CHECK: spv.globalVariable @var6 : !spv.ptr [8], i1 [24]>, Uniform> - spv.globalVariable @var6 : !spv.ptr, i1>, Uniform> + // CHECK: spv.globalVariable @var6 : !spv.ptr [8], i1 [24])>, Uniform> + spv.globalVariable @var6 : !spv.ptr, i1)>, Uniform> - // CHECK: spv.globalVariable @var7 : !spv.ptr [0], i1 [16]> [8], i1 [32]>, Uniform> - spv.globalVariable @var7 : !spv.ptr, i1>, i1>, Uniform> + // CHECK: spv.globalVariable @var7 : !spv.ptr [0], i1 [16])> [8], i1 [32])>, Uniform> + spv.globalVariable @var7 : !spv.ptr, i1)>, i1)>, Uniform> } // ----- spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @var0 : !spv.ptr [0], f32 [8]>, StorageBuffer> - spv.globalVariable @var0 : !spv.ptr, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var0 : !spv.ptr [0], f32 [8])>, StorageBuffer> + spv.globalVariable @var0 : !spv.ptr, f32)>, StorageBuffer> - // CHECK: spv.globalVariable @var1 : !spv.ptr [0], f32 [12]>, StorageBuffer> - spv.globalVariable @var1 : !spv.ptr, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var1 : !spv.ptr [0], f32 [12])>, StorageBuffer> + spv.globalVariable @var1 : !spv.ptr, f32)>, StorageBuffer> - // CHECK: spv.globalVariable @var2 : !spv.ptr [0], f32 [16]>, StorageBuffer> - spv.globalVariable @var2 : !spv.ptr, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var2 : !spv.ptr [0], f32 [16])>, StorageBuffer> + spv.globalVariable @var2 : !spv.ptr, f32)>, StorageBuffer> } // ----- spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @emptyStructAsMember : !spv.ptr [0]>, StorageBuffer> - spv.globalVariable @emptyStructAsMember : !spv.ptr>, StorageBuffer> + // CHECK: spv.globalVariable @emptyStructAsMember : !spv.ptr [0])>, StorageBuffer> + spv.globalVariable @emptyStructAsMember : !spv.ptr)>, StorageBuffer> // CHECK: spv.globalVariable @arrayType : !spv.ptr>, StorageBuffer> spv.globalVariable @arrayType : !spv.ptr>, StorageBuffer> - // CHECK: spv.globalVariable @InputStorage : !spv.ptr>, Input> - spv.globalVariable @InputStorage : !spv.ptr>, Input> + // CHECK: spv.globalVariable @InputStorage : !spv.ptr)>, Input> + spv.globalVariable @InputStorage : !spv.ptr)>, Input> - // CHECK: spv.globalVariable @customLayout : !spv.ptr, Uniform> - spv.globalVariable @customLayout : !spv.ptr, Uniform> + // CHECK: spv.globalVariable @customLayout : !spv.ptr, Uniform> + spv.globalVariable @customLayout : !spv.ptr, Uniform> - // CHECK: spv.globalVariable @emptyStruct : !spv.ptr, Uniform> - spv.globalVariable @emptyStruct : !spv.ptr, Uniform> + // CHECK: spv.globalVariable @emptyStruct : !spv.ptr, Uniform> + spv.globalVariable @emptyStruct : !spv.ptr, Uniform> } // ----- spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @var0 : !spv.ptr, PushConstant> - spv.globalVariable @var0 : !spv.ptr, PushConstant> - // CHECK: spv.globalVariable @var1 : !spv.ptr, PhysicalStorageBuffer> - spv.globalVariable @var1 : !spv.ptr, PhysicalStorageBuffer> + // CHECK: spv.globalVariable @var0 : !spv.ptr, PushConstant> + spv.globalVariable @var0 : !spv.ptr, PushConstant> + // CHECK: spv.globalVariable @var1 : !spv.ptr, PhysicalStorageBuffer> + spv.globalVariable @var1 : !spv.ptr, PhysicalStorageBuffer> } diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir @@ -15,16 +15,16 @@ %7 = spv.CompositeInsert %value2, %6[2 : i32] : f32 into !spv.array<4xf32> %8 = spv.CompositeInsert %value0, %7[3 : i32] : f32 into !spv.array<4xf32> - %9 = spv.undef : !spv.struct - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct - %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct - %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct - %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct + %9 = spv.undef : !spv.struct<(f32, i32, f32)> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct<(f32, i32, f32)> + %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct<(f32, i32, f32)> + %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct<(f32, i32, f32)> + %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct<(f32, i32, f32)> - %13 = spv.undef : !spv.struct> - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct> - %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct> - %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct> + %13 = spv.undef : !spv.struct<(f32, !spv.array<3xf32>)> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct<(f32, !spv.array<3 x f32>)> + %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct<(f32, !spv.array<3xf32>)> + %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct<(f32, !spv.array<3xf32>)> spv.ReturnValue %3 : vector<3xf32> } diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -180,6 +180,6 @@ #spv.vce, {}> } { - spv.globalVariable @data : !spv.ptr, Uniform> + spv.globalVariable @data : !spv.ptr, Uniform> spv.globalVariable @img : !spv.ptr, UniformConstant> } diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -10,8 +10,8 @@ // CHECK-NEXT: %[[PTR:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] // CHECK-NEXT: spv.Load "Function" %[[PTR]] %c0 = spv.constant 0: i32 - %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>)>, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>)>, Function>, i32 %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr>, Function>, i32, i32 %3 = spv.Load "Function" %2 : f32 spv.ReturnValue %3 : f32 @@ -27,8 +27,8 @@ // CHECK-NEXT: spv.Load "Function" %[[PTR_0]] // CHECK-NEXT: spv.Load "Function" %[[PTR_1]] %c0 = spv.constant 0: i32 - %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>)>, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>)>, Function>, i32 %2 = spv.AccessChain %1[%c0] : !spv.ptr>, Function>, i32 %3 = spv.AccessChain %2[%c0] : !spv.ptr, Function>, i32 %4 = spv.Load "Function" %2 : !spv.array<4xf32> @@ -47,10 +47,10 @@ // CHECK-NEXT: spv.Load "Function" %[[VAR_0_PTR]] // CHECK-NEXT: spv.Load "Function" %[[VAR_1_PTR]] %c1 = spv.constant 1: i32 - %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 - %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>)>, Function> + %1 = spv.Variable : !spv.ptr>, !spv.array<4xi32>)>, Function> + %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>)>, Function>, i32 + %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>)>, Function>, i32 %4 = spv.Load "Function" %2 : !spv.array<4xi32> %5 = spv.Load "Function" %3 : !spv.array<4xi32> spv.ReturnValue %4 : !spv.array<4xi32> diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -12,10 +12,10 @@ // ----- -func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct) -> !spv.struct, !spv.array<4xf32>, !spv.struct> { - // CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct, !spv.array<4 x f32>, !spv.struct> - %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct, !spv.array<4xf32>, !spv.struct> - return %0: !spv.struct, !spv.array<4xf32>, !spv.struct> +func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct<(f32)>) -> !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> { + // CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4 x f32>, !spv.struct<(f32)>)> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> + return %0: !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> } // ----- @@ -28,10 +28,10 @@ // ----- -func @composite_construct_empty_struct() -> !spv.struct<> { - // CHECK: spv.CompositeConstruct : !spv.struct<> - %0 = spv.CompositeConstruct : !spv.struct<> - return %0: !spv.struct<> +func @composite_construct_empty_struct() -> !spv.struct<()> { + // CHECK: spv.CompositeConstruct : !spv.struct<()> + %0 = spv.CompositeConstruct : !spv.struct<()> + return %0: !spv.struct<()> } // ----- @@ -80,9 +80,9 @@ // ----- -func @composite_extract_struct(%arg0 : !spv.struct>) -> f32 { - // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spv.struct> - %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct> +func @composite_extract_struct(%arg0 : !spv.struct<(f32, !spv.array<4xf32>)>) -> f32 { + // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spv.struct<(f32, !spv.array<4 x f32>)> + %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct<(f32, !spv.array<4xf32>)> return %0 : f32 } @@ -156,9 +156,9 @@ // ----- -func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct>) -> () { - // expected-error @+1 {{index 2 out of bounds for '!spv.struct>'}} - %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct> +func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct<(f32, !spv.array<4xf32>)>) -> () { + // expected-error @+1 {{index 2 out of bounds for '!spv.struct<(f32, !spv.array<4 x f32>)>'}} + %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct<(f32, !spv.array<4xf32>)> return } @@ -216,10 +216,10 @@ // ----- -func @composite_insert_struct(%arg0: !spv.struct, f32>, %arg1: !spv.array<4xf32>) -> !spv.struct, f32> { - // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spv.array<4 x f32> into !spv.struct, f32> - %0 = spv.CompositeInsert %arg1, %arg0[0 : i32] : !spv.array<4xf32> into !spv.struct, f32> - return %0: !spv.struct, f32> +func @composite_insert_struct(%arg0: !spv.struct<(!spv.array<4xf32>, f32)>, %arg1: !spv.array<4xf32>) -> !spv.struct<(!spv.array<4xf32>, f32)> { + // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spv.array<4 x f32> into !spv.struct<(!spv.array<4 x f32>, f32)> + %0 = spv.CompositeInsert %arg1, %arg0[0 : i32] : !spv.array<4xf32> into !spv.struct<(!spv.array<4xf32>, f32)> + return %0: !spv.struct<(!spv.array<4xf32>, f32)> } // ----- diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -143,9 +143,9 @@ // ----- -spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { +spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { // expected-error @+1 {{Pointer must point to a scalar or vector type}} - %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup> spv.Return } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -6,9 +6,9 @@ func @access_chain_struct() -> () { %0 = spv.constant 1: i32 - %1 = spv.Variable : !spv.ptr>, Function> - // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 + %1 = spv.Variable : !spv.ptr)>, Function> + // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr)>, Function> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Function>, i32, i32 return } @@ -111,9 +111,9 @@ // ----- func @access_chain_invalid_index_2(%index0 : i32) -> () { - %0 = spv.Variable : !spv.ptr>, Function> + %0 = spv.Variable : !spv.ptr)>, Function> // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct}} - %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32, i32 + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr)>, Function>, i32, i32 return } @@ -121,9 +121,9 @@ func @access_chain_invalid_constant_type_1() -> () { %0 = std.constant 1: i32 - %1 = spv.Variable : !spv.ptr>, Function> + %1 = spv.Variable : !spv.ptr)>, Function> // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct, but provided std.constant}} - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Function>, i32, i32 return } @@ -131,9 +131,9 @@ func @access_chain_out_of_bounds() -> () { %index0 = "spv.constant"() { value = 12: i32} : () -> i32 - %0 = spv.Variable : !spv.ptr>, Function> - // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct>'}} - %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32, i32 + %0 = spv.Variable : !spv.ptr)>, Function> + // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct<(f32, !spv.array<4 x f32>)>'}} + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr)>, Function>, i32, i32 return } diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -5,13 +5,13 @@ //===----------------------------------------------------------------------===// spv.module Logical GLSL450 { - spv.globalVariable @var1 : !spv.ptr>, Input> + spv.globalVariable @var1 : !spv.ptr)>, Input> spv.func @access_chain() -> () "None" { %0 = spv.constant 1: i32 - // CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr>, Input> - // CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr>, Input> - %1 = spv._address_of @var1 : !spv.ptr>, Input> - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Input>, i32, i32 + // CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr)>, Input> + // CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr)>, Input> + %1 = spv._address_of @var1 : !spv.ptr)>, Input> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Input>, i32, i32 spv.Return } } @@ -19,27 +19,27 @@ // ----- // Allow taking address of global variables in other module-like ops -spv.globalVariable @var : !spv.ptr>, Input> +spv.globalVariable @var : !spv.ptr)>, Input> func @address_of() -> () { // CHECK: spv._address_of @var - %1 = spv._address_of @var : !spv.ptr>, Input> + %1 = spv._address_of @var : !spv.ptr)>, Input> return } // ----- spv.module Logical GLSL450 { - spv.globalVariable @var1 : !spv.ptr>, Input> + spv.globalVariable @var1 : !spv.ptr)>, Input> spv.func @foo() -> () "None" { // expected-error @+1 {{expected spv.globalVariable symbol}} - %0 = spv._address_of @var2 : !spv.ptr>, Input> + %0 = spv._address_of @var2 : !spv.ptr)>, Input> } } // ----- spv.module Logical GLSL450 { - spv.globalVariable @var1 : !spv.ptr>, Input> + spv.globalVariable @var1 : !spv.ptr)>, Input> spv.func @foo() -> () "None" { // expected-error @+1 {{result type mismatch with the referenced global variable's type}} %0 = spv._address_of @var1 : !spv.ptr @@ -653,8 +653,8 @@ spv.specConstant @sc1 = 1 : i32 spv.specConstant @sc2 = 2.5 : f32 spv.specConstant @sc3 = 3.5 : f32 - // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct - spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<(i32, f32, f32)> + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<(i32, f32, f32)> } // ----- @@ -664,7 +664,7 @@ spv.specConstant @sc2 = 2.5 : f32 spv.specConstant @sc3 = 3.5 : f32 // expected-error @+1 {{has incorrect number of operands: expected 2, but provided 3}} - spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<(i32, f32)> } // ----- @@ -674,7 +674,7 @@ spv.specConstant @sc2 = 2.5 : f32 spv.specConstant @sc3 = 3.5 : f32 // expected-error @+1 {{has incorrect types of operands: expected 'i32', but provided 'f32'}} - spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<(i32, f32, f32)> } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -230,119 +230,194 @@ // StructType //===----------------------------------------------------------------------===// -// CHECK: func @struct_type(!spv.struct) -func @struct_type(!spv.struct) -> () +// CHECK: func @struct_type(!spv.struct<(f32)>) +func @struct_type(!spv.struct<(f32)>) -> () -// CHECK: func @struct_type2(!spv.struct) -func @struct_type2(!spv.struct) -> () +// CHECK: func @struct_type2(!spv.struct<(f32 [0])>) +func @struct_type2(!spv.struct<(f32 [0])>) -> () -// CHECK: func @struct_type_simple(!spv.struct>) -func @struct_type_simple(!spv.struct>) -> () +// CHECK: func @struct_type_simple(!spv.struct<(f32, !spv.image)>) +func @struct_type_simple(!spv.struct<(f32, !spv.image)>) -> () -// CHECK: func @struct_type_with_offset(!spv.struct) -func @struct_type_with_offset(!spv.struct) -> () +// CHECK: func @struct_type_with_offset(!spv.struct<(f32 [0], i32 [4])>) +func @struct_type_with_offset(!spv.struct<(f32 [0], i32 [4])>) -> () -// CHECK: func @nested_struct(!spv.struct>) -func @nested_struct(!spv.struct>) +// CHECK: func @nested_struct(!spv.struct<(f32, !spv.struct<(f32, i32)>)>) +func @nested_struct(!spv.struct<(f32, !spv.struct<(f32, i32)>)>) -// CHECK: func @nested_struct_with_offset(!spv.struct [4]>) -func @nested_struct_with_offset(!spv.struct [4]>) +// CHECK: func @nested_struct_with_offset(!spv.struct<(f32 [0], !spv.struct<(f32 [0], i32 [4])> [4])>) +func @nested_struct_with_offset(!spv.struct<(f32 [0], !spv.struct<(f32 [0], i32 [4])> [4])>) -// CHECK: func @struct_type_with_decoration(!spv.struct) -func @struct_type_with_decoration(!spv.struct) +// CHECK: func @struct_type_with_decoration(!spv.struct<(f32 [NonWritable])>) +func @struct_type_with_decoration(!spv.struct<(f32 [NonWritable])>) -// CHECK: func @struct_type_with_decoration_and_offset(!spv.struct) -func @struct_type_with_decoration_and_offset(!spv.struct) +// CHECK: func @struct_type_with_decoration_and_offset(!spv.struct<(f32 [0, NonWritable])>) +func @struct_type_with_decoration_and_offset(!spv.struct<(f32 [0, NonWritable])>) -// CHECK: func @struct_type_with_decoration2(!spv.struct) -func @struct_type_with_decoration2(!spv.struct) +// CHECK: func @struct_type_with_decoration2(!spv.struct<(f32 [NonWritable], i32 [NonReadable])>) +func @struct_type_with_decoration2(!spv.struct<(f32 [NonWritable], i32 [NonReadable])>) -// CHECK: func @struct_type_with_decoration3(!spv.struct) -func @struct_type_with_decoration3(!spv.struct) +// CHECK: func @struct_type_with_decoration3(!spv.struct<(f32, i32 [NonReadable])>) +func @struct_type_with_decoration3(!spv.struct<(f32, i32 [NonReadable])>) -// CHECK: func @struct_type_with_decoration4(!spv.struct) -func @struct_type_with_decoration4(!spv.struct) +// CHECK: func @struct_type_with_decoration4(!spv.struct<(f32 [0], i32 [4, NonReadable])>) +func @struct_type_with_decoration4(!spv.struct<(f32 [0], i32 [4, NonReadable])>) -// CHECK: func @struct_type_with_decoration5(!spv.struct) -func @struct_type_with_decoration5(!spv.struct) +// CHECK: func @struct_type_with_decoration5(!spv.struct<(f32 [NonWritable, NonReadable])>) +func @struct_type_with_decoration5(!spv.struct<(f32 [NonWritable, NonReadable])>) -// CHECK: func @struct_type_with_decoration6(!spv.struct>) -func @struct_type_with_decoration6(!spv.struct>) +// CHECK: func @struct_type_with_decoration6(!spv.struct<(f32, !spv.struct<(i32 [NonWritable, NonReadable])>)>) +func @struct_type_with_decoration6(!spv.struct<(f32, !spv.struct<(i32 [NonWritable, NonReadable])>)>) -// CHECK: func @struct_type_with_decoration7(!spv.struct [4]>) -func @struct_type_with_decoration7(!spv.struct [4]>) +// CHECK: func @struct_type_with_decoration7(!spv.struct<(f32 [0], !spv.struct<(i32, f32 [NonReadable])> [4])>) +func @struct_type_with_decoration7(!spv.struct<(f32 [0], !spv.struct<(i32, f32 [NonReadable])> [4])>) -// CHECK: func @struct_type_with_decoration8(!spv.struct>) -func @struct_type_with_decoration8(!spv.struct>) +// CHECK: func @struct_type_with_decoration8(!spv.struct<(f32, !spv.struct<(i32 [0], f32 [4, NonReadable])>)>) +func @struct_type_with_decoration8(!spv.struct<(f32, !spv.struct<(i32 [0], f32 [4, NonReadable])>)>) -// CHECK: func @struct_type_with_matrix_1(!spv.struct> [0, ColMajor, MatrixStride=16]>) -func @struct_type_with_matrix_1(!spv.struct> [0, ColMajor, MatrixStride=16]>) +// CHECK: func @struct_type_with_matrix_1(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>) +func @struct_type_with_matrix_1(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>) -// CHECK: func @struct_type_with_matrix_2(!spv.struct> [0, RowMajor, MatrixStride=16]>) -func @struct_type_with_matrix_2(!spv.struct> [0, RowMajor, MatrixStride=16]>) +// CHECK: func @struct_type_with_matrix_2(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16])>) +func @struct_type_with_matrix_2(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16])>) -// CHECK: func @struct_empty(!spv.struct<>) -func @struct_empty(!spv.struct<>) +// CHECK: func @struct_empty(!spv.struct<()>) +func @struct_empty(!spv.struct<()>) // ----- // expected-error @+1 {{offset specification must be given for all members}} -func @struct_type_missing_offset1((!spv.struct) -> () +func @struct_type_missing_offset1((!spv.struct<(f32, i32 [4])>) -> () // ----- // expected-error @+1 {{offset specification must be given for all members}} -func @struct_type_missing_offset2(!spv.struct) -> () +func @struct_type_missing_offset2(!spv.struct<(f32 [3], i32)>) -> () // ----- -// expected-error @+1 {{expected '>'}} -func @struct_type_missing_comma1(!spv.struct) -> () +// expected-error @+1 {{expected ')'}} +func @struct_type_missing_comma1(!spv.struct<(f32 i32)>) -> () // ----- -// expected-error @+1 {{expected '>'}} -func @struct_type_missing_comma2(!spv.struct) -> () +// expected-error @+1 {{expected ')'}} +func @struct_type_missing_comma2(!spv.struct<(f32 [0] i32)>) -> () // ----- -// expected-error @+1 {{unbalanced '>' character in pretty dialect name}} -func @struct_type_neg_offset(!spv.struct) -> () +// expected-error @+1 {{unbalanced ')' character in pretty dialect name}} +func @struct_type_neg_offset(!spv.struct<(f32 [0)>) -> () // ----- // expected-error @+1 {{unbalanced ']' character in pretty dialect name}} -func @struct_type_neg_offset(!spv.struct) -> () +func @struct_type_neg_offset(!spv.struct<(f32 0])>) -> () // ----- // expected-error @+1 {{expected ']'}} -func @struct_type_neg_offset(!spv.struct) -> () +func @struct_type_neg_offset(!spv.struct<(f32 [NonWritable 0])>) -> () // ----- // expected-error @+1 {{expected valid keyword}} -func @struct_type_neg_offset(!spv.struct) -> () +func @struct_type_neg_offset(!spv.struct<(f32 [NonWritable, 0])>) -> () // ----- // expected-error @+1 {{expected ','}} -func @struct_type_missing_comma(!spv.struct) +func @struct_type_missing_comma(!spv.struct<(f32 [0 NonWritable], i32 [4])>) // ----- // expected-error @+1 {{expected ']'}} -func @struct_type_missing_comma(!spv.struct) +func @struct_type_missing_comma(!spv.struct<(f32 [0, NonWritable NonReadable], i32 [4])>) // ----- // expected-error @+1 {{expected ']'}} -func @struct_type_missing_comma(!spv.struct> [0, RowMajor MatrixStride=16]>) +func @struct_type_missing_comma(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, RowMajor MatrixStride=16])>) // ----- // expected-error @+1 {{expected integer value}} -func @struct_missing_member_decorator_value(!spv.struct> [0, RowMajor, MatrixStride=]>) +func @struct_missing_member_decorator_value(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=])>) + +// ----- + +//===----------------------------------------------------------------------===// +// StructType (identified) +//===----------------------------------------------------------------------===// + +// CHECK: func @id_struct_empty(!spv.struct) +func @id_struct_empty(!spv.struct) -> () + +// ----- + +// CHECK: func @id_struct_simple(!spv.struct) +func @id_struct_simple(!spv.struct) -> () + +// ----- + +// CHECK: func @id_struct_multiple_elements(!spv.struct) +func @id_struct_multiple_elements(!spv.struct) -> () + +// ----- + +// CHECK: func @id_struct_nested_literal(!spv.struct)>) +func @id_struct_nested_literal(!spv.struct)>) -> () + +// ----- + +// CHECK: func @id_struct_nested_id(!spv.struct)>) +func @id_struct_nested_id(!spv.struct)>) -> () + +// ----- + +// CHECK: func @literal_struct_nested_id(!spv.struct<(!spv.struct)>) +func @literal_struct_nested_id(!spv.struct<(!spv.struct)>) -> () + +// ----- + +// CHECK: func @id_struct_self_recursive(!spv.struct, Uniform>)>) +func @id_struct_self_recursive(!spv.struct, Uniform>)>) -> () + +// ----- + +// CHECK: func @id_struct_self_recursive2(!spv.struct, Uniform>)>) +func @id_struct_self_recursive2(!spv.struct, Uniform>)>) -> () + +// ----- + +// expected-error @+1 {{recursive struct reference not nested in struct definition}} +func @id_wrong_recursive_reference(!spv.struct) -> () + +// ----- + +// expected-error @+1 {{recursive struct reference not nested in struct definition}} +func @id_struct_recursive_invalid(!spv.struct, Uniform>)>) -> () + +// ----- + +// expected-error @+1 {{identifier already used for an enclosing struct}} +func @id_struct_redefinition(!spv.struct, Uniform>)>, Uniform>)>) -> () + +// ----- + +// Equivalent to: +// struct a { struct b *bPtr; }; +// struct b { struct a *aPtr; }; +// CHECK: func @id_struct_recursive(!spv.struct, Uniform>)>, Uniform>)>) +func @id_struct_recursive(!spv.struct, Uniform>)>, Uniform>)>) -> () + +// ----- + +// Equivalent to: +// struct a { struct b *bPtr; }; +// struct b { struct a *aPtr, struct b *bPtr; }; +// CHECK: func @id_struct_recursive(!spv.struct, Uniform>, !spv.ptr, Uniform>)>, Uniform>)>) +func @id_struct_recursive(!spv.struct, Uniform>, !spv.ptr, Uniform>)>, Uniform>)>) -> () // ----- @@ -446,4 +521,4 @@ // expected-error @+1 {{expected single unsigned integer for number of columns}} func @matrix_size_type(!spv.matrix<2.0 x vector<3xi32>>) -> () -// ----- \ No newline at end of file +// -----