diff --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization.cpp --- a/mlir/lib/Target/SPIRV/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization.cpp @@ -34,6 +34,10 @@ #define DEBUG_TYPE "spirv-deserialization" +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + /// Decodes a string literal in `words` starting at `wordIndex`. Update the /// latter to point to the position in words after the string literal. static inline StringRef decodeStringLiteral(ArrayRef words, @@ -55,6 +59,10 @@ } namespace { +//===----------------------------------------------------------------------===// +// Utility Definitions +//===----------------------------------------------------------------------===// + /// A struct for containing a header block's merge and continue targets. /// /// This struct is used to track original structured control flow info from @@ -124,6 +132,10 @@ SmallVector memberDecorationsInfo; }; +//===----------------------------------------------------------------------===// +// Deserializer Declaration +//===----------------------------------------------------------------------===// + /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each @@ -423,6 +435,14 @@ ArrayRef operands, bool deferInstructions = true); + /// Processes a SPIR-V instruction from the given `operands`. It should + /// deserialize into an op with the given `opName` and `numOperands`. + /// This method is a generic one for dispatching any SPIR-V ops without + /// variadic operands and attributes in TableGen definitions. + LogicalResult processOpWithoutGrammarAttr(ArrayRef words, + StringRef opName, bool hasResult, + unsigned numOperands); + /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current /// insertion point. LogicalResult processUndef(ArrayRef operands); @@ -580,6 +600,10 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Deserializer Method Definitions +//===----------------------------------------------------------------------===// + Deserializer::Deserializer(ArrayRef binary, MLIRContext *context) : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), module(createModuleOp()), opBuilder(module->body()) {} @@ -2497,6 +2521,87 @@ return dispatchToAutogenDeserialization(opcode, operands); } +LogicalResult +Deserializer::processOpWithoutGrammarAttr(ArrayRef words, + StringRef opName, bool hasResult, + unsigned numOperands) { + SmallVector resultTypes; + uint32_t valueID = 0; + + size_t wordIndex= 0; + if (hasResult) { + if (wordIndex >= words.size()) + return emitError(unknownLoc, + "expected result type while deserializing for ") + << opName; + + // Decode the type + auto type = getType(words[wordIndex]); + if (!type) + return emitError(unknownLoc, "unknown type result : ") + << words[wordIndex]; + resultTypes.push_back(type); + ++wordIndex; + + // Decode the result + if (wordIndex >= words.size()) + return emitError(unknownLoc, + "expected result while deserializing for ") + << opName; + valueID = words[wordIndex]; + ++wordIndex; + } + + SmallVector operands; + SmallVector attributes; + + // Decode operands + size_t operandIndex = 0; + for (; operandIndex < numOperands && wordIndex < words.size(); + ++operandIndex, ++wordIndex) { + auto arg = getValue(words[wordIndex]); + if (!arg) + return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; + operands.push_back(arg); + } + if (operandIndex != numOperands) { + return emitError( + unknownLoc, + "found less operands than expected when deserializing for ") + << opName << "; only " << operandIndex << " of " << numOperands + << " processed"; + } + if (wordIndex != words.size()) { + return emitError( + unknownLoc, + "found more operands than expected when deserializing for ") + << opName << "; only " << wordIndex << " of " << words.size() + << " processed"; + } + + // Attach attributes from decorations + if (decorations.count(valueID)) { + auto attrs = decorations[valueID].getAttrs(); + attributes.append(attrs.begin(), attrs.end()); + } + + // Create the op and update bookkeeping maps + Location loc = createFileLineColLoc(opBuilder); + OperationState opState(loc, opName); + opState.addOperands(operands); + if (hasResult) + opState.addTypes(resultTypes); + opState.addAttributes(attributes); + Operation *op = opBuilder.createOperation(opState); + if (hasResult) + valueMap[valueID] = op->getResult(0); + + if (op->hasTrait()) + clearDebugLine(); + + return success(); +} + LogicalResult Deserializer::processUndef(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpUndef instruction must have two operands"); @@ -2779,6 +2884,7 @@ // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" + } // namespace namespace mlir { diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -914,16 +914,28 @@ const Record *record, const Operator &op, raw_ostream &os) { // If the record has 'autogenSerialization' set to 0, nothing to do - if (!record->getValueAsBit("autogenSerialization")) { + if (!record->getValueAsBit("autogenSerialization")) return; - } + StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"), wordIndex("wordIndex"), opVar("op"), operands("operands"), attributes("attributes"); + + // Method declaration os << formatv("template <> " "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<" "uint32_t> {1}) {{\n", op.getQualCppClassName(), words); + + // Special case for ops without attributes in TableGen definitions + if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) { + os << formatv(" return processOpWithoutGrammarAttr(" + "{0}, \"{1}\", {2}, {3});\n}\n\n", + words, op.getOperationName(), + op.getNumResults() ? "true" : "false", op.getNumOperands()); + return; + } + os << formatv(" SmallVector {0};\n", resultTypes); os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex); os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID); @@ -938,6 +950,9 @@ emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex, operands, attributes, os); + // Decorations + emitDecorationDeserialization(op, " ", valueID, attributes, os); + os << formatv(" Location loc = createFileLineColLoc(opBuilder);\n"); os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); " "(void){1};\n", @@ -953,9 +968,6 @@ // next end of block. os << formatv(" if ({0}.hasTrait())\n", opVar); os << formatv(" clearDebugLine();\n"); - - // Decorations - emitDecorationDeserialization(op, " ", valueID, attributes, os); os << " return success();\n"; os << "}\n\n"; }