diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -276,22 +276,30 @@ public: using Base::Base; - // Layout information used for members in a struct in SPIR-V - // - // TODO(ravishankarm) : For now this only supports the offset type, so uses - // uint64_t value to represent the offset, with - // std::numeric_limit::max indicating no offset. Change this to - // something that can hold all the information needed for different member - // types - using LayoutInfo = uint64_t; - - using MemberDecorationInfo = std::pair; + using OffsetInfo = uint32_t; + + struct MemberDecorationInfo { + uint32_t memberIndex; + Decoration decoration; + uint32_t decorationValue; + uint32_t hasValue; + bool operator==(const MemberDecorationInfo &other) const { + return (this->memberIndex == other.memberIndex) && + (this->decoration == other.decoration) && + (this->decorationValue == other.decorationValue); + } + bool operator<(const MemberDecorationInfo &other) const { + return this->memberIndex < other.memberIndex || + (this->memberIndex == other.memberIndex && + this->decoration < other.decoration); + } + }; static bool kindof(unsigned kind) { return kind == TypeKind::Struct; } /// Construct a StructType with at least one member. static StructType get(ArrayRef memberTypes, - ArrayRef layoutInfo = {}, + ArrayRef offsetInfo = {}, ArrayRef memberDecorations = {}); /// Construct a struct with no members. @@ -323,9 +331,9 @@ ElementTypeRange getElementTypes() const; - bool hasLayout() const; + bool hasOffset() const; - uint64_t getOffset(unsigned) const; + uint64_t getMemberOffset(unsigned) const; // Returns in `allMemberDecorations` the spirv::Decorations (apart from // Offset) associated with all members of the StructType. @@ -334,8 +342,9 @@ // Returns in `memberDecorations` all the spirv::Decorations (apart from // Offset) associated with the `i`-th member of the StructType. - void getMemberDecorations( - unsigned i, SmallVectorImpl &memberDecorations) const; + void getMemberDecorations(unsigned i, + SmallVectorImpl + &memberDecorations) const; void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); @@ -343,6 +352,8 @@ Optional storage = llvm::None); }; +llvm::hash_code hash_value(const StructType::MemberDecorationInfo &f); + // SPIR-V cooperative matrix type class CooperativeMatrixNVType : public Type::TypeBase memberTypes; - SmallVector layoutInfo; + SmallVector offsetInfo; SmallVector memberDecorations; Size structMemberOffset = 0; @@ -46,7 +46,8 @@ decorateType(structType.getElementType(i), memberSize, memberAlignment); structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment); memberTypes.push_back(memberType); - layoutInfo.push_back(structMemberOffset); + offsetInfo.push_back( + static_cast(structMemberOffset)); // If the member's size is the max value, it must be the last member and it // must be a runtime array. assert(memberSize != std::numeric_limits().max() || @@ -66,7 +67,7 @@ size = llvm::alignTo(structMemberOffset, maxMemberAlignment); alignment = maxMemberAlignment; structType.getMemberDecorations(memberDecorations); - return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations); + return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); } Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, @@ -168,7 +169,7 @@ case spirv::StorageClass::StorageBuffer: case spirv::StorageClass::PushConstant: case spirv::StorageClass::PhysicalStorageBuffer: - return structType.hasLayout() || !structType.getNumElements(); + return structType.hasOffset() || !structType.getNumElements(); default: return true; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -535,30 +535,31 @@ static ParseResult parseStructMemberDecorations( SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef memberTypes, - SmallVectorImpl &layoutInfo, + SmallVectorImpl &offsetInfo, SmallVectorImpl &memberDecorationInfo) { // Check if the first element is offset. - llvm::SMLoc layoutLoc = parser.getCurrentLocation(); - StructType::LayoutInfo layout = 0; - OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout); - if (layoutParseResult.hasValue()) { - if (failed(*layoutParseResult)) + llvm::SMLoc offsetLoc = parser.getCurrentLocation(); + StructType::OffsetInfo offset = 0; + OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset); + if (offsetParseResult.hasValue()) { + if (failed(*offsetParseResult)) return failure(); - if (layoutInfo.size() != memberTypes.size() - 1) { - return parser.emitError( - layoutLoc, "layout specification must be given for all members"); + if (offsetInfo.size() != memberTypes.size() - 1) { + return parser.emitError(offsetLoc, + "offset specification must be given for " + "all members"); } - layoutInfo.push_back(layout); + offsetInfo.push_back(offset); } // Check for no spirv::Decorations. if (succeeded(parser.parseOptionalRSquare())) return success(); - // If there was a layout, make sure to parse the comma. - if (layoutParseResult.hasValue() && parser.parseComma()) + // If there was an offset, make sure to parse the comma. + if (offsetParseResult.hasValue() && parser.parseComma()) return failure(); // Check for spirv::Decorations. @@ -567,9 +568,23 @@ if (!memberDecoration) return failure(); - memberDecorationInfo.emplace_back( - static_cast(memberTypes.size() - 1), - memberDecoration.getValue()); + // Parse member decoration value if it exists + if (succeeded(parser.parseOptionalEqual())) { + auto memberDecorationValue = + parseAndVerifyInteger(dialect, parser); + + if (!memberDecorationValue) + return failure(); + + memberDecorationInfo.emplace_back(spirv::StructType::MemberDecorationInfo( + {static_cast(memberTypes.size() - 1), + memberDecoration.getValue(), memberDecorationValue.getValue(), 1})); + } else { + memberDecorationInfo.emplace_back(spirv::StructType::MemberDecorationInfo( + {static_cast(memberTypes.size() - 1), + memberDecoration.getValue(), 0, 0})); + } + } while (succeeded(parser.parseOptionalComma())); return parser.parseRSquare(); @@ -587,7 +602,7 @@ return StructType::getEmpty(dialect.getContext()); SmallVector memberTypes; - SmallVector layoutInfo; + SmallVector offsetInfo; SmallVector memberDecorationInfo; do { @@ -597,21 +612,21 @@ memberTypes.push_back(memberType); if (succeeded(parser.parseOptionalLSquare())) { - if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo, + if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, memberDecorationInfo)) { return Type(); } } } while (succeeded(parser.parseOptionalComma())); - if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) { + if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { parser.emitError(parser.getNameLoc(), - "layout specification must be given for all members"); + "offset specification must be given for all members"); return Type(); } if (parser.parseGreater()) return Type(); - return StructType::get(memberTypes, layoutInfo, memberDecorationInfo); + return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); } // spirv-type ::= array-type @@ -679,17 +694,20 @@ os << "struct<"; auto printMember = [&](unsigned i) { os << type.getElementType(i); - SmallVector decorations; + SmallVector decorations; type.getMemberDecorations(i, decorations); - if (type.hasLayout() || !decorations.empty()) { + if (type.hasOffset() || !decorations.empty()) { os << " ["; - if (type.hasLayout()) { - os << type.getOffset(i); + if (type.hasOffset()) { + os << type.getMemberOffset(i); if (!decorations.empty()) os << ", "; } - auto eachFn = [&os](spirv::Decoration decoration) { - os << stringifyDecoration(decoration); + auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { + os << stringifyDecoration(decoration.decoration); + if (decoration.hasValue) { + os << "=" << decoration.decorationValue; + } }; llvm::interleaveComma(decorations, os, eachFn); os << "]"; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -874,17 +874,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage( unsigned numMembers, Type const *memberTypes, - StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations, + StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo) : TypeStorage(numMembers), memberTypes(memberTypes), - layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations), + offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations), memberDecorationsInfo(memberDecorationsInfo) {} - using KeyTy = std::tuple, ArrayRef, + using KeyTy = std::tuple, ArrayRef, ArrayRef>; bool operator==(const KeyTy &key) const { return key == - KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo()); + KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo()); } static StructTypeStorage *construct(TypeStorageAllocator &allocator, @@ -897,13 +897,13 @@ typesList = allocator.copyInto(keyTypes).data(); } - const StructType::LayoutInfo *layoutInfoList = nullptr; + const StructType::OffsetInfo *offsetInfoList = nullptr; if (!std::get<1>(key).empty()) { - ArrayRef keyLayoutInfo = std::get<1>(key); - assert(keyLayoutInfo.size() == keyTypes.size() && - "size of layout information must be same as the size of number of " + ArrayRef keyOffsetInfo = std::get<1>(key); + assert(keyOffsetInfo.size() == keyTypes.size() && + "size of offset information must be same as the size of number of " "elements"); - layoutInfoList = allocator.copyInto(keyLayoutInfo).data(); + offsetInfoList = allocator.copyInto(keyOffsetInfo).data(); } const StructType::MemberDecorationInfo *memberDecorationList = nullptr; @@ -914,7 +914,7 @@ memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } return new (allocator.allocate()) - StructTypeStorage(keyTypes.size(), typesList, layoutInfoList, + StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, numMemberDecorations, memberDecorationList); } @@ -922,9 +922,9 @@ return ArrayRef(memberTypes, getSubclassData()); } - ArrayRef getLayoutInfo() const { - if (layoutInfo) { - return ArrayRef(layoutInfo, getSubclassData()); + ArrayRef getOffsetInfo() const { + if (offsetInfo) { + return ArrayRef(offsetInfo, getSubclassData()); } return {}; } @@ -938,14 +938,14 @@ } Type const *memberTypes; - StructType::LayoutInfo const *layoutInfo; + StructType::OffsetInfo const *offsetInfo; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; }; StructType StructType::get(ArrayRef memberTypes, - ArrayRef layoutInfo, + ArrayRef offsetInfo, ArrayRef memberDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); // Sort the decorations. @@ -953,12 +953,12 @@ memberDecorations.begin(), memberDecorations.end()); llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct, - memberTypes, layoutInfo, sortedDecorations); + memberTypes, offsetInfo, sortedDecorations); } StructType StructType::getEmpty(MLIRContext *context) { return Base::get(context, TypeKind::Struct, ArrayRef(), - ArrayRef(), + ArrayRef(), ArrayRef()); } @@ -975,11 +975,11 @@ return ElementTypeRange(getImpl()->memberTypes, getNumElements()); } -bool StructType::hasLayout() const { return getImpl()->layoutInfo; } +bool StructType::hasOffset() const { return getImpl()->offsetInfo; } -uint64_t StructType::getOffset(unsigned index) const { +uint64_t StructType::getMemberOffset(unsigned index) const { assert(getNumElements() > index && "member index out of range"); - return getImpl()->layoutInfo[index]; + return getImpl()->offsetInfo[index]; } void StructType::getMemberDecorations( @@ -992,15 +992,16 @@ } void StructType::getMemberDecorations( - unsigned index, SmallVectorImpl &decorations) const { + unsigned index, + SmallVectorImpl &decorationsInfo) const { assert(getNumElements() > index && "member index out of range"); auto memberDecorations = getImpl()->getMemberDecorationsInfo(); - decorations.clear(); - for (auto &memberDecoration : memberDecorations) { - if (memberDecoration.first == index) { - decorations.push_back(memberDecoration.second); + decorationsInfo.clear(); + for (const auto &memberDecoration : memberDecorations) { + if (memberDecoration.memberIndex == index) { + decorationsInfo.push_back(memberDecoration); } - if (memberDecoration.first > index) { + if (memberDecoration.memberIndex > index) { // Early exit since the decorations are stored sorted. return; } @@ -1020,6 +1021,10 @@ elementType.cast().getCapabilities(capabilities, storage); } +llvm::hash_code spirv::hash_value(const StructType::MemberDecorationInfo &f) { + return llvm::hash_combine(f.memberIndex, f.decoration); +} + //===----------------------------------------------------------------------===// // MatrixType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1305,7 +1305,7 @@ memberTypes.push_back(memberType); } - SmallVector layoutInfo; + SmallVector offsetInfo; SmallVector memberDecorationsInfo; if (memberDecorationMap.count(operands[0])) { auto &allMemberDecorations = memberDecorationMap[operands[0]]; @@ -1314,27 +1314,29 @@ for (auto &memberDecoration : allMemberDecorations[memberIndex]) { // Check for offset. if (memberDecoration.first == spirv::Decoration::Offset) { - // If layoutInfo is empty, resize to the number of members; - if (layoutInfo.empty()) { - layoutInfo.resize(memberTypes.size()); + // If offset info is empty, resize to the number of members; + if (offsetInfo.empty()) { + offsetInfo.resize(memberTypes.size()); } - layoutInfo[memberIndex] = memberDecoration.second[0]; + offsetInfo[memberIndex] = memberDecoration.second[0]; } else { if (!memberDecoration.second.empty()) { - return emitError(unknownLoc, - "unhandled OpMemberDecoration with decoration ") - << stringifyDecoration(memberDecoration.first) - << " which has additional operands"; + memberDecorationsInfo.emplace_back( + spirv::StructType::MemberDecorationInfo( + {memberIndex, memberDecoration.first, + memberDecoration.second[0], 1})); + } else { + memberDecorationsInfo.emplace_back( + spirv::StructType::MemberDecorationInfo( + {memberIndex, memberDecoration.first, 0, 0})); } - memberDecorationsInfo.emplace_back(memberIndex, - memberDecoration.first); } } } } } typeMap[operands[0]] = - spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo); + spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); // TODO(ravishankarm): Update StructType to have member name as attribute as // well. return success(); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -227,9 +227,9 @@ } /// Process member decoration - LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberIndex, - spirv::Decoration decorationType, - ArrayRef values = {}); + LogicalResult processMemberDecoration( + uint32_t structID, + spirv::StructType::MemberDecorationInfo memberDecorationInfo); //===--------------------------------------------------------------------===// // Types @@ -736,14 +736,14 @@ return success(); } -LogicalResult -Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, - spirv::Decoration decorationType, - ArrayRef values) { +LogicalResult Serializer::processMemberDecoration( + uint32_t structID, + spirv::StructType::MemberDecorationInfo memberDecoration) { SmallVector args( - {structID, memberIndex, static_cast(decorationType)}); - if (!values.empty()) { - args.append(values.begin(), values.end()); + {structID, memberDecoration.memberIndex, + static_cast(memberDecoration.decoration)}); + if (memberDecoration.hasValue) { + args.push_back(memberDecoration.decorationValue); } return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); @@ -1070,7 +1070,7 @@ } if (auto structType = type.dyn_cast()) { - bool hasLayout = structType.hasLayout(); + bool hasOffset = structType.hasOffset(); for (auto elementIndex : llvm::seq(0, structType.getNumElements())) { uint32_t elementTypeID = 0; @@ -1079,11 +1079,15 @@ return failure(); } operands.push_back(elementTypeID); - if (hasLayout) { + if (hasOffset) { // Decorate each struct member with an offset - if (failed(processMemberDecoration( - resultID, elementIndex, spirv::Decoration::Offset, - static_cast(structType.getOffset(elementIndex))))) { + spirv::StructType::MemberDecorationInfo offsetDecoration; + offsetDecoration.memberIndex = elementIndex; + offsetDecoration.decoration = spirv::Decoration::Offset; + offsetDecoration.decorationValue = + static_cast(structType.getMemberOffset(elementIndex)); + offsetDecoration.hasValue = 1; + if (failed(processMemberDecoration(resultID, offsetDecoration))) { return emitError(loc, "cannot decorate ") << elementIndex << "-th member of " << structType << " with its offset"; @@ -1093,11 +1097,11 @@ SmallVector memberDecorations; structType.getMemberDecorations(memberDecorations); for (auto &memberDecoration : memberDecorations) { - if (failed(processMemberDecoration(resultID, memberDecoration.first, - memberDecoration.second))) { + if (failed(processMemberDecoration(resultID, memberDecoration))) { return emitError(loc, "cannot decorate ") - << memberDecoration.first << "-th member of " << structType - << " with " << stringifyDecoration(memberDecoration.second); + << static_cast(memberDecoration.memberIndex) + << "-th member of " << structType << " with " + << stringifyDecoration(memberDecoration.decoration); } } typeEnum = spirv::Opcode::OpTypeStruct; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -544,6 +544,11 @@ return parser.parseToken(Token::equal, "expected '='"); } + /// Parse a `=` token if present. + ParseResult parseOptionalEqual() override { + return success(parser.consumeIf(Token::equal)); + } + /// Parse a '<' token. ParseResult parseLess() override { return parser.parseToken(Token::less, "expected '<'"); diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir @@ -22,6 +22,9 @@ // CHECK: !spv.ptr, StorageBuffer> spv.globalVariable @var6 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr> [0, ColMajor, MatrixStride=16]>, StorageBuffer> + spv.globalVariable @var7 : !spv.ptr> [0, ColMajor, MatrixStride=16]>, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> spv.globalVariable @empty : !spv.ptr, StorageBuffer> diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -275,17 +275,23 @@ // CHECK: func @struct_type_with_decoration8(!spv.struct>) func @struct_type_with_decoration8(!spv.struct>) +// CHECK: func @struct_type_with_matrix_1(!spv.struct> [0, ColMajor, MatrixStride=16]>) +func @struct_type_with_matrix_1(!spv.struct> [0, ColMajor, MatrixStride=16]>) + +// CHECK: func @struct_type_with_matrix_2(!spv.struct> [0, RowMajor, MatrixStride=16]>) +func @struct_type_with_matrix_2(!spv.struct> [0, RowMajor, MatrixStride=16]>) + // CHECK: func @struct_empty(!spv.struct<>) func @struct_empty(!spv.struct<>) // ----- -// expected-error @+1 {{layout specification must be given for all members}} +// expected-error @+1 {{offset specification must be given for all members}} func @struct_type_missing_offset1((!spv.struct) -> () // ----- -// expected-error @+1 {{layout specification must be given for all members}} +// expected-error @+1 {{offset specification must be given for all members}} func @struct_type_missing_offset2(!spv.struct) -> () // ----- @@ -330,6 +336,16 @@ // ----- +// expected-error @+1 {{expected ']'}} +func @struct_type_missing_comma(!spv.struct> [0, RowMajor MatrixStride=16]>) + +// ----- + +// expected-error @+1 {{expected integer value}} +func @struct_missing_member_decorator_value(!spv.struct> [0, RowMajor, MatrixStride=]>) + +// ----- + //===----------------------------------------------------------------------===// // CooperativeMatrix //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -58,8 +58,8 @@ Type getFloatStructType() { OpBuilder opBuilder(module.body()); llvm::SmallVector elementTypes{opBuilder.getF32Type()}; - llvm::SmallVector layoutInfo{0}; - auto structType = spirv::StructType::get(elementTypes, layoutInfo); + llvm::SmallVector offsetInfo{0}; + auto structType = spirv::StructType::get(elementTypes, offsetInfo); return structType; }