diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3136,6 +3136,7 @@ def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>; def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; +def SPV_OC_OpTypeForwardPointer : I32EnumAttrCase<"OpTypeForwardPointer", 39>; def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>; def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>; def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>; @@ -3282,21 +3283,21 @@ SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, - SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, - SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, - SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, - SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, - SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, - SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, - SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, - SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, - SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, - SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, - SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, - SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, - SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, - SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, + SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer, + SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, + SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, + SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, + SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, + SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, + SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, + SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, + SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, + SPV_OC_OpBitcast, SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, + SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, + SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, + SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -302,8 +302,17 @@ ArrayRef offsetInfo = {}, ArrayRef memberDecorations = {}); - /// Construct a struct with no members. - static StructType getEmpty(MLIRContext *context); + /// Lookup an identified struct. + static StructType lookupIdentified(MLIRContext *context, + StringRef identifier); + + /// Construct an identified struct. + static StructType getIdentified(MLIRContext *context, StringRef identifier); + + /// Construct a (possibly identified) struct with no members. + static StructType getEmpty(MLIRContext *context, StringRef identifier = ""); + + StringRef getIdentifier() const; unsigned getNumElements() const; @@ -346,6 +355,10 @@ SmallVectorImpl &decorationsInfo) const; + LogicalResult + trySetBody(ArrayRef memberTypes, ArrayRef offsetInfo = {}, + ArrayRef memberDecorations = {}); + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -15,6 +15,7 @@ #define MLIR_IR_DIALECTIMPLEMENTATION_H #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/raw_ostream.h" @@ -47,6 +48,8 @@ /// Print the given type to the stream. virtual void printType(Type type) = 0; + virtual llvm::SetVector &getStructContext() = 0; + private: DialectAsmPrinter(const DialectAsmPrinter &) = delete; void operator=(const DialectAsmPrinter &) = delete; @@ -135,7 +138,8 @@ virtual ParseResult parseFloat(double &result) = 0; /// Parse an integer value from the stream. - template ParseResult parseInteger(IntT &result) { + template + ParseResult parseInteger(IntT &result) { auto loc = getCurrentLocation(); OptionalParseResult parseResult = parseOptionalInteger(result); if (!parseResult.hasValue()) @@ -311,7 +315,8 @@ virtual ParseResult parseType(Type &result) = 0; /// Parse a type of a specific kind, e.g. a FunctionType. - template ParseResult parseType(TypeType &result) { + template + ParseResult parseType(TypeType &result) { llvm::SMLoc loc = getCurrentLocation(); // Parse any kind of type. @@ -341,6 +346,8 @@ /// static-dimension-list ::= (integer `x`)* virtual ParseResult parseDimensionList(SmallVectorImpl &dimensions, bool allowDynamic = true) = 0; + + virtual llvm::SetVector &getStructContext() = 0; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" +#include "llvm/ADT/SetVector.h" #include #include #include @@ -156,6 +157,8 @@ /// instances. This should not be used directly. StorageUniquer &getAttributeUniquer(); + llvm::SetVector &getStructContext(); + /// These APIs are tracking whether the context will be used in a /// multithreading environment: this has no effect other than enabling /// assertions on misuses of some APIs. diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -82,6 +82,14 @@ return detail::InterfaceMap::template get...>(); } + template + static ConcreteT lookup(MLIRContext *ctx, Args... args) { + // Ensure that the invariants are correct for construction. + assert(succeeded(ConcreteT::verifyConstructionInvariants( + generateUnknownStorageLocation(ctx), args...))); + return UniquerT::template lookup(ctx, args...); + } + /// Get or create a new ConcreteT instance within the ctx. This /// function is guaranteed to return a non null object and will assert if /// the arguments provided are invalid. diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -121,6 +121,12 @@ /// A utility class to get, or create, unique instances of types within an /// MLIRContext. This class manages all creation and uniquing of types. struct TypeUniquer { + template + static T lookup(MLIRContext *ctx, Args &&... args) { + return ctx->getTypeUniquer().lookup( + T::getTypeID(), std::forward(args)...); + } + /// Get an uniqued instance of a parametric type T. template static typename std::enable_if_t< diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -162,6 +162,24 @@ registerSingletonStorageType(TypeID::get(), initFn); } + template + Storage *lookup(const TypeID &id, Arg &&arg, Args &&... args) { + // Construct a value of the derived key type. + auto derivedKey = + getKey(std::forward(arg), std::forward(args)...); + + // Create a hash of the derived key. + unsigned hashValue = getHash(derivedKey); + + // Generate an equality function for the derived storage. + auto isEqual = [&derivedKey](const BaseStorage *existing) { + return static_cast(*existing) == derivedKey; + }; + + // Get an instance for the derived storage. + return static_cast(lookupImpl(id, hashValue, isEqual)); + } + /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when /// registering the storage instance. 'initFn' is an optional parameter that /// can be used to initialize a newly inserted storage instance. This function @@ -244,6 +262,9 @@ } private: + BaseStorage *lookupImpl(const TypeID &id, unsigned hashValue, + function_ref isEqual); + /// Implementation for getting/creating an instance of a derived type with /// parametric storage. BaseStorage *getParametricStorageTypeImpl( diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp --- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -67,7 +67,18 @@ size = llvm::alignTo(structMemberOffset, maxMemberAlignment); alignment = maxMemberAlignment; structType.getMemberDecorations(memberDecorations); - return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); + + if (structType.getIdentifier().empty()) { + return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); + } else { + // TODO What should we do in that situation? Identified structs are uniqued + // by identifier so it is not possible to create 2 structs with the same + // name but different decorations. + // + // Should we, for example, add a random suffix in order to create a new + // struct type? + assert(false && "Identified structs are not supported."); + } } Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -589,15 +589,60 @@ } // struct-member-decoration ::= integer-literal? spirv-decoration* -// struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)? -// (`, ` spirv-type (`[` struct-member-decoration `]`)? `>` +// struct-type ::= +// `!spv.struct<(` spirv-type (`[` struct-member-decoration `]`)? +// (`, ` spirv-type (`[` struct-member-decoration `]`)? `>` static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser) { if (parser.parseLess()) return Type(); - if (succeeded(parser.parseOptionalGreater())) - return StructType::getEmpty(dialect.getContext()); + StringRef identifier; + bool identifierExistsInCtx = false; + + // Check if this is an idenitifed struct + if (succeeded(parser.parseOptionalKeyword(&identifier))) { + // Check if this is a possible recursive reference + if (succeeded(parser.parseOptionalGreater())) { + if (parser.getStructContext().count(identifier) == 0) { + parser.emitError( + parser.getNameLoc(), + "recursive struct reference not nested in struct definition"); + + return Type(); + } + + StructType lookupResult = + StructType::lookupIdentified(dialect.getContext(), identifier); + + return lookupResult; + } + + if (parser.parseComma()) + return Type(); + + identifierExistsInCtx = (parser.getStructContext().count(identifier) > 0); + parser.getStructContext().insert(identifier); + } + + if (parser.parseLParen()) + return Type(); + + if (identifierExistsInCtx) { + parser.emitError(parser.getNameLoc(), + "identifier already used for an enclosing struct"); + + return Type(); + } + + if (!parser.parseOptionalRParen() && !parser.parseOptionalGreater()) + return StructType::getEmpty(dialect.getContext(), identifier); + + StructType idStructTy; + + if (!identifier.empty()) { + idStructTy = StructType::getIdentified(dialect.getContext(), identifier); + } SmallVector memberTypes; SmallVector offsetInfo; @@ -622,8 +667,16 @@ "offset specification must be given for all members"); return Type(); } - if (parser.parseGreater()) + + if (parser.parseRParen() || parser.parseGreater()) return Type(); + + if (!identifier.empty()) { + idStructTy.trySetBody(memberTypes, offsetInfo, memberDecorationInfo); + parser.getStructContext().remove(identifier); + return idStructTy; + } + return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); } @@ -690,6 +743,21 @@ static void print(StructType type, DialectAsmPrinter &os) { os << "struct<"; + + if (!type.getIdentifier().empty()) { + os << type.getIdentifier(); + + if (os.getStructContext().count(type.getIdentifier()) == 0) { + os << ", "; + os.getStructContext().insert(type.getIdentifier()); + } else { + os << ">"; + return; + } + } + + os << "("; + auto printMember = [&](unsigned i) { os << type.getElementType(i); SmallVector decorations; @@ -713,7 +781,11 @@ }; llvm::interleaveComma(llvm::seq(0, type.getNumElements()), os, printMember); - os << ">"; + os << ")>"; + + if (!type.getIdentifier().empty()) { + os.getStructContext().remove(type.getIdentifier()); + } } static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -760,24 +760,45 @@ //===----------------------------------------------------------------------===// struct spirv::detail::StructTypeStorage : public TypeStorage { + StructTypeStorage(StringRef identifier, TypeStorageAllocator &allocator) + : memberTypes(nullptr), offsetInfo(nullptr), numMemberDecorations(0), + memberDecorationsInfo(nullptr), isBodySet(false), + identifier(identifier), allocator(&allocator) {} + StructTypeStorage( unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo) : memberTypes(memberTypes), offsetInfo(layoutInfo), numMembers(numMembers), numMemberDecorations(numMemberDecorations), - memberDecorationsInfo(memberDecorationsInfo) {} + memberDecorationsInfo(memberDecorationsInfo), isBodySet(false), + identifier(StringRef()), allocator(nullptr) {} + + using KeyTy = + std::tuple, ArrayRef, + ArrayRef>; - using KeyTy = std::tuple, ArrayRef, - ArrayRef>; bool operator==(const KeyTy &key) const { - return key == - KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo()); + if (isIdentified()) + // Identified types are uniqued by their identifier. + return getIdentifier() == std::get<0>(key); + else + return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), + getMemberDecorationsInfo()); } static StructTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { - ArrayRef keyTypes = std::get<0>(key); + StringRef keyIdentifier = std::get<0>(key); + + if (!keyIdentifier.empty()) { + StringRef identifier = allocator.copyInto(keyIdentifier); + + return new (allocator.allocate()) + StructTypeStorage(identifier, allocator); + } + + ArrayRef keyTypes = std::get<1>(key); // Copy the member type and layout information into the bump pointer const Type *typesList = nullptr; @@ -786,8 +807,8 @@ } const StructType::OffsetInfo *offsetInfoList = nullptr; - if (!std::get<1>(key).empty()) { - ArrayRef keyOffsetInfo = std::get<1>(key); + if (!std::get<2>(key).empty()) { + ArrayRef keyOffsetInfo = std::get<2>(key); assert(keyOffsetInfo.size() == keyTypes.size() && "size of offset information must be same as the size of number of " "elements"); @@ -796,11 +817,12 @@ const StructType::MemberDecorationInfo *memberDecorationList = nullptr; unsigned numMemberDecorations = 0; - if (!std::get<2>(key).empty()) { - auto keyMemberDecorations = std::get<2>(key); + if (!std::get<3>(key).empty()) { + auto keyMemberDecorations = std::get<3>(key); numMemberDecorations = keyMemberDecorations.size(); memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } + return new (allocator.allocate()) StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, numMemberDecorations, memberDecorationList); @@ -825,11 +847,53 @@ return {}; } + StringRef getIdentifier() const { return identifier; } + + bool isIdentified() const { return !identifier.empty(); } + + bool hasBody() const { return memberTypes == nullptr; } + + LogicalResult + trySetBody(ArrayRef memberTypes, + ArrayRef offsetInfo, + ArrayRef memberDecorations) { + if (isBodySet > 0) { + return failure(); + } + + isBodySet = true; + numMembers = memberTypes.size(); + + // Copy the member type and layout information into the bump pointer + if (!memberTypes.empty()) { + this->memberTypes = allocator->copyInto(memberTypes).data(); + } + + if (!offsetInfo.empty()) { + assert(offsetInfo.size() == memberTypes.size() && + "size of offset information must be same as the size of number of " + "elements"); + this->offsetInfo = allocator->copyInto(offsetInfo).data(); + } + + if (!memberDecorations.empty()) { + this->numMemberDecorations = memberDecorations.size(); + this->memberDecorationsInfo = + allocator->copyInto(memberDecorations).data(); + } + + return success(); + } + Type const *memberTypes; StructType::OffsetInfo const *offsetInfo; unsigned numMembers; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; + + bool isBodySet; + StringRef identifier; + TypeStorageAllocator *allocator; }; StructType @@ -841,16 +905,40 @@ SmallVector sortedDecorations( memberDecorations.begin(), memberDecorations.end()); llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); - return Base::get(memberTypes.vec().front().getContext(), memberTypes, - offsetInfo, sortedDecorations); + return Base::get(memberTypes.vec().front().getContext(), StringRef(), + memberTypes, offsetInfo, sortedDecorations); } -StructType StructType::getEmpty(MLIRContext *context) { - return Base::get(context, ArrayRef(), +StructType StructType::lookupIdentified(MLIRContext *context, + StringRef identifier) { + assert(!identifier.empty() && "Struct identifier must be non-empty string"); + + return Base::lookup(context, identifier, ArrayRef(), + ArrayRef(), + ArrayRef()); +} + +StructType StructType::getIdentified(MLIRContext *context, + StringRef identifier) { + assert(!identifier.empty() && "Struct identifier must be non-empty string"); + + return Base::get(context, identifier, ArrayRef(), ArrayRef(), ArrayRef()); } +StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { + StructType newStructType = Base::get( + context, identifier, ArrayRef(), ArrayRef(), + ArrayRef()); + // Set an empty body in case this is a identified struct. + newStructType.trySetBody(ArrayRef(), ArrayRef(), + ArrayRef()); + return newStructType; +} + +StringRef StructType::getIdentifier() const { return getImpl()->identifier; } + unsigned StructType::getNumElements() const { return getImpl()->numMembers; } Type StructType::getElementType(unsigned index) const { @@ -895,6 +983,13 @@ } } +LogicalResult +StructType::trySetBody(ArrayRef memberTypes, + ArrayRef offsetInfo, + ArrayRef memberDecorations) { + return getImpl()->trySetBody(memberTypes, offsetInfo, memberDecorations); +} + void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { for (Type elementType : getElementTypes()) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -87,6 +87,44 @@ /// Map from a selection/loop's header block to its merge (and continue) target. using BlockMergeInfoMap = DenseMap; +/// A "deferred struct type" is a struct type with one or more member types not +/// known when the Deserializer first encounters the struct. This happens, for +/// example, with recursive structs where a pointer to the struct type is +/// forward declared through OpTypeForwardPointer in the SPIR-V module before +/// the struct declaration; the actual pointer to struct type should be defined +/// later through an OpTypePointer. For example, the following C struct: +/// +/// struct A { +/// A* next; +/// }; +/// +/// would be represented in the SPIR-V module as: +/// +/// OpName %A "A" +/// OpTypeForwardPointer %APtr Generic +/// %A = OpTypeStruct %APtr +/// %APtr = OpTypePointer Generic %A +/// +/// This means that the spirv::StructType cannot be fully constructed directly +/// when the Deserializer encounters it. Instead we create a +/// DeferredStructTypeInfo that contains all the information we know about the +/// spirv::StructType. Once all forward references for the struct are resolved, +/// the struct's body is set will all member info. +struct DeferredStructTypeInfo { + // The ID of the deferred struct type. + uint32_t structID; + + // A list of all unresolved member types for the struct. First element of each + // item is operand ID, second element is member index in the struct. + SmallVector, 0> unresolvedMemberTypes; + + // The list of member types. For unresolved members, this list contains + // place-holder empty types that will be updated later. + SmallVector memberTypes; + SmallVector offsetInfo; + SmallVector memberDecorationsInfo; +}; + /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each @@ -219,6 +257,8 @@ /// registers the type into `module`. LogicalResult processType(spirv::Opcode opcode, ArrayRef operands); + LogicalResult processOpTypePointer(ArrayRef operands); + LogicalResult processArrayType(ArrayRef operands); LogicalResult processCooperativeMatrixType(ArrayRef operands); @@ -380,6 +420,8 @@ /// insertion point. LogicalResult processUndef(ArrayRef operands); + LogicalResult processTypeForwardPointer(ArrayRef operands); + /// Method to dispatch to the specialized deserialization function for an /// operation in SPIR-V dialect that is a mirror of an instruction in the /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for @@ -518,6 +560,13 @@ // processed. SmallVector>, 4> deferredInstructions; + + // A list of IDs for all types forward-declared through OpTypeForwardPointer + // instructions. + llvm::SetVector typeForwardPointerIDs; + + // A list of all structs which have unresolved member types. + llvm::SmallVector deferredStructTypesInfos; }; } // namespace @@ -1161,16 +1210,7 @@ typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); } break; case spirv::Opcode::OpTypePointer: { - if (operands.size() != 3) { - return emitError(unknownLoc, "OpTypePointer must have two parameters"); - } - auto pointeeType = getType(operands[2]); - if (!pointeeType) { - return emitError(unknownLoc, "unknown OpTypePointer pointee type ") - << operands[2]; - } - auto storageClass = static_cast(operands[1]); - typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass); + return processOpTypePointer(operands); } break; case spirv::Opcode::OpTypeArray: return processArrayType(operands); @@ -1190,6 +1230,62 @@ return success(); } +LogicalResult Deserializer::processOpTypePointer(ArrayRef operands) { + if (operands.size() != 3) { + return emitError(unknownLoc, "OpTypePointer must have two parameters"); + } + + auto pointeeType = getType(operands[2]); + + if (!pointeeType) { + return emitError(unknownLoc, "unknown OpTypePointer pointee type ") + << operands[2]; + } + + uint32_t typePointerID = operands[0]; + auto storageClass = static_cast(operands[1]); + typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); + + for (auto *deferredStructIt = std::begin(deferredStructTypesInfos); + deferredStructIt != std::end(deferredStructTypesInfos);) { + for (auto *unresolvedMemberIt = + std::begin(deferredStructIt->unresolvedMemberTypes); + unresolvedMemberIt != + std::end(deferredStructIt->unresolvedMemberTypes);) { + if (unresolvedMemberIt->first == typePointerID) { + // The newly constructed pointer type can resolve one of the + // deferred struct type members; update the memberTypes list and + // clean the unresolvedMemberTypes list accordingly. + deferredStructIt->memberTypes[unresolvedMemberIt->second] = + typeMap[typePointerID]; + unresolvedMemberIt = + deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt); + } else { + ++unresolvedMemberIt; + } + } + + if (deferredStructIt->unresolvedMemberTypes.empty()) { + // All deferred struct type members are now resolved, set the struct body. + auto structType = + typeMap[deferredStructIt->structID].dyn_cast(); + + assert(structType && "Expected a spirv::StructType."); + assert(!structType.getIdentifier().empty() && + "Expected an indentified struct."); + + structType.trySetBody(deferredStructIt->memberTypes, + deferredStructIt->offsetInfo, + deferredStructIt->memberDecorationsInfo); + deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); + } else { + ++deferredStructIt; + } + } + + return success(); +} + LogicalResult Deserializer::processArrayType(ArrayRef operands) { if (operands.size() != 3) { return emitError(unknownLoc, @@ -1293,22 +1389,36 @@ } LogicalResult Deserializer::processStructType(ArrayRef operands) { + // TODO Find a way to handle identified structs when debug info is stripped. + if (operands.empty()) { return emitError(unknownLoc, "OpTypeStruct must have at least result "); } + if (operands.size() == 1) { // Handle empty struct. - typeMap[operands[0]] = spirv::StructType::getEmpty(context); + typeMap[operands[0]] = + spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str()); return success(); } + // First element is operand ID, second element is member index in the struct. + SmallVector, 0> unresolvedMemberTypes; SmallVector memberTypes; + for (auto op : llvm::drop_begin(operands, 1)) { Type memberType = getType(op); - if (!memberType) { + bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0); + + if (!memberType && !typeForwardPtr) { return emitError(unknownLoc, "OpTypeStruct references undefined ") << op; } + + if (!memberType) { + unresolvedMemberTypes.emplace_back(op, memberTypes.size()); + } + memberTypes.push_back(memberType); } @@ -1340,8 +1450,28 @@ } } } - typeMap[operands[0]] = - spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + + uint32_t structID = operands[0]; + std::string structIdentifier = nameMap.lookup(structID).str(); + + if (structIdentifier.empty()) { + assert(unresolvedMemberTypes.empty() && + "Didn't expect unresolved member types."); + typeMap[structID] = + spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + } else { + auto structTy = spirv::StructType::getIdentified(context, structIdentifier); + typeMap[structID] = structTy; + + if (!unresolvedMemberTypes.empty()) { + deferredStructTypesInfos.push_back({structID, unresolvedMemberTypes, + memberTypes, offsetInfo, + memberDecorationsInfo}); + } else { + structTy.trySetBody(memberTypes, offsetInfo, memberDecorationsInfo); + } + } + // TODO: Update StructType to have member name as attribute as // well. return success(); @@ -2312,6 +2442,8 @@ return processPhi(operands); case spirv::Opcode::OpUndef: return processUndef(operands); + case spirv::Opcode::OpTypeForwardPointer: + return processTypeForwardPointer(operands); default: break; } @@ -2330,6 +2462,20 @@ return success(); } +LogicalResult +Deserializer::processTypeForwardPointer(ArrayRef operands) { + if (operands.size() != 2) { + return emitError(unknownLoc, + "OpTypeForwardPointer instruction must have two operands"); + } + + typeForwardPointerIDs.insert(operands[0]); + // TODO Use the 2nd operand (Storage Class) to validate the OpTypePointer + // instruction that defines the actual type. + + return success(); +} + LogicalResult Deserializer::processExtInst(ArrayRef operands) { if (operands.size() < 4) { return emitError(unknownLoc, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -116,6 +116,16 @@ namespace { +/// Recursive struct references are serialized as OpTypePointer instructions to +/// the recursive struct type. However, the OpTypePointer instruction cannot be +/// emitted before the recursive struct's OpTypeStruct. +/// RecursiveStructPointerInfo stores the data needed to emit such OpTypePointer +/// instructions after forward references to such types. +struct RecursiveStructPointerInfo { + uint32_t pointerTypeID; + spirv::StorageClass storageClass; +}; + /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each @@ -247,13 +257,16 @@ /// Main dispatch method for serializing a type. The result of the /// serialized type will be returned as `typeID`. - LogicalResult processType(Location loc, Type type, uint32_t &typeID); + LogicalResult processType(Location loc, Type type, uint32_t &typeID, + llvm::SetVector &serializationCtx); /// Method for preparing basic SPIR-V type serialization. Returns the type's /// opcode and operands for the instruction via `typeEnum` and `operands`. LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, - SmallVectorImpl &operands); + SmallVectorImpl &operands, + bool &deferSerialization, + llvm::SetVector &serializationCtx); LogicalResult prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, @@ -420,6 +433,10 @@ SmallVector typesGlobalValues; SmallVector functions; + // Maps spirv::StructType to its recursive reference member info. + DenseMap> + recursiveStructInfos; + /// `functionHeader` contains all the instructions that must be in the first /// block in the function, and `functionBody` contains the rest. After /// processing FuncOp, the encoded instructions of a function are appended to @@ -650,7 +667,8 @@ if (!id) { id = getNextID(); uint32_t typeID = 0; - if (failed(processType(op.getLoc(), undefType, typeID)) || + llvm::SetVector serializationCtx; + if (failed(processType(op.getLoc(), undefType, typeID, serializationCtx)) || failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, {typeID, id}))) { return failure(); @@ -757,7 +775,8 @@ uint32_t fnTypeID = 0; // Generate type of the function. - processType(op.getLoc(), op.getType(), fnTypeID); + llvm::SetVector serializationCtx; + processType(op.getLoc(), op.getType(), fnTypeID, serializationCtx); // Add the function definition. SmallVector operands; @@ -768,7 +787,7 @@ } if (failed(processType(op.getLoc(), (resultTypes.empty() ? getVoidType() : resultTypes[0]), - resTypeID))) { + resTypeID, serializationCtx))) { return failure(); } operands.push_back(resTypeID); @@ -787,7 +806,8 @@ // Declare the parameters. for (auto arg : op.getArguments()) { uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + if (failed(processType(op.getLoc(), arg.getType(), argTypeID, + serializationCtx))) { return failure(); } auto argValueID = getNextID(); @@ -847,7 +867,9 @@ SmallVector elidedAttrs; uint32_t resultID = 0; uint32_t resultTypeID = 0; - if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { + llvm::SetVector serializationCtx; + if (failed(processType(op.getLoc(), op.getType(), resultTypeID, + serializationCtx))) { return failure(); } operands.push_back(resultTypeID); @@ -886,7 +908,9 @@ // Get TypeID. uint32_t resultTypeID = 0; SmallVector elidedAttrs; - if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { + llvm::SetVector serializationCtx; + if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID, + serializationCtx))) { return failure(); } @@ -972,30 +996,69 @@ return false; } -LogicalResult Serializer::processType(Location loc, Type type, - uint32_t &typeID) { +LogicalResult +Serializer::processType(Location loc, Type type, uint32_t &typeID, + llvm::SetVector &serializationCtx) { typeID = getTypeID(type); if (typeID) { return success(); } typeID = getNextID(); SmallVector operands; + operands.push_back(typeID); auto typeEnum = spirv::Opcode::OpTypeVoid; + bool deferSerialization = false; + if ((type.isa() && succeeded(prepareFunctionType(loc, type.cast(), typeEnum, operands))) || - succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) { - typeIDMap[type] = typeID; - return encodeInstructionInto(typesGlobalValues, typeEnum, operands); + succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, + deferSerialization, serializationCtx))) { + if (deferSerialization) { + return success(); + } else { + typeIDMap[type] = typeID; + + if (failed( + encodeInstructionInto(typesGlobalValues, typeEnum, operands))) { + return failure(); + } + + if (recursiveStructInfos.count(type) != 0) { + // This recursive struct type is emitted already, now the OpTypePointer + // instructions referring to recursive references are emitted as well. + for (auto &ptrInfo : recursiveStructInfos[type]) { + // TODO This might not work if more than 1 recursive reference is + // present in the struct. + SmallVector ptrOperands; + ptrOperands.push_back(ptrInfo.pointerTypeID); + ptrOperands.push_back(static_cast(ptrInfo.storageClass)); + ptrOperands.push_back(typeIDMap[type]); + + if (failed(encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpTypePointer, + ptrOperands))) { + return failure(); + } + } + + recursiveStructInfos[type].clear(); + } + + return success(); + } } + return failure(); } -LogicalResult -Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, - spirv::Opcode &typeEnum, - SmallVectorImpl &operands) { +LogicalResult Serializer::prepareBasicType( + Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, + SmallVectorImpl &operands, bool &deferSerialization, + llvm::SetVector &serializationCtx) { + deferSerialization = false; + if (isVoidType(type)) { typeEnum = spirv::Opcode::OpTypeVoid; return success(); @@ -1025,7 +1088,8 @@ if (auto vectorType = type.dyn_cast()) { uint32_t elementTypeID = 0; - if (failed(processType(loc, vectorType.getElementType(), elementTypeID))) { + if (failed(processType(loc, vectorType.getElementType(), elementTypeID, + serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeVector; @@ -1037,7 +1101,8 @@ if (auto arrayType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; - if (failed(processType(loc, arrayType.getElementType(), elementTypeID))) { + if (failed(processType(loc, arrayType.getElementType(), elementTypeID, + serializationCtx))) { return failure(); } operands.push_back(elementTypeID); @@ -1050,9 +1115,47 @@ if (auto ptrType = type.dyn_cast()) { uint32_t pointeeTypeID = 0; - if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) { - return failure(); + spirv::StructType pointeeStruct = + ptrType.getPointeeType().dyn_cast(); + + if (pointeeStruct && !pointeeStruct.getIdentifier().empty() && + serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { + // A recursive reference to an enclosing struct is found. + // + // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage + // class as operands. + SmallVector forwardPtrOperands; + forwardPtrOperands.push_back(resultID); + forwardPtrOperands.push_back( + static_cast(ptrType.getStorageClass())); + + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpTypeForwardPointer, + forwardPtrOperands); + + // 2. Find the the pointee (enclosing) struct. + auto structType = spirv::StructType::lookupIdentified( + module.getContext(), pointeeStruct.getIdentifier()); + + if (!structType) { + return failure(); + } + + // 3. Mark the OpTypePointer that is supposed to be emitted by this call + // as deferred. + deferSerialization = true; + + // 4. Record the info needed to emit the deferred OpTypePointer + // instruction when the enclosing struct is completely serialized. + recursiveStructInfos[structType].push_back( + {resultID, ptrType.getStorageClass()}); + } else { + if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID, + serializationCtx))) { + return failure(); + } } + typeEnum = spirv::Opcode::OpTypePointer; operands.push_back(static_cast(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); @@ -1062,7 +1165,7 @@ if (auto runtimeArrayType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processType(loc, runtimeArrayType.getElementType(), - elementTypeID))) { + elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeRuntimeArray; @@ -1071,12 +1174,17 @@ } if (auto structType = type.dyn_cast()) { + if (!structType.getIdentifier().empty()) { + processName(resultID, structType.getIdentifier()); + serializationCtx.insert(structType.getIdentifier()); + } + bool hasOffset = structType.hasOffset(); for (auto elementIndex : llvm::seq(0, structType.getNumElements())) { uint32_t elementTypeID = 0; if (failed(processType(loc, structType.getElementType(elementIndex), - elementTypeID))) { + elementTypeID, serializationCtx))) { return failure(); } operands.push_back(elementTypeID); @@ -1094,6 +1202,7 @@ } SmallVector memberDecorations; structType.getMemberDecorations(memberDecorations); + for (auto &memberDecoration : memberDecorations) { if (failed(processMemberDecoration(resultID, memberDecoration))) { return emitError(loc, "cannot decorate ") @@ -1102,7 +1211,13 @@ << stringifyDecoration(memberDecoration.decoration); } } + typeEnum = spirv::Opcode::OpTypeStruct; + + if (!structType.getIdentifier().empty()) { + serializationCtx.remove(structType.getIdentifier()); + } + return success(); } @@ -1110,7 +1225,7 @@ type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processType(loc, cooperativeMatrixType.getElementType(), - elementTypeID))) { + elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; @@ -1128,7 +1243,8 @@ if (auto matrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; - if (failed(processType(loc, matrixType.getColumnType(), elementTypeID))) { + if (failed(processType(loc, matrixType.getColumnType(), elementTypeID, + serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeMatrix; @@ -1149,15 +1265,16 @@ assert(type.getNumResults() <= 1 && "serialization supports only a single return value"); uint32_t resultID = 0; + llvm::SetVector serializationCtx; if (failed(processType( loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), - resultID))) { + resultID, serializationCtx))) { return failure(); } operands.push_back(resultID); for (auto &res : type.getInputs()) { uint32_t argTypeID = 0; - if (failed(processType(loc, res, argTypeID))) { + if (failed(processType(loc, res, argTypeID, serializationCtx))) { return failure(); } operands.push_back(argTypeID); @@ -1183,7 +1300,8 @@ } uint32_t typeID = 0; - if (failed(processType(loc, constType, typeID))) { + llvm::SetVector serializationCtx; + if (failed(processType(loc, constType, typeID, serializationCtx))) { return 0; } @@ -1209,7 +1327,8 @@ uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, ArrayAttr attr) { uint32_t typeID = 0; - if (failed(processType(loc, constType, typeID))) { + llvm::SetVector serializationCtx; + if (failed(processType(loc, constType, typeID, serializationCtx))) { return 0; } @@ -1251,7 +1370,8 @@ } uint32_t typeID = 0; - if (failed(processType(loc, constType, typeID))) { + llvm::SetVector serializationCtx; + if (failed(processType(loc, constType, typeID, serializationCtx))) { return 0; } @@ -1300,7 +1420,8 @@ // Process the type for this bool literal uint32_t typeID = 0; - if (failed(processType(loc, boolAttr.getType(), typeID))) { + llvm::SetVector serializationCtx; + if (failed(processType(loc, boolAttr.getType(), typeID, serializationCtx))) { return 0; } @@ -1329,7 +1450,8 @@ // Process the type for this integer literal uint32_t typeID = 0; - if (failed(processType(loc, intAttr.getType(), typeID))) { + llvm::SetVector serializationCtx; + if (failed(processType(loc, intAttr.getType(), typeID, serializationCtx))) { return 0; } @@ -1395,7 +1517,8 @@ // Process the type for this float literal uint32_t typeID = 0; - if (failed(processType(loc, floatAttr.getType(), typeID))) { + llvm::SetVector serializationCtx; + if (failed(processType(loc, floatAttr.getType(), typeID, serializationCtx))) { return 0; } @@ -1515,7 +1638,9 @@ // Get the type and result for this OpPhi instruction. uint32_t phiTypeID = 0; - if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) + llvm::SetVector serializationCtx; + if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID, + serializationCtx))) return failure(); uint32_t phiID = getNextID(); @@ -1883,7 +2008,8 @@ uint32_t resTypeID = 0; Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); - if (failed(processType(op.getLoc(), resultTy, resTypeID))) + llvm::SetVector serializationCtx; + if (failed(processType(op.getLoc(), resultTy, resTypeID, serializationCtx))) return failure(); auto funcID = getOrCreateFunctionID(funcName); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -893,6 +893,7 @@ if (locationMap) (*locationMap)[op] = std::make_pair(line, col); } + llvm::SetVector &getStructContext() { return structContext; } private: /// Collection of OpAsm interfaces implemented in the context. @@ -906,6 +907,8 @@ /// An optional location map to be populated. AsmState::LocationMap *locationMap; + + llvm::SetVector structContext; }; } // end namespace detail } // end namespace mlir @@ -962,6 +965,10 @@ void printAffineConstraint(AffineExpr expr, bool isEq); void printIntegerSet(IntegerSet set); + llvm::SetVector &getStructContext() { + return state->getStructContext(); + } + protected: void printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs = {}, @@ -1741,6 +1748,10 @@ /// Print the given type to the stream. void printType(Type type) override { printer.printType(type); } + llvm::SetVector &getStructContext() override { + return printer.getStructContext(); + } + /// The main module printer. ModulePrinter &printer; }; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -345,6 +345,8 @@ UnknownLoc unknownLocAttr; DictionaryAttr emptyDictionaryAttr; + llvm::SetVector structContext; + public: MLIRContextImpl() : identifiers(identifierAllocator) {} ~MLIRContextImpl() { @@ -953,3 +955,7 @@ return reinterpret_cast( ctx->getImpl().unknownLocAttr.getAsOpaquePointer()); } + +llvm::SetVector &MLIRContext::getStructContext() { + return impl->structContext; +} diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -308,6 +308,10 @@ return parser.parseDimensionListRanked(dimensions, allowDynamic); } + llvm::SetVector &getStructContext() override { + return parser.getStructContext(); + } + OptionalParseResult parseOptionalType(Type &result) override { return parser.parseOptionalType(result); } diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -279,6 +279,10 @@ function_ref parseElement, OpAsmParser::Delimiter delimiter); + llvm::SetVector &getStructContext() { + return state.getStructContext(); + } + private: /// The Parser is subclassed and reinstantiated. Do not add additional /// non-trivial state here, add it to the ParserState class. diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -11,6 +11,7 @@ #include "Lexer.h" #include "mlir/IR/Attributes.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" namespace mlir { @@ -63,6 +64,10 @@ ParserState(const ParserState &) = delete; void operator=(const ParserState &) = delete; + llvm::SetVector &getStructContext() { + return context->getStructContext(); + } + /// The context we're parsing into. MLIRContext *const context; diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -89,6 +89,33 @@ // Parametric Storage //===--------------------------------------------------------------------===// + BaseStorage *lookup(TypeID id, unsigned hashValue, + function_ref isEqual) { + ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual}; + ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; + if (!threadingIsEnabled) + return lookupUnsafe(storageUniquer, lookupKey); + + // Check for an existing instance in read-only mode. + { + llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); + auto it = storageUniquer.instances.find_as(lookupKey); + if (it != storageUniquer.instances.end()) + return it->storage; + } + + return nullptr; + } + + BaseStorage *lookupUnsafe(ParametricStorageUniquer &storageUniquer, + ParametricStorageUniquer::LookupKey &lookupKey) { + auto it = storageUniquer.instances.find_as(lookupKey); + if (it != storageUniquer.instances.end()) + return it->storage; + + return nullptr; + } + /// Get or create an instance of a parametric type. BaseStorage * getOrCreate(TypeID id, unsigned hashValue, @@ -205,6 +232,12 @@ impl->threadingIsEnabled = !disable; } +auto StorageUniquer::lookupImpl(const TypeID &id, unsigned hashValue, + function_ref isEqual) + -> BaseStorage * { + return impl->lookup(id, hashValue, isEqual); +} + /// Implementation for getting/creating an instance of a derived type with /// parametric storage. auto StorageUniquer::getParametricStorageTypeImpl( diff --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir --- a/mlir/test/Conversion/GPUToSPIRV/if.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/if.mlir @@ -133,20 +133,20 @@ // VariablePointer capability is supported. This test is still useful to // make sure we can handle scf op result with type change. // CHECK-LABEL: @simple_if_yield_type_change - // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0]>, StorageBuffer>, Function> + // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0])>, StorageBuffer>, Function> // CHECK: spv.selection { // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] // CHECK-NEXT: [[TRUE]]: - // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0])>, StorageBuffer> // CHECK: spv.Branch ^[[MERGE:.*]] // CHECK-NEXT: [[FALSE]]: - // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0])>, StorageBuffer> // CHECK: spv.Branch ^[[MERGE]] // CHECK-NEXT: ^[[MERGE]]: // CHECK: spv._merge // CHECK-NEXT: } - // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr [0]>, StorageBuffer> - // CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr [0]>, StorageBuffer> + // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr [0])>, StorageBuffer> // CHECK: spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32 // CHECK: spv.Return gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -25,9 +25,9 @@ // CHECK-DAG: spv.globalVariable @[[$LOCALINVOCATIONIDVAR:.*]] built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable @[[$WORKGROUPIDVAR:.*]] built_in("WorkgroupId") : !spv.ptr, Input> // CHECK-LABEL: spv.func @load_store_kernel - // CHECK-SAME: %[[ARG0:.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>} - // CHECK-SAME: %[[ARG1:.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} - // CHECK-SAME: %[[ARG2:.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>} + // CHECK-SAME: %[[ARG0:.*]]: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>} + // CHECK-SAME: %[[ARG1:.*]]: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} + // CHECK-SAME: %[[ARG2:.*]]: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>} // CHECK-SAME: %[[ARG3:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>} // CHECK-SAME: %[[ARG4:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>} // CHECK-SAME: %[[ARG5:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>} diff --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir --- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir @@ -9,7 +9,7 @@ // CHECK: spv.func // CHECK-SAME: {{%.*}}: f32 // CHECK-NOT: spv.interface_var_abi - // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, CrossWorkgroup> + // CHECK-SAME: {{%.*}}: !spv.ptr [0])>, CrossWorkgroup> // CHECK-NOT: spv.interface_var_abi // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -5,7 +5,7 @@ // CHECK: spv.module @{{.*}} Logical GLSL450 { // CHECK-LABEL: spv.func @basic_module_structure // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>} - // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} + // CHECK-SAME: {{%.*}}: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>) kernel attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { @@ -32,7 +32,7 @@ // CHECK-LABEL: spv.func @basic_module_structure_preset_ABI // CHECK-SAME: {{%[a-zA-Z0-9_]*}}: f32 // CHECK-SAME: spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer> - // CHECK-SAME: !spv.ptr [0]>, StorageBuffer> + // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // CHECK-SAME: spv.interface_var_abi = #spv.interface_var_abi<(3, 0)> // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure_preset_ABI( diff --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir --- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir +++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir @@ -6,12 +6,12 @@ module attributes {gpu.container_module} { spv.module Logical GLSL450 requires #spv.vce { - spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr [0])>, StorageBuffer> spv.func @kernel() "None" attributes {workgroup_attributions = 0 : i64} { - %0 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> + %0 = spv._address_of @kernel_arg_0 : !spv.ptr [0])>, StorageBuffer> %2 = spv.constant 0 : i32 - %3 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> - %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %3 = spv._address_of @kernel_arg_0 : !spv.ptr [0])>, StorageBuffer> + %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0])>, StorageBuffer>, i32, i32 %5 = spv.Load "StorageBuffer" %4 : f32 spv.Return } diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -8,10 +8,10 @@ spv.func @access_chain() "None" { // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 %0 = spv.constant 1: i32 - %1 = spv.Variable : !spv.ptr>, Function> + %1 = spv.Variable : !spv.ptr)>, Function> // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr)>>, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.ptr - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Function>, i32, i32 spv.Return } @@ -38,9 +38,9 @@ // CHECK: llvm.mlir.global private @struct() : !llvm.struct)> // CHECK-LABEL: @func // CHECK: llvm.mlir.addressof @struct : !llvm.ptr)>> - spv.globalVariable @struct : !spv.ptr>, Private> + spv.globalVariable @struct : !spv.ptr)>, Private> spv.func @func() "None" { - %0 = spv._address_of @struct : !spv.ptr>, Private> + %0 = spv._address_of @struct : !spv.ptr)>, Private> spv.Return } } @@ -124,10 +124,10 @@ } // CHECK-LABEL: @store_composite -spv.func @store_composite(%arg0 : !spv.struct) "None" { - %0 = spv.Variable : !spv.ptr, Function> +spv.func @store_composite(%arg0 : !spv.struct<(f64)>) "None" { + %0 = spv.Variable : !spv.ptr, Function> // CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm.ptr> - spv.Store "Function" %0, %arg0 : !spv.struct + spv.Store "Function" %0, %arg0 : !spv.struct<(f64)> spv.Return } diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir @@ -8,13 +8,13 @@ // ----- // expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} -spv.func @struct_with_unnatural_offset(%arg: !spv.struct) -> () "None" { +spv.func @struct_with_unnatural_offset(%arg: !spv.struct<(i32[0], i32[8])>) -> () "None" { spv.Return } // ----- // expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} -spv.func @struct_with_decorations(%arg: !spv.struct) -> () "None" { +spv.func @struct_with_decorations(%arg: !spv.struct<(f32 [RelaxedPrecision])>) -> () "None" { spv.Return } diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir @@ -35,10 +35,10 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: @struct(!llvm.struct) -spv.func @struct(!spv.struct) "None" +spv.func @struct(!spv.struct<(f64)>) "None" // CHECK-LABEL: @struct_nested(!llvm.struct)>) -spv.func @struct_nested(!spv.struct>) "None" +spv.func @struct_nested(!spv.struct<(i32, !spv.struct<(i64, i32)>)>) "None" // CHECK-LABEL: @struct_with_natural_offset(!llvm.struct<(i8, i32)>) -spv.func @struct_with_natural_offset(!spv.struct) "None" +spv.func @struct_with_natural_offset(!spv.struct<(i8[0], i32[4])>) "None" diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -17,7 +17,7 @@ return } } -// CHECK: spv.globalVariable @[[VAR:.+]] : !spv.ptr>, Workgroup> +// CHECK: spv.globalVariable @[[VAR:.+]] : !spv.ptr)>, Workgroup> // CHECK: func @alloc_dealloc_workgroup_mem // CHECK-NOT: alloc // CHECK: %[[PTR:.+]] = spv._address_of @[[VAR]] @@ -45,7 +45,7 @@ } // CHECK: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr>, Workgroup> +// CHECK-SAME: !spv.ptr)>, Workgroup> // CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem // CHECK: %[[VAR:.+]] = spv._address_of @__workgroup_mem__0 // CHECK: %[[LOC:.+]] = spv.SDiv @@ -72,9 +72,9 @@ } // CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr>, Workgroup> +// CHECK-SAME: !spv.ptr)>, Workgroup> // CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr>, Workgroup> +// CHECK-SAME: !spv.ptr)>, Workgroup> // CHECK: spv.func @two_allocs() // CHECK: spv.Return @@ -93,9 +93,9 @@ } // CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr, stride=8>>, Workgroup> +// CHECK-SAME: !spv.ptr, stride=8>)>, Workgroup> // CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr, stride=16>>, Workgroup> +// CHECK-SAME: !spv.ptr, stride=16>)>, Workgroup> // CHECK: spv.func @two_allocs_vector() // CHECK: spv.Return diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -682,8 +682,8 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: @load_store_zero_rank_float -// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, -// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +// CHECK: [[ARG0:%.*]]: !spv.ptr [0])>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0])>, StorageBuffer>) func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ @@ -701,8 +701,8 @@ } // CHECK-LABEL: @load_store_zero_rank_int -// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, -// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +// CHECK: [[ARG0:%.*]]: !spv.ptr [0])>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0])>, StorageBuffer>) func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -274,35 +274,35 @@ } { // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } // CHECK-LABEL: spv.func @memref_8bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return } // CHECK-LABEL: spv.func @memref_8bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0]>, Input> +// CHECK-SAME: !spv.ptr [0])>, Input> func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0]>, Output> +// CHECK-SAME: !spv.ptr [0])>, Output> func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return } } // end module @@ -319,12 +319,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> func @memref_16bit_PushConstant( %arg0: memref<16xi16, 7>, %arg1: memref<16xf16, 7> @@ -344,12 +344,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_16bit_StorageBuffer( %arg0: memref<16xi16, 0>, %arg1: memref<16xf16, 0> @@ -369,12 +369,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> func @memref_16bit_Uniform( %arg0: memref<16xi16, 4>, %arg1: memref<16xf16, 4> @@ -393,11 +393,11 @@ } { // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0]>, Input> +// CHECK-SAME: !spv.ptr [0])>, Input> func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0]>, Output> +// CHECK-SAME: !spv.ptr [0])>, Output> func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return } } // end module @@ -412,22 +412,22 @@ // CHECK-LABEL: spv.func @memref_offset_strides func @memref_offset_strides( -// CHECK-SAME: !spv.array<64 x f32, stride=4> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<72 x f32, stride=4> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<256 x f32, stride=4> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<64 x f32, stride=4> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<88 x f32, stride=4> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f32, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<72 x f32, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<256 x f32, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f32, stride=4> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<88 x f32, stride=4> [0])>, StorageBuffer> %arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>, // tightly packed; row major %arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>, // offset 8 %arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row %arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major %arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col -// CHECK-SAME: !spv.array<64 x f16, stride=2> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<72 x f16, stride=2> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<256 x f16, stride=2> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<64 x f16, stride=2> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<88 x f16, stride=2> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<72 x f16, stride=2> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<256 x f16, stride=2> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer> +// CHECK-SAME: !spv.array<88 x f16, stride=2> [0])>, StorageBuffer> %arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>, %arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>, %arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>, @@ -450,8 +450,8 @@ func @unranked_memref(%arg0: memref<*xi32>) { return } // CHECK-LABEL: func @dynamic_dim_memref -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @dynamic_dim_memref(%arg0: memref<8x?xi32>, %arg1: memref) { return } @@ -466,16 +466,16 @@ } { // CHECK-LABEL: func @memref_vector -// CHECK-SAME: !spv.ptr, stride=8> [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr, stride=16> [0]>, Uniform> +// CHECK-SAME: !spv.ptr, stride=8> [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr, stride=16> [0])>, Uniform> func @memref_vector( %arg0: memref<4xvector<2xf32>, 0>, %arg1: memref<4xvector<4xf32>, 4>) { return } // CHECK-LABEL: func @dynamic_dim_memref_vector -// CHECK-SAME: !spv.ptr, stride=16> [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr, stride=8> [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr, stride=8> [0])>, StorageBuffer> func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>, %arg1: memref>) { return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir @@ -1,10 +1,10 @@ // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { - spv.func @composite_insert(%arg0 : !spv.struct, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct, f32>> "None" { - // CHECK: spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct, f32>> - %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct, f32>> - spv.ReturnValue %0: !spv.struct, f32>> + spv.func @composite_insert(%arg0 : !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)>, %arg1: !spv.array<4xf32>) -> !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> "None" { + // CHECK: spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<(f32, !spv.struct<(!spv.array<4 x f32>, f32)>)> + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> + spv.ReturnValue %0: !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> } spv.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> "None" { // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> diff --git a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir @@ -29,9 +29,9 @@ spv.Return } - spv.func @composite(%arg0 : !spv.struct, f32>>, %arg1: !spv.array<4xf32>, %arg2 : f32, %arg3 : f32) "None" { + spv.func @composite(%arg0 : !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)>, %arg1: !spv.array<4xf32>, %arg2 : f32, %arg3 : f32) "None" { // CHECK: loc({{".*debug.mlir"}}:34:10) - %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct, f32>> + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> // CHECK: loc({{".*debug.mlir"}}:36:10) %1 = spv.CompositeConstruct %arg2, %arg3 : vector<2xf32> spv.Return diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -60,14 +60,14 @@ // ----- spv.module Logical GLSL450 requires #spv.vce { - spv.globalVariable @GV1 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> - spv.globalVariable @GV2 bind(0, 1) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @GV1 bind(0, 0) : !spv.ptr [0])>, StorageBuffer> + spv.globalVariable @GV2 bind(0, 1) : !spv.ptr [0])>, StorageBuffer> spv.func @loop_kernel() "None" { - %0 = spv._address_of @GV1 : !spv.ptr [0]>, StorageBuffer> + %0 = spv._address_of @GV1 : !spv.ptr [0])>, StorageBuffer> %1 = spv.constant 0 : i32 - %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer>, i32 - %3 = spv._address_of @GV2 : !spv.ptr [0]>, StorageBuffer> - %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer>, i32 + %2 = spv.AccessChain %0[%1] : !spv.ptr [0])>, StorageBuffer>, i32 + %3 = spv._address_of @GV2 : !spv.ptr [0])>, StorageBuffer> + %5 = spv.AccessChain %3[%1] : !spv.ptr [0])>, StorageBuffer>, i32 %6 = spv.constant 4 : i32 %7 = spv.constant 42 : i32 %8 = spv.constant 2 : i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -27,32 +27,32 @@ // ----- spv.module Logical GLSL450 requires #spv.vce { - spv.func @load_store_zero_rank_float(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + spv.func @load_store_zero_rank_float(%arg0: !spv.ptr [0])>, StorageBuffer>, %arg1: !spv.ptr [0])>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0])> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 %2 = spv.Load "StorageBuffer" %1 : f32 - // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0])> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0])>, StorageBuffer>, i32, i32 spv.Store "StorageBuffer" %4, %2 : f32 spv.Return } - spv.func @load_store_zero_rank_int(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + spv.func @load_store_zero_rank_int(%arg0: !spv.ptr [0])>, StorageBuffer>, %arg1: !spv.ptr [0])>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0])> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 %2 = spv.Load "StorageBuffer" %1 : i32 - // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0])> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0])>, StorageBuffer>, i32, i32 spv.Store "StorageBuffer" %4, %2 : i32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir @@ -1,36 +1,52 @@ // RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { - // CHECK: !spv.ptr [0]>, Input> - spv.globalVariable @var0 bind(0, 1) : !spv.ptr [0]>, Input> + // CHECK: !spv.ptr [0])>, Input> + spv.globalVariable @var0 bind(0, 1) : !spv.ptr [0])>, Input> - // CHECK: !spv.ptr [4]> [4]>, Input> - spv.globalVariable @var1 bind(0, 2) : !spv.ptr [4]> [4]>, Input> + // CHECK: !spv.ptr [4])> [4])>, Input> + spv.globalVariable @var1 bind(0, 2) : !spv.ptr [4])> [4])>, Input> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @var2 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var2 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr [0]>, stride=512> [0]>, StorageBuffer> - spv.globalVariable @var3 : !spv.ptr [0]>, stride=512> [0]>, StorageBuffer> + // CHECK: !spv.ptr [0])>, stride=512> [0])>, StorageBuffer> + spv.globalVariable @var3 : !spv.ptr [0])>, stride=512> [0])>, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @var4 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var4 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @var5 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var5 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @var6 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var6 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr> [0, ColMajor, MatrixStride=16]>, StorageBuffer> - spv.globalVariable @var7 : !spv.ptr> [0, ColMajor, MatrixStride=16]>, StorageBuffer> + // CHECK: !spv.ptr> [0, ColMajor, MatrixStride=16])>, StorageBuffer> + spv.globalVariable @var7 : !spv.ptr> [0, ColMajor, MatrixStride=16])>, StorageBuffer> - // CHECK: !spv.ptr, StorageBuffer> - spv.globalVariable @empty : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @empty : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr [0]>, Input>, - // CHECK-SAME: !spv.ptr [0]>, Output> - spv.func @kernel(%arg0: !spv.ptr [0]>, Input>, %arg1: !spv.ptr [0]>, Output>) -> () "None" { + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @id_empty : !spv.ptr, StorageBuffer> + + // CHECK: !spv.ptr [0])>, Input> + spv.globalVariable @id_var0 : !spv.ptr [0])>, Input> + + + // CHECK: !spv.ptr, StorageBuffer>)>, StorageBuffer> + spv.globalVariable @recursive_simple : !spv.ptr, StorageBuffer>)>, StorageBuffer> + + // CHECK: !spv.ptr, Uniform>)>, Uniform>)>, Uniform> + spv.globalVariable @recursive_2 : !spv.ptr, Uniform>)>, Uniform>)>, Uniform> + + // CHECK: !spv.ptr, Uniform>, !spv.ptr, Uniform>)>, Uniform>)>, Uniform> + spv.globalVariable @recursive_3 : !spv.ptr, Uniform>, !spv.ptr, Uniform>)>, Uniform>)>, Uniform> + + // CHECK: !spv.ptr [0])>, Input>, + // CHECK-SAME: !spv.ptr [0])>, Output> + spv.func @kernel(%arg0: !spv.ptr [0])>, Input>, %arg1: !spv.ptr [0])>, Output>) -> () "None" { spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir @@ -13,10 +13,10 @@ // CHECK: {{%.*}} = spv.undef : !spv.array<4 x !spv.array<4 x i32>> %5 = spv.undef : !spv.array<4x!spv.array<4xi32>> %6 = spv.CompositeExtract %5[1 : i32, 2 : i32] : !spv.array<4x!spv.array<4xi32>> - // CHECK: {{%.*}} = spv.undef : !spv.ptr, StorageBuffer> - %7 = spv.undef : !spv.ptr, StorageBuffer> + // CHECK: {{%.*}} = spv.undef : !spv.ptr, StorageBuffer> + %7 = spv.undef : !spv.ptr, StorageBuffer> %8 = spv.constant 0 : i32 - %9 = spv.AccessChain %7[%8] : !spv.ptr, StorageBuffer>, i32 + %9 = spv.AccessChain %7[%8] : !spv.ptr, StorageBuffer>, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir @@ -5,14 +5,12 @@ } { spv.module Physical64 OpenCL { // CHECK-LABEL: spv.module - // CHECK: spv.func [[FN:@.*]]( - // CHECK-SAME: {{%.*}}: f32 - // CHECK-SAME: {{%.*}}: !spv.ptr>, CrossWorkgroup> + // CHECK: spv.func [[FN:@.*]]({{%.*}}: f32, {{%.*}}: !spv.ptr)>, CrossWorkgroup> // CHECK: spv.EntryPoint "Kernel" [[FN]] // CHECK: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1 spv.func @kernel( %arg0: f32, - %arg1: !spv.ptr>, CrossWorkgroup>) "None" + %arg1: !spv.ptr)>, CrossWorkgroup>) "None" attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -7,13 +7,13 @@ // CHECK-LABEL: spv.module spv.module Logical GLSL450 { - // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0])>, StorageBuffer> // CHECK: spv.func [[FN:@.*]]() spv.func @kernel( %arg0: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>}, - %arg1: !spv.ptr>, StorageBuffer> + %arg1: !spv.ptr)>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}) "None" attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]] diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -15,20 +15,20 @@ spv.globalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") spv.globalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr, Input> - // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, stride=16> [0]>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr, stride=16> [0]>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr, stride=16> [0]>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, stride=16> [0])>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr, stride=16> [0])>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr, stride=16> [0])>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr, StorageBuffer> // CHECK: spv.func [[FN:@.*]]() spv.func @load_store_kernel( - %arg0: !spv.ptr>>, StorageBuffer> + %arg0: !spv.ptr>)>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>}, - %arg1: !spv.ptr>>, StorageBuffer> + %arg1: !spv.ptr>)>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}, - %arg2: !spv.ptr>>, StorageBuffer> + %arg2: !spv.ptr>)>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>}, %arg3: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>}, @@ -103,14 +103,14 @@ %37 = spv.IAdd %arg4, %11 : i32 // CHECK: spv.AccessChain [[ARG0]] %c0 = spv.constant 0 : i32 - %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 + %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr>)>, StorageBuffer>, i32, i32, i32 %39 = spv.Load "StorageBuffer" %38 : f32 // CHECK: spv.AccessChain [[ARG1]] - %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 + %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr>)>, StorageBuffer>, i32, i32, i32 %41 = spv.Load "StorageBuffer" %40 : f32 %42 = spv.FAdd %39, %41 : f32 // CHECK: spv.AccessChain [[ARG2]] - %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 + %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr>)>, StorageBuffer>, i32, i32, i32 spv.Store "StorageBuffer" %43, %42 : f32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir @@ -33,11 +33,11 @@ // ----- spv.module Logical GLSL450 { - spv.globalVariable @data bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @data bind(0, 0) : !spv.ptr [0])>, StorageBuffer> spv.func @callee() "None" { - %0 = spv._address_of @data : !spv.ptr [0]>, StorageBuffer> + %0 = spv._address_of @data : !spv.ptr [0])>, StorageBuffer> %1 = spv.constant 0: i32 - %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0]>, StorageBuffer>, i32, i32 + %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0])>, StorageBuffer>, i32, i32 spv.Branch ^next ^next: @@ -184,8 +184,8 @@ // ----- spv.module Logical GLSL450 { - spv.globalVariable @arg_0 bind(0, 0) : !spv.ptr, StorageBuffer> - spv.globalVariable @arg_1 bind(0, 1) : !spv.ptr, StorageBuffer> + spv.globalVariable @arg_0 bind(0, 0) : !spv.ptr, StorageBuffer> + spv.globalVariable @arg_1 bind(0, 1) : !spv.ptr, StorageBuffer> // CHECK: @inline_into_selection_region spv.func @inline_into_selection_region() "None" { @@ -194,9 +194,9 @@ // CHECK-DAG: [[ADDRESS_ARG1:%.*]] = spv._address_of @arg_1 // CHECK-DAG: [[LOADPTR:%.*]] = spv.AccessChain [[ADDRESS_ARG0]] // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOADPTR]] - %2 = spv._address_of @arg_0 : !spv.ptr, StorageBuffer> - %3 = spv._address_of @arg_1 : !spv.ptr, StorageBuffer> - %4 = spv.AccessChain %2[%1] : !spv.ptr, StorageBuffer>, i32 + %2 = spv._address_of @arg_0 : !spv.ptr, StorageBuffer> + %3 = spv._address_of @arg_1 : !spv.ptr, StorageBuffer> + %4 = spv.AccessChain %2[%1] : !spv.ptr, StorageBuffer>, i32 %5 = spv.Load "StorageBuffer" %4 : i32 %6 = spv.SGreaterThan %5, %1 : i32 // CHECK: spv.selection @@ -204,7 +204,7 @@ spv.BranchConditional %6, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: [[STOREPTR:%.*]] = spv.AccessChain [[ADDRESS_ARG1]] - %7 = spv.AccessChain %3[%1] : !spv.ptr, StorageBuffer>, i32 + %7 = spv.AccessChain %3[%1] : !spv.ptr, StorageBuffer>, i32 // CHECK-NOT: spv.FunctionCall // CHECK: spv.AtomicIAdd "Device" "AcquireRelease" [[STOREPTR]], [[VAL]] // CHECK: spv.Branch diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir @@ -1,30 +1,30 @@ // RUN: mlir-opt -decorate-spirv-composite-type-layout -split-input-file -verify-diagnostics %s -o - | FileCheck %s spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @var0 bind(0, 1) : !spv.ptr [4], f32 [12]>, Uniform> - spv.globalVariable @var0 bind(0,1) : !spv.ptr, f32>, Uniform> + // CHECK: spv.globalVariable @var0 bind(0, 1) : !spv.ptr [4], f32 [12])>, Uniform> + spv.globalVariable @var0 bind(0,1) : !spv.ptr, f32)>, Uniform> - // CHECK: spv.globalVariable @var1 bind(0, 2) : !spv.ptr [0], f32 [256]>, StorageBuffer> - spv.globalVariable @var1 bind(0,2) : !spv.ptr, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var1 bind(0, 2) : !spv.ptr [0], f32 [256])>, StorageBuffer> + spv.globalVariable @var1 bind(0,2) : !spv.ptr, f32)>, StorageBuffer> - // CHECK: spv.globalVariable @var2 bind(1, 0) : !spv.ptr [0], f32 [256]> [0], i32 [260]>, StorageBuffer> - spv.globalVariable @var2 bind(1,0) : !spv.ptr, f32>, i32>, StorageBuffer> + // CHECK: spv.globalVariable @var2 bind(1, 0) : !spv.ptr [0], f32 [256])> [0], i32 [260])>, StorageBuffer> + spv.globalVariable @var2 bind(1,0) : !spv.ptr, f32)>, i32)>, StorageBuffer> - // CHECK: spv.globalVariable @var3 : !spv.ptr [8]>, stride=72> [0], f32 [1152]>, StorageBuffer> - spv.globalVariable @var3 : !spv.ptr>>, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var3 : !spv.ptr [8])>, stride=72> [0], f32 [1152])>, StorageBuffer> + spv.globalVariable @var3 : !spv.ptr)>>, f32)>, StorageBuffer> - // CHECK: spv.globalVariable @var4 bind(1, 2) : !spv.ptr [0], f32 [16], i1 [20]> [0], i1 [24]>, StorageBuffer> - spv.globalVariable @var4 bind(1,2) : !spv.ptr, f32, i1>, i1>, StorageBuffer> + // CHECK: spv.globalVariable @var4 bind(1, 2) : !spv.ptr [0], f32 [16], i1 [20])> [0], i1 [24])>, StorageBuffer> + spv.globalVariable @var4 bind(1,2) : !spv.ptr, f32, i1)>, i1)>, StorageBuffer> - // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr [0]>, StorageBuffer> - spv.globalVariable @var5 bind(1,3) : !spv.ptr>, StorageBuffer> + // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr [0])>, StorageBuffer> + spv.globalVariable @var5 bind(1,3) : !spv.ptr)>, StorageBuffer> spv.func @kernel() -> () "None" { %c0 = spv.constant 0 : i32 - // CHECK: {{%.*}} = spv._address_of @var0 : !spv.ptr [4], f32 [12]>, Uniform> - %0 = spv._address_of @var0 : !spv.ptr, f32>, Uniform> - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr [4], f32 [12]>, Uniform> - %1 = spv.AccessChain %0[%c0] : !spv.ptr, f32>, Uniform>, i32 + // CHECK: {{%.*}} = spv._address_of @var0 : !spv.ptr [4], f32 [12])>, Uniform> + %0 = spv._address_of @var0 : !spv.ptr, f32)>, Uniform> + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr [4], f32 [12])>, Uniform> + %1 = spv.AccessChain %0[%c0] : !spv.ptr, f32)>, Uniform>, i32 spv.Return } } @@ -32,68 +32,68 @@ // ----- spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @var0 : !spv.ptr [0], i1 [16]> [0], i1 [24]> [0], i1 [32]> [0], i1 [40]>, Uniform> - spv.globalVariable @var0 : !spv.ptr, i1>, i1>, i1>, i1>, Uniform> + // CHECK: spv.globalVariable @var0 : !spv.ptr [0], i1 [16])> [0], i1 [24])> [0], i1 [32])> [0], i1 [40])>, Uniform> + spv.globalVariable @var0 : !spv.ptr, i1)>, i1)>, i1)>, i1)>, Uniform> - // CHECK: spv.globalVariable @var1 : !spv.ptr [8], f32 [24]> [0], f32 [32]>, Uniform> - spv.globalVariable @var1 : !spv.ptr, f32>, f32>, Uniform> + // CHECK: spv.globalVariable @var1 : !spv.ptr [8], f32 [24])> [0], f32 [32])>, Uniform> + spv.globalVariable @var1 : !spv.ptr, f32)>, f32)>, Uniform> - // CHECK: spv.globalVariable @var2 : !spv.ptr, stride=128> [8]> [8], f32 [2064]> [0], f32 [2072]>, Uniform> - spv.globalVariable @var2 : !spv.ptr>>, f32>, f32>, Uniform> + // CHECK: spv.globalVariable @var2 : !spv.ptr, stride=128> [8])> [8], f32 [2064])> [0], f32 [2072])>, Uniform> + spv.globalVariable @var2 : !spv.ptr>)>, f32)>, f32)>, Uniform> - // CHECK: spv.globalVariable @var3 : !spv.ptr [0], i1 [512]> [0], i1 [520]>, Uniform> - spv.globalVariable @var3 : !spv.ptr, i1>, i1>, Uniform> + // CHECK: spv.globalVariable @var3 : !spv.ptr [0], i1 [512])> [0], i1 [520])>, Uniform> + spv.globalVariable @var3 : !spv.ptr, i1)>, i1)>, Uniform> - // CHECK: spv.globalVariable @var4 : !spv.ptr [8], i1 [24]>, Uniform> - spv.globalVariable @var4 : !spv.ptr, i1>, Uniform> + // CHECK: spv.globalVariable @var4 : !spv.ptr [8], i1 [24])>, Uniform> + spv.globalVariable @var4 : !spv.ptr, i1)>, Uniform> - // CHECK: spv.globalVariable @var5 : !spv.ptr [8], i1 [24]>, Uniform> - spv.globalVariable @var5 : !spv.ptr, i1>, Uniform> + // CHECK: spv.globalVariable @var5 : !spv.ptr [8], i1 [24])>, Uniform> + spv.globalVariable @var5 : !spv.ptr, i1)>, Uniform> - // CHECK: spv.globalVariable @var6 : !spv.ptr [8], i1 [24]>, Uniform> - spv.globalVariable @var6 : !spv.ptr, i1>, Uniform> + // CHECK: spv.globalVariable @var6 : !spv.ptr [8], i1 [24])>, Uniform> + spv.globalVariable @var6 : !spv.ptr, i1)>, Uniform> - // CHECK: spv.globalVariable @var7 : !spv.ptr [0], i1 [16]> [8], i1 [32]>, Uniform> - spv.globalVariable @var7 : !spv.ptr, i1>, i1>, Uniform> + // CHECK: spv.globalVariable @var7 : !spv.ptr [0], i1 [16])> [8], i1 [32])>, Uniform> + spv.globalVariable @var7 : !spv.ptr, i1)>, i1)>, Uniform> } // ----- spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @var0 : !spv.ptr [0], f32 [8]>, StorageBuffer> - spv.globalVariable @var0 : !spv.ptr, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var0 : !spv.ptr [0], f32 [8])>, StorageBuffer> + spv.globalVariable @var0 : !spv.ptr, f32)>, StorageBuffer> - // CHECK: spv.globalVariable @var1 : !spv.ptr [0], f32 [12]>, StorageBuffer> - spv.globalVariable @var1 : !spv.ptr, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var1 : !spv.ptr [0], f32 [12])>, StorageBuffer> + spv.globalVariable @var1 : !spv.ptr, f32)>, StorageBuffer> - // CHECK: spv.globalVariable @var2 : !spv.ptr [0], f32 [16]>, StorageBuffer> - spv.globalVariable @var2 : !spv.ptr, f32>, StorageBuffer> + // CHECK: spv.globalVariable @var2 : !spv.ptr [0], f32 [16])>, StorageBuffer> + spv.globalVariable @var2 : !spv.ptr, f32)>, StorageBuffer> } // ----- spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @emptyStructAsMember : !spv.ptr [0]>, StorageBuffer> - spv.globalVariable @emptyStructAsMember : !spv.ptr>, StorageBuffer> + // CHECK: spv.globalVariable @emptyStructAsMember : !spv.ptr [0])>, StorageBuffer> + spv.globalVariable @emptyStructAsMember : !spv.ptr)>, StorageBuffer> // CHECK: spv.globalVariable @arrayType : !spv.ptr>, StorageBuffer> spv.globalVariable @arrayType : !spv.ptr>, StorageBuffer> - // CHECK: spv.globalVariable @InputStorage : !spv.ptr>, Input> - spv.globalVariable @InputStorage : !spv.ptr>, Input> + // CHECK: spv.globalVariable @InputStorage : !spv.ptr)>, Input> + spv.globalVariable @InputStorage : !spv.ptr)>, Input> - // CHECK: spv.globalVariable @customLayout : !spv.ptr, Uniform> - spv.globalVariable @customLayout : !spv.ptr, Uniform> + // CHECK: spv.globalVariable @customLayout : !spv.ptr, Uniform> + spv.globalVariable @customLayout : !spv.ptr, Uniform> - // CHECK: spv.globalVariable @emptyStruct : !spv.ptr, Uniform> - spv.globalVariable @emptyStruct : !spv.ptr, Uniform> + // CHECK: spv.globalVariable @emptyStruct : !spv.ptr, Uniform> + spv.globalVariable @emptyStruct : !spv.ptr, Uniform> } // ----- spv.module Logical GLSL450 { - // CHECK: spv.globalVariable @var0 : !spv.ptr, PushConstant> - spv.globalVariable @var0 : !spv.ptr, PushConstant> - // CHECK: spv.globalVariable @var1 : !spv.ptr, PhysicalStorageBuffer> - spv.globalVariable @var1 : !spv.ptr, PhysicalStorageBuffer> + // CHECK: spv.globalVariable @var0 : !spv.ptr, PushConstant> + spv.globalVariable @var0 : !spv.ptr, PushConstant> + // CHECK: spv.globalVariable @var1 : !spv.ptr, PhysicalStorageBuffer> + spv.globalVariable @var1 : !spv.ptr, PhysicalStorageBuffer> } diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir @@ -15,16 +15,16 @@ %7 = spv.CompositeInsert %value2, %6[2 : i32] : f32 into !spv.array<4xf32> %8 = spv.CompositeInsert %value0, %7[3 : i32] : f32 into !spv.array<4xf32> - %9 = spv.undef : !spv.struct - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct - %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct - %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct - %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct + %9 = spv.undef : !spv.struct<(f32, i32, f32)> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct<(f32, i32, f32)> + %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct<(f32, i32, f32)> + %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct<(f32, i32, f32)> + %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct<(f32, i32, f32)> - %13 = spv.undef : !spv.struct> - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct> - %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct> - %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct> + %13 = spv.undef : !spv.struct<(f32, !spv.array<3xf32>)> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct<(f32, !spv.array<3 x f32>)> + %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct<(f32, !spv.array<3xf32>)> + %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct<(f32, !spv.array<3xf32>)> spv.ReturnValue %3 : vector<3xf32> } diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -180,6 +180,6 @@ #spv.vce, {}> } { - spv.globalVariable @data : !spv.ptr, Uniform> + spv.globalVariable @data : !spv.ptr, Uniform> spv.globalVariable @img : !spv.ptr, UniformConstant> } diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -10,8 +10,8 @@ // CHECK-NEXT: %[[PTR:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] // CHECK-NEXT: spv.Load "Function" %[[PTR]] %c0 = spv.constant 0: i32 - %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>)>, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>)>, Function>, i32 %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr>, Function>, i32, i32 %3 = spv.Load "Function" %2 : f32 spv.ReturnValue %3 : f32 @@ -27,8 +27,8 @@ // CHECK-NEXT: spv.Load "Function" %[[PTR_0]] // CHECK-NEXT: spv.Load "Function" %[[PTR_1]] %c0 = spv.constant 0: i32 - %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>)>, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>)>, Function>, i32 %2 = spv.AccessChain %1[%c0] : !spv.ptr>, Function>, i32 %3 = spv.AccessChain %2[%c0] : !spv.ptr, Function>, i32 %4 = spv.Load "Function" %2 : !spv.array<4xf32> @@ -47,10 +47,10 @@ // CHECK-NEXT: spv.Load "Function" %[[VAR_0_PTR]] // CHECK-NEXT: spv.Load "Function" %[[VAR_1_PTR]] %c1 = spv.constant 1: i32 - %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 - %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>)>, Function> + %1 = spv.Variable : !spv.ptr>, !spv.array<4xi32>)>, Function> + %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>)>, Function>, i32 + %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>)>, Function>, i32 %4 = spv.Load "Function" %2 : !spv.array<4xi32> %5 = spv.Load "Function" %3 : !spv.array<4xi32> spv.ReturnValue %4 : !spv.array<4xi32> diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -12,10 +12,10 @@ // ----- -func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct) -> !spv.struct, !spv.array<4xf32>, !spv.struct> { - // CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct, !spv.array<4 x f32>, !spv.struct> - %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct, !spv.array<4xf32>, !spv.struct> - return %0: !spv.struct, !spv.array<4xf32>, !spv.struct> +func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct<(f32)>) -> !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> { + // CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4 x f32>, !spv.struct<(f32)>)> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> + return %0: !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> } // ----- @@ -28,10 +28,10 @@ // ----- -func @composite_construct_empty_struct() -> !spv.struct<> { - // CHECK: spv.CompositeConstruct : !spv.struct<> - %0 = spv.CompositeConstruct : !spv.struct<> - return %0: !spv.struct<> +func @composite_construct_empty_struct() -> !spv.struct<()> { + // CHECK: spv.CompositeConstruct : !spv.struct<()> + %0 = spv.CompositeConstruct : !spv.struct<()> + return %0: !spv.struct<()> } // ----- @@ -80,9 +80,9 @@ // ----- -func @composite_extract_struct(%arg0 : !spv.struct>) -> f32 { - // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spv.struct> - %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct> +func @composite_extract_struct(%arg0 : !spv.struct<(f32, !spv.array<4xf32>)>) -> f32 { + // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spv.struct<(f32, !spv.array<4 x f32>)> + %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct<(f32, !spv.array<4xf32>)> return %0 : f32 } @@ -156,9 +156,9 @@ // ----- -func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct>) -> () { - // expected-error @+1 {{index 2 out of bounds for '!spv.struct>'}} - %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct> +func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct<(f32, !spv.array<4xf32>)>) -> () { + // expected-error @+1 {{index 2 out of bounds for '!spv.struct<(f32, !spv.array<4 x f32>)>'}} + %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct<(f32, !spv.array<4xf32>)> return } @@ -216,10 +216,10 @@ // ----- -func @composite_insert_struct(%arg0: !spv.struct, f32>, %arg1: !spv.array<4xf32>) -> !spv.struct, f32> { - // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spv.array<4 x f32> into !spv.struct, f32> - %0 = spv.CompositeInsert %arg1, %arg0[0 : i32] : !spv.array<4xf32> into !spv.struct, f32> - return %0: !spv.struct, f32> +func @composite_insert_struct(%arg0: !spv.struct<(!spv.array<4xf32>, f32)>, %arg1: !spv.array<4xf32>) -> !spv.struct<(!spv.array<4xf32>, f32)> { + // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spv.array<4 x f32> into !spv.struct<(!spv.array<4 x f32>, f32)> + %0 = spv.CompositeInsert %arg1, %arg0[0 : i32] : !spv.array<4xf32> into !spv.struct<(!spv.array<4xf32>, f32)> + return %0: !spv.struct<(!spv.array<4xf32>, f32)> } // ----- diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -143,9 +143,9 @@ // ----- -spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { +spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { // expected-error @+1 {{Pointer must point to a scalar or vector type}} - %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup> spv.Return } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -6,9 +6,9 @@ func @access_chain_struct() -> () { %0 = spv.constant 1: i32 - %1 = spv.Variable : !spv.ptr>, Function> - // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 + %1 = spv.Variable : !spv.ptr)>, Function> + // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr)>, Function> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Function>, i32, i32 return } @@ -111,9 +111,9 @@ // ----- func @access_chain_invalid_index_2(%index0 : i32) -> () { - %0 = spv.Variable : !spv.ptr>, Function> + %0 = spv.Variable : !spv.ptr)>, Function> // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct}} - %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32, i32 + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr)>, Function>, i32, i32 return } @@ -121,9 +121,9 @@ func @access_chain_invalid_constant_type_1() -> () { %0 = std.constant 1: i32 - %1 = spv.Variable : !spv.ptr>, Function> + %1 = spv.Variable : !spv.ptr)>, Function> // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct, but provided std.constant}} - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Function>, i32, i32 return } @@ -131,9 +131,9 @@ func @access_chain_out_of_bounds() -> () { %index0 = "spv.constant"() { value = 12: i32} : () -> i32 - %0 = spv.Variable : !spv.ptr>, Function> - // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct>'}} - %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32, i32 + %0 = spv.Variable : !spv.ptr)>, Function> + // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct<(f32, !spv.array<4 x f32>)>'}} + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr)>, Function>, i32, i32 return } diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -5,13 +5,13 @@ //===----------------------------------------------------------------------===// spv.module Logical GLSL450 { - spv.globalVariable @var1 : !spv.ptr>, Input> + spv.globalVariable @var1 : !spv.ptr)>, Input> spv.func @access_chain() -> () "None" { %0 = spv.constant 1: i32 - // CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr>, Input> - // CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr>, Input> - %1 = spv._address_of @var1 : !spv.ptr>, Input> - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Input>, i32, i32 + // CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr)>, Input> + // CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr)>, Input> + %1 = spv._address_of @var1 : !spv.ptr)>, Input> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr)>, Input>, i32, i32 spv.Return } } @@ -19,27 +19,27 @@ // ----- // Allow taking address of global variables in other module-like ops -spv.globalVariable @var : !spv.ptr>, Input> +spv.globalVariable @var : !spv.ptr)>, Input> func @address_of() -> () { // CHECK: spv._address_of @var - %1 = spv._address_of @var : !spv.ptr>, Input> + %1 = spv._address_of @var : !spv.ptr)>, Input> return } // ----- spv.module Logical GLSL450 { - spv.globalVariable @var1 : !spv.ptr>, Input> + spv.globalVariable @var1 : !spv.ptr)>, Input> spv.func @foo() -> () "None" { // expected-error @+1 {{expected spv.globalVariable symbol}} - %0 = spv._address_of @var2 : !spv.ptr>, Input> + %0 = spv._address_of @var2 : !spv.ptr)>, Input> } } // ----- spv.module Logical GLSL450 { - spv.globalVariable @var1 : !spv.ptr>, Input> + spv.globalVariable @var1 : !spv.ptr)>, Input> spv.func @foo() -> () "None" { // expected-error @+1 {{result type mismatch with the referenced global variable's type}} %0 = spv._address_of @var1 : !spv.ptr diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -230,119 +230,194 @@ // StructType //===----------------------------------------------------------------------===// -// CHECK: func @struct_type(!spv.struct) -func @struct_type(!spv.struct) -> () +// CHECK: func @struct_type(!spv.struct<(f32)>) +func @struct_type(!spv.struct<(f32)>) -> () -// CHECK: func @struct_type2(!spv.struct) -func @struct_type2(!spv.struct) -> () +// CHECK: func @struct_type2(!spv.struct<(f32 [0])>) +func @struct_type2(!spv.struct<(f32 [0])>) -> () -// CHECK: func @struct_type_simple(!spv.struct>) -func @struct_type_simple(!spv.struct>) -> () +// CHECK: func @struct_type_simple(!spv.struct<(f32, !spv.image)>) +func @struct_type_simple(!spv.struct<(f32, !spv.image)>) -> () -// CHECK: func @struct_type_with_offset(!spv.struct) -func @struct_type_with_offset(!spv.struct) -> () +// CHECK: func @struct_type_with_offset(!spv.struct<(f32 [0], i32 [4])>) +func @struct_type_with_offset(!spv.struct<(f32 [0], i32 [4])>) -> () -// CHECK: func @nested_struct(!spv.struct>) -func @nested_struct(!spv.struct>) +// CHECK: func @nested_struct(!spv.struct<(f32, !spv.struct<(f32, i32)>)>) +func @nested_struct(!spv.struct<(f32, !spv.struct<(f32, i32)>)>) -// CHECK: func @nested_struct_with_offset(!spv.struct [4]>) -func @nested_struct_with_offset(!spv.struct [4]>) +// CHECK: func @nested_struct_with_offset(!spv.struct<(f32 [0], !spv.struct<(f32 [0], i32 [4])> [4])>) +func @nested_struct_with_offset(!spv.struct<(f32 [0], !spv.struct<(f32 [0], i32 [4])> [4])>) -// CHECK: func @struct_type_with_decoration(!spv.struct) -func @struct_type_with_decoration(!spv.struct) +// CHECK: func @struct_type_with_decoration(!spv.struct<(f32 [NonWritable])>) +func @struct_type_with_decoration(!spv.struct<(f32 [NonWritable])>) -// CHECK: func @struct_type_with_decoration_and_offset(!spv.struct) -func @struct_type_with_decoration_and_offset(!spv.struct) +// CHECK: func @struct_type_with_decoration_and_offset(!spv.struct<(f32 [0, NonWritable])>) +func @struct_type_with_decoration_and_offset(!spv.struct<(f32 [0, NonWritable])>) -// CHECK: func @struct_type_with_decoration2(!spv.struct) -func @struct_type_with_decoration2(!spv.struct) +// CHECK: func @struct_type_with_decoration2(!spv.struct<(f32 [NonWritable], i32 [NonReadable])>) +func @struct_type_with_decoration2(!spv.struct<(f32 [NonWritable], i32 [NonReadable])>) -// CHECK: func @struct_type_with_decoration3(!spv.struct) -func @struct_type_with_decoration3(!spv.struct) +// CHECK: func @struct_type_with_decoration3(!spv.struct<(f32, i32 [NonReadable])>) +func @struct_type_with_decoration3(!spv.struct<(f32, i32 [NonReadable])>) -// CHECK: func @struct_type_with_decoration4(!spv.struct) -func @struct_type_with_decoration4(!spv.struct) +// CHECK: func @struct_type_with_decoration4(!spv.struct<(f32 [0], i32 [4, NonReadable])>) +func @struct_type_with_decoration4(!spv.struct<(f32 [0], i32 [4, NonReadable])>) -// CHECK: func @struct_type_with_decoration5(!spv.struct) -func @struct_type_with_decoration5(!spv.struct) +// CHECK: func @struct_type_with_decoration5(!spv.struct<(f32 [NonWritable, NonReadable])>) +func @struct_type_with_decoration5(!spv.struct<(f32 [NonWritable, NonReadable])>) -// CHECK: func @struct_type_with_decoration6(!spv.struct>) -func @struct_type_with_decoration6(!spv.struct>) +// CHECK: func @struct_type_with_decoration6(!spv.struct<(f32, !spv.struct<(i32 [NonWritable, NonReadable])>)>) +func @struct_type_with_decoration6(!spv.struct<(f32, !spv.struct<(i32 [NonWritable, NonReadable])>)>) -// CHECK: func @struct_type_with_decoration7(!spv.struct [4]>) -func @struct_type_with_decoration7(!spv.struct [4]>) +// CHECK: func @struct_type_with_decoration7(!spv.struct<(f32 [0], !spv.struct<(i32, f32 [NonReadable])> [4])>) +func @struct_type_with_decoration7(!spv.struct<(f32 [0], !spv.struct<(i32, f32 [NonReadable])> [4])>) -// CHECK: func @struct_type_with_decoration8(!spv.struct>) -func @struct_type_with_decoration8(!spv.struct>) +// CHECK: func @struct_type_with_decoration8(!spv.struct<(f32, !spv.struct<(i32 [0], f32 [4, NonReadable])>)>) +func @struct_type_with_decoration8(!spv.struct<(f32, !spv.struct<(i32 [0], f32 [4, NonReadable])>)>) -// CHECK: func @struct_type_with_matrix_1(!spv.struct> [0, ColMajor, MatrixStride=16]>) -func @struct_type_with_matrix_1(!spv.struct> [0, ColMajor, MatrixStride=16]>) +// CHECK: func @struct_type_with_matrix_1(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>) +func @struct_type_with_matrix_1(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>) -// CHECK: func @struct_type_with_matrix_2(!spv.struct> [0, RowMajor, MatrixStride=16]>) -func @struct_type_with_matrix_2(!spv.struct> [0, RowMajor, MatrixStride=16]>) +// CHECK: func @struct_type_with_matrix_2(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16])>) +func @struct_type_with_matrix_2(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16])>) -// CHECK: func @struct_empty(!spv.struct<>) -func @struct_empty(!spv.struct<>) +// CHECK: func @struct_empty(!spv.struct<()>) +func @struct_empty(!spv.struct<()>) // ----- // expected-error @+1 {{offset specification must be given for all members}} -func @struct_type_missing_offset1((!spv.struct) -> () +func @struct_type_missing_offset1((!spv.struct<(f32, i32 [4])>) -> () // ----- // expected-error @+1 {{offset specification must be given for all members}} -func @struct_type_missing_offset2(!spv.struct) -> () +func @struct_type_missing_offset2(!spv.struct<(f32 [3], i32)>) -> () // ----- -// expected-error @+1 {{expected '>'}} -func @struct_type_missing_comma1(!spv.struct) -> () +// expected-error @+1 {{expected ')'}} +func @struct_type_missing_comma1(!spv.struct<(f32 i32)>) -> () // ----- -// expected-error @+1 {{expected '>'}} -func @struct_type_missing_comma2(!spv.struct) -> () +// expected-error @+1 {{expected ')'}} +func @struct_type_missing_comma2(!spv.struct<(f32 [0] i32)>) -> () // ----- -// expected-error @+1 {{unbalanced '>' character in pretty dialect name}} -func @struct_type_neg_offset(!spv.struct) -> () +// expected-error @+1 {{unbalanced ')' character in pretty dialect name}} +func @struct_type_neg_offset(!spv.struct<(f32 [0)>) -> () // ----- // expected-error @+1 {{unbalanced ']' character in pretty dialect name}} -func @struct_type_neg_offset(!spv.struct) -> () +func @struct_type_neg_offset(!spv.struct<(f32 0])>) -> () // ----- // expected-error @+1 {{expected ']'}} -func @struct_type_neg_offset(!spv.struct) -> () +func @struct_type_neg_offset(!spv.struct<(f32 [NonWritable 0])>) -> () // ----- // expected-error @+1 {{expected valid keyword}} -func @struct_type_neg_offset(!spv.struct) -> () +func @struct_type_neg_offset(!spv.struct<(f32 [NonWritable, 0])>) -> () // ----- // expected-error @+1 {{expected ','}} -func @struct_type_missing_comma(!spv.struct) +func @struct_type_missing_comma(!spv.struct<(f32 [0 NonWritable], i32 [4])>) // ----- // expected-error @+1 {{expected ']'}} -func @struct_type_missing_comma(!spv.struct) +func @struct_type_missing_comma(!spv.struct<(f32 [0, NonWritable NonReadable], i32 [4])>) // ----- // expected-error @+1 {{expected ']'}} -func @struct_type_missing_comma(!spv.struct> [0, RowMajor MatrixStride=16]>) +func @struct_type_missing_comma(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, RowMajor MatrixStride=16])>) // ----- // expected-error @+1 {{expected integer value}} -func @struct_missing_member_decorator_value(!spv.struct> [0, RowMajor, MatrixStride=]>) +func @struct_missing_member_decorator_value(!spv.struct<(!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=])>) + +// ----- + +//===----------------------------------------------------------------------===// +// StructType (identified) +//===----------------------------------------------------------------------===// + +// CHECK: func @id_struct_empty(!spv.struct) +func @id_struct_empty(!spv.struct) -> () + +// ----- + +// CHECK: func @id_struct_simple(!spv.struct) +func @id_struct_simple(!spv.struct) -> () + +// ----- + +// CHECK: func @id_struct_multiple_elements(!spv.struct) +func @id_struct_multiple_elements(!spv.struct) -> () + +// ----- + +// CHECK: func @id_struct_nested_literal(!spv.struct)>) +func @id_struct_nested_literal(!spv.struct)>) -> () + +// ----- + +// CHECK: func @id_struct_nested_id(!spv.struct)>) +func @id_struct_nested_id(!spv.struct)>) -> () + +// ----- + +// CHECK: func @literal_struct_nested_id(!spv.struct<(!spv.struct)>) +func @literal_struct_nested_id(!spv.struct<(!spv.struct)>) -> () + +// ----- + +// CHECK: func @id_struct_self_recursive(!spv.struct, Uniform>)>) +func @id_struct_self_recursive(!spv.struct, Uniform>)>) -> () + +// ----- + +// CHECK: func @id_struct_self_recursive2(!spv.struct, Uniform>)>) +func @id_struct_self_recursive2(!spv.struct, Uniform>)>) -> () + +// ----- + +// expected-error @+1 {{recursive struct reference not nested in struct definition}} +func @id_wrong_recursive_reference(!spv.struct) -> () + +// ----- + +// expected-error @+1 {{recursive struct reference not nested in struct definition}} +func @id_struct_recursive_invalid(!spv.struct, Uniform>)>) -> () + +// ----- + +// expected-error @+1 {{identifier already used for an enclosing struct}} +func @id_struct_redefinition(!spv.struct, Uniform>)>, Uniform>)>) -> () + +// ----- + +// Equivalent to: +// struct a { struct b *bPtr; }; +// struct b { struct a *aPtr; }; +// CHECK: func @id_struct_recursive(!spv.struct, Uniform>)>, Uniform>)>) +func @id_struct_recursive(!spv.struct, Uniform>)>, Uniform>)>) -> () + +// ----- + +// Equivalent to: +// struct a { struct b *bPtr; }; +// struct b { struct a *aPtr, struct b *bPtr; }; +// CHECK: func @id_struct_recursive(!spv.struct, Uniform>, !spv.ptr, Uniform>)>, Uniform>)>) +func @id_struct_recursive(!spv.struct, Uniform>, !spv.ptr, Uniform>)>, Uniform>)>) -> () // ----- @@ -446,4 +521,4 @@ // expected-error @+1 {{expected single unsigned integer for number of columns}} func @matrix_size_type(!spv.matrix<2.0 x vector<3xi32>>) -> () -// ----- \ No newline at end of file +// ----- diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -620,10 +620,11 @@ if (op.getNumResults() == 1) { StringRef resultTypeID("resultTypeID"); os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID); + os << tabs << "llvm::SetVector serializationCtx;"; os << tabs - << formatv( - "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n", - opVar, resultTypeID); + << formatv("if (failed(processType({0}.getLoc(), {0}.getType(), {1}, " + "serializationCtx))) {{\n", + opVar, resultTypeID); os << tabs << " return failure();\n"; os << tabs << "}\n"; os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID);