diff --git a/mlir/lib/Target/SPIRV/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization.cpp @@ -364,15 +364,23 @@ /// Main dispatch method for serializing an operation. LogicalResult processOperation(Operation *op); - /// Method to dispatch to the serialization function for an operation in - /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. - /// This is auto-generated from ODS. Dispatch is handled for all operations - /// in SPIR-V dialect that have hasOpcode == 1. + /// Serializes an operation `op` as core instruction with `opcode` if + /// `extInstSet` is empty. Otherwise serializes it as an extended instruction + /// with `opcode` from `extInstSet`. + /// This method is a generic one for dispatching any SPIR-V ops that has no + /// variadic operands and attributes in TableGen definitions. + LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet, + uint32_t opcode); + + /// Dispatches to the serialization function for an operation in SPIR-V + /// dialect that is a mirror of an instruction in the SPIR-V spec. This is + /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V + /// dialect that have hasOpcode == 1. LogicalResult dispatchToAutogenSerialization(Operation *op); - /// Method to serialize an operation in the SPIR-V dialect that is a mirror of - /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == - /// 1 and autogenSerialization == 1 in ODS. + /// Serializes an operation in the SPIR-V dialect that is a mirror of an + /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1 + /// and autogenSerialization == 1 in ODS. template LogicalResult processOp(OpTy op) { return op.emitError("unsupported op serialization"); @@ -1930,6 +1938,46 @@ [&](Operation *op) { return dispatchToAutogenSerialization(op); }); } +LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, + StringRef extInstSet, + uint32_t opcode) { + SmallVector operands; + Location loc = op->getLoc(); + + uint32_t resultID = 0; + if (op->getNumResults() != 0) { + uint32_t resultTypeID = 0; + if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) + return failure(); + operands.push_back(resultTypeID); + + resultID = getNextID(); + operands.push_back(resultID); + valueIDMap[op->getResult(0)] = resultID; + }; + + for (Value operand : op->getOperands()) + operands.push_back(getValueID(operand)); + + emitDebugLine(functionBody, loc); + + if (extInstSet.empty()) { + encodeInstructionInto(functionBody, static_cast(opcode), + operands); + } else { + encodeExtensionInstruction(op, extInstSet, opcode, operands); + } + + if (op->getNumResults() != 0) { + for (auto attr : op->getAttrs()) { + if (failed(processDecoration(loc, resultID, attr))) + return failure(); + } + } + + return success(); +} + namespace { template <> LogicalResult diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -647,8 +647,7 @@ // All non-argument attributes translated into OpDecorate instruction os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar); os << tabs - << formatv(" if (llvm::any_of({0}, [&](StringRef elided)", elidedAttrs); - os << " {return attr.first == elided;})) {\n"; + << formatv(" if (llvm::is_contained({0}, attr.first)) {{", elidedAttrs); os << tabs << " continue;\n"; os << tabs << " }\n"; os << tabs @@ -666,14 +665,35 @@ const Record *record, const Operator &op, raw_ostream &os) { // If the record has 'autogenSerialization' set to 0, nothing to do - if (!record->getValueAsBit("autogenSerialization")) { + if (!record->getValueAsBit("autogenSerialization")) return; - } + StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"), resultID("resultID"); + os << formatv( "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n", op.getQualCppClassName(), opVar); + + // Special case for ops without attributes in TableGen definitions + if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) { + std::string extInstSet; + std::string opcode; + if (record->isSubClassOf("SPV_ExtInstOp")) { + extInstSet = + formatv("\"{0}\"", record->getValueAsString("extendedInstSetName")); + opcode = std::to_string(record->getValueAsInt("extendedInstOpcode")); + } else { + extInstSet = "\"\""; + opcode = formatv("static_cast(spirv::Opcode::{0})", + record->getValueAsString("spirvOpName")); + } + + os << formatv(" return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n", + opVar, extInstSet, opcode); + return; + } + os << formatv(" SmallVector {0};\n", operands); os << formatv(" SmallVector {0};\n", elidedAttrs);