diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -487,8 +487,8 @@ ArrayRef loc, StringRef tabs, StringRef opVar, StringRef operandList, StringRef attrName, raw_ostream &os) { - os << tabs << formatv("auto attr = {0}.getAttr(\"{1}\");\n", opVar, attrName); - os << tabs << "if (attr) {\n"; + os << tabs + << formatv("if (auto attr = {0}.getAttr(\"{1}\")) {{\n", opVar, attrName); if (attr.getAttrDefName() == "SPV_ScopeAttr" || attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { os << tabs @@ -522,10 +522,57 @@ /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the /// attributes. The `operands` vector is updated appropriately. `elidedAttrs` /// updated as well to include the serialized attributes. -static void emitOperandSerialization(const Operator &op, ArrayRef loc, - StringRef tabs, StringRef opVar, - StringRef operands, StringRef elidedAttrs, - raw_ostream &os) { +static void emitArgumentSerialization(const Operator &op, ArrayRef loc, + StringRef tabs, StringRef opVar, + StringRef operands, StringRef elidedAttrs, + raw_ostream &os) { + using mlir::tblgen::Argument; + + // SPIR-V ops can mix operands and attributes in the definition. These + // operands and attributes are serialized in the exact order of the definition + // to match SPIR-V binary format requirements. It can cause excessive + // generated code bloat because we are emitting code to handle each + // operand/attribute separately. So here we probe first to check whether all + // the operands are ahead of attributes. Then we can serialize all operands + // together. + + // Whether all operands are ahead of all attributes in the op's spec. + bool areOperandsAheadOfAttrs = true; + // Find the first attribute. + const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) { + return arg.is(); + }); + // Check whether all following arguments are attributes. + for (const Argument *ie = op.arg_end(); it != ie; ++it) { + if (!it->is()) { + areOperandsAheadOfAttrs = false; + break; + } + } + + // Serialize all operands together. + if (areOperandsAheadOfAttrs) { + if (op.getNumOperands() != 0) { + os << tabs + << formatv( + "for (Value operand : {0}.getOperation()->getOperands()) {{\n", + opVar); + os << tabs << " auto id = getValueID(operand);\n"; + os << tabs << " assert(id && \"use before def!\");\n"; + os << tabs << formatv(" {0}.push_back(id);\n", operands); + os << tabs << "}\n"; + } + for (const NamedAttribute &attr : op.getAttributes()) { + emitAttributeSerialization( + (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc, + tabs, opVar, operands, attr.name, os); + os << tabs + << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr.name); + } + return; + } + + // Serialize operands separately. auto operandNum = 0; for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); @@ -545,7 +592,7 @@ os << " }\n"; operandNum++; } else { - auto attr = argument.get(); + NamedAttribute *attr = argument.get(); auto newtabs = tabs.str() + " "; emitAttributeSerialization( (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), @@ -632,8 +679,8 @@ } // Process arguments. - emitOperandSerialization(op, record->getLoc(), " ", opVar, operands, - elidedAttrs, os); + emitArgumentSerialization(op, record->getLoc(), " ", opVar, operands, + elidedAttrs, os); if (record->isSubClassOf("SPV_ExtInstOp")) { os << formatv(" encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n",