diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -316,18 +316,6 @@ return failure(); } - if (isInterfaceStructPtrType(varOp.type())) { - auto structType = varOp.type() - .cast() - .getPointeeType() - .cast(); - if (failed( - emitDecoration(getTypeID(structType), spirv::Decoration::Block))) { - return varOp.emitError("cannot decorate ") - << structType << " with Block decoration"; - } - } - elidedAttrs.push_back("type"); SmallVector operands; operands.push_back(resultTypeID); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -331,9 +331,9 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, SetVector &serializationCtx) { typeID = getTypeID(type); - if (typeID) { + if (typeID) return success(); - } + typeID = getNextID(); SmallVector operands; @@ -499,6 +499,14 @@ typeEnum = spirv::Opcode::OpTypePointer; operands.push_back(static_cast(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); + + if (isInterfaceStructPtrType(ptrType)) { + if (failed(emitDecoration(getTypeID(pointeeStruct), + spirv::Decoration::Block))) + return emitError(loc, "cannot decorate ") + << pointeeStruct << " with Block decoration"; + } + return success(); } diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -76,27 +76,29 @@ builder.getStringAttr(name), nullptr); } + /// Handles a SPIR-V instruction with the given `opcode` and `operand`. + /// Returns true to interrupt. + using HandleFn = llvm::function_ref operands)>; + /// Returns true if we can find a matching instruction in the SPIR-V blob. - bool findInstruction(llvm::function_ref operands)> - matchFn) { + bool scanInstruction(HandleFn handleFn) { auto binarySize = binary.size(); auto *begin = binary.begin(); auto currOffset = spirv::kHeaderWordCount; while (currOffset < binarySize) { auto wordCount = binary[currOffset] >> 16; - if (!wordCount || (currOffset + wordCount > binarySize)) { + if (!wordCount || (currOffset + wordCount > binarySize)) return false; - } + spirv::Opcode opcode = static_cast(binary[currOffset] & 0xffff); - - if (matchFn(opcode, - llvm::ArrayRef(begin + currOffset + 1, - begin + currOffset + wordCount))) { + llvm::ArrayRef operands(begin + currOffset + 1, + begin + currOffset + wordCount); + if (handleFn(opcode, operands)) return true; - } + currOffset += wordCount; } return false; @@ -119,12 +121,32 @@ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); auto hasBlockDecoration = [](spirv::Opcode opcode, - ArrayRef operands) -> bool { - if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) - return false; - return operands[1] == static_cast(spirv::Decoration::Block); + ArrayRef operands) { + return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && + operands[1] == static_cast(spirv::Decoration::Block); + }; + EXPECT_TRUE(scanInstruction(hasBlockDecoration)); +} + +TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) { + auto structType = getFloatStructType(); + // Two global variables using the same type should not decorate the type with + // duplicated `Block` decorations. + addGlobalVar(structType, "var0"); + addGlobalVar(structType, "var1"); + + ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); + + unsigned count = 0; + auto countBlockDecoration = [&count](spirv::Opcode opcode, + ArrayRef operands) { + if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && + operands[1] == static_cast(spirv::Decoration::Block)) + ++count; + return false; }; - EXPECT_TRUE(findInstruction(hasBlockDecoration)); + ASSERT_FALSE(scanInstruction(countBlockDecoration)); + EXPECT_EQ(count, 1u); } TEST_F(SerializationTest, ContainsSymbolName) { @@ -140,7 +162,7 @@ return opcode == spirv::Opcode::OpName && spirv::decodeStringLiteral(operands, index) == "var0"; }; - EXPECT_TRUE(findInstruction(hasVarName)); + EXPECT_TRUE(scanInstruction(hasVarName)); } TEST_F(SerializationTest, DoesNotContainSymbolName) { @@ -156,5 +178,5 @@ return opcode == spirv::Opcode::OpName && spirv::decodeStringLiteral(operands, index) == "var0"; }; - EXPECT_FALSE(findInstruction(hasVarName)); + EXPECT_FALSE(scanInstruction(hasVarName)); }