diff --git a/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt b/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt --- a/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt +++ b/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_translation_library(MLIRSPIRVSerialization Serialization.cpp + Serializer.cpp + SerializeOps.cpp DEPENDS MLIRSPIRVSerializationGen diff --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp @@ -1,4 +1,4 @@ -//===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===// +//===- Serialization.cpp - MLIR SPIR-V Serialization ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,2252 +6,20 @@ // //===----------------------------------------------------------------------===// // -// This file defines the MLIR SPIR-V module to SPIR-V binary serialization. +// This file defines the MLIR SPIR-V module to SPIR-V binary serialization entry +// point. // //===----------------------------------------------------------------------===// +#include "Serializer.h" + #include "mlir/Target/SPIRV/Serialization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/RegionGraphTraits.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "spirv-serialization" -using namespace mlir; - -/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into -/// the given `binary` vector. -static LogicalResult encodeInstructionInto(SmallVectorImpl &binary, - spirv::Opcode op, - ArrayRef operands) { - uint32_t wordCount = 1 + operands.size(); - binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); - binary.append(operands.begin(), operands.end()); - return success(); -} - -/// A pre-order depth-first visitor function for processing basic blocks. -/// -/// Visits the basic blocks starting from the given `headerBlock` in pre-order -/// depth-first manner and calls `blockHandler` on each block. Skips handling -/// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` -/// will not be invoked in `headerBlock` but still handles all `headerBlock`'s -/// successors. -/// -/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order -/// of blocks in a function must satisfy the rule that blocks appear before -/// all blocks they dominate." This can be achieved by a pre-order CFG -/// traversal algorithm. To make the serialization output more logical and -/// readable to human, we perform depth-first CFG traversal and delay the -/// serialization of the merge block and the continue block, if exists, until -/// after all other blocks have been processed. -static LogicalResult -visitInPrettyBlockOrder(Block *headerBlock, - function_ref blockHandler, - bool skipHeader = false, BlockRange skipBlocks = {}) { - llvm::df_iterator_default_set doneBlocks; - doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); - - for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { - if (skipHeader && block == headerBlock) - continue; - if (failed(blockHandler(block))) - return failure(); - } - return success(); -} - -/// Returns the merge block if the given `op` is a structured control flow op. -/// Otherwise returns nullptr. -static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { - if (auto selectionOp = dyn_cast(op)) - return selectionOp.getMergeBlock(); - if (auto loopOp = dyn_cast(op)) - return loopOp.getMergeBlock(); - return nullptr; -} - -/// Given a predecessor `block` for a block with arguments, returns the block -/// that should be used as the parent block for SPIR-V OpPhi instructions -/// corresponding to the block arguments. -static Block *getPhiIncomingBlock(Block *block) { - // If the predecessor block in question is the entry block for a spv.loop, - // we jump to this spv.loop from its enclosing block. - if (block->isEntryBlock()) { - if (auto loopOp = dyn_cast(block->getParentOp())) { - // Then the incoming parent block for OpPhi should be the merge block of - // the structured control flow op before this loop. - Operation *op = loopOp.getOperation(); - while ((op = op->getPrevNode()) != nullptr) - if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) - return incomingBlock; - // Or the enclosing block itself if no structured control flow ops - // exists before this loop. - return loopOp->getBlock(); - } - } - - // Otherwise, we jump from the given predecessor block. Try to see if there is - // a structured control flow op inside it. - for (Operation &op : llvm::reverse(block->getOperations())) { - if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) - return incomingBlock; - } - return block; -} - -namespace { - -/// A SPIR-V module serializer. -/// -/// A SPIR-V binary module is a single linear stream of instructions; each -/// instruction is composed of 32-bit words with the layout: -/// -/// | | | | | ... | -/// | <------ word -------> | <-- word --> | <-- word --> | ... | -/// -/// For the first word, the 16 high-order bits are the word count of the -/// instruction, the 16 low-order bits are the opcode enumerant. The -/// instructions then belong to different sections, which must be laid out in -/// the particular order as specified in "2.4 Logical Layout of a Module" of -/// the SPIR-V spec. -class Serializer { -public: - /// Creates a serializer for the given SPIR-V `module`. - explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); - - /// Serializes the remembered SPIR-V module. - LogicalResult serialize(); - - /// Collects the final SPIR-V `binary`. - void collect(SmallVectorImpl &binary); - -#ifndef NDEBUG - /// (For debugging) prints each value and its corresponding result . - void printValueIDMap(raw_ostream &os); -#endif - -private: - // Note that there are two main categories of methods in this class: - // * process*() methods are meant to fully serialize a SPIR-V module entity - // (header, type, op, etc.). They update internal vectors containing - // different binary sections. They are not meant to be called except the - // top-level serialization loop. - // * prepare*() methods are meant to be helpers that prepare for serializing - // certain entity. They may or may not update internal vectors containing - // different binary sections. They are meant to be called among themselves - // or by other process*() methods for subtasks. - - //===--------------------------------------------------------------------===// - // - //===--------------------------------------------------------------------===// - - // Note that it is illegal to use id <0> in SPIR-V binary module. Various - // methods in this class, if using SPIR-V word (uint32_t) as interface, - // check or return id <0> to indicate error in processing. - - /// Consumes the next unused . This method will never return 0. - uint32_t getNextID() { return nextID++; } - - //===--------------------------------------------------------------------===// - // Module structure - //===--------------------------------------------------------------------===// - - uint32_t getSpecConstID(StringRef constName) const { - return specConstIDMap.lookup(constName); - } - - uint32_t getVariableID(StringRef varName) const { - return globalVarIDMap.lookup(varName); - } - - uint32_t getFunctionID(StringRef fnName) const { - return funcIDMap.lookup(fnName); - } - - /// Gets the for the function with the given name. Assigns the next - /// available if the function haven't been deserialized. - uint32_t getOrCreateFunctionID(StringRef fnName); - - void processCapability(); - - void processDebugInfo(); - - void processExtension(); - - void processMemoryModel(); - - LogicalResult processConstantOp(spirv::ConstantOp op); - - LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); - - LogicalResult - processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); - - LogicalResult - processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); - - /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA - /// value to use with other operations. The SPIR-V spec recommends that - /// OpUndef be generated at module level. The serialization generates an - /// OpUndef for each type needed at module level. - LogicalResult processUndefOp(spirv::UndefOp op); - - /// Emit OpName for the given `resultID`. - LogicalResult processName(uint32_t resultID, StringRef name); - - /// Processes a SPIR-V function op. - LogicalResult processFuncOp(spirv::FuncOp op); - - LogicalResult processVariableOp(spirv::VariableOp op); - - /// Process a SPIR-V GlobalVariableOp - LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); - - /// Process attributes that translate to decorations on the result - LogicalResult processDecoration(Location loc, uint32_t resultID, - NamedAttribute attr); - - template - LogicalResult processTypeDecoration(Location loc, DType type, - uint32_t resultId) { - return emitError(loc, "unhandled decoration for type:") << type; - } - - /// Process member decoration - LogicalResult processMemberDecoration( - uint32_t structID, - const spirv::StructType::MemberDecorationInfo &memberDecorationInfo); - - //===--------------------------------------------------------------------===// - // Types - //===--------------------------------------------------------------------===// - - uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); } - - Type getVoidType() { return mlirBuilder.getNoneType(); } - - bool isVoidType(Type type) const { return type.isa(); } - - /// Returns true if the given type is a pointer type to a struct in some - /// interface storage class. - bool isInterfaceStructPtrType(Type type) const; - - /// Main dispatch method for serializing a type. The result of the - /// serialized type will be returned as `typeID`. - LogicalResult processType(Location loc, Type type, uint32_t &typeID); - LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID, - llvm::SetVector &serializationCtx); - - /// Method for preparing basic SPIR-V type serialization. Returns the type's - /// opcode and operands for the instruction via `typeEnum` and `operands`. - LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, - spirv::Opcode &typeEnum, - SmallVectorImpl &operands, - bool &deferSerialization, - llvm::SetVector &serializationCtx); - - LogicalResult prepareFunctionType(Location loc, FunctionType type, - spirv::Opcode &typeEnum, - SmallVectorImpl &operands); - - //===--------------------------------------------------------------------===// - // Constant - //===--------------------------------------------------------------------===// - - uint32_t getConstantID(Attribute value) const { - return constIDMap.lookup(value); - } - - /// Main dispatch method for processing a constant with the given `constType` - /// and `valueAttr`. `constType` is needed here because we can interpret the - /// `valueAttr` as a different type than the type of `valueAttr` itself; for - /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType - /// constants. - uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); - - /// Prepares array attribute serialization. This method emits corresponding - /// OpConstant* and returns the result associated with it. Returns 0 if - /// failed. - uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); - - /// Prepares bool/int/float DenseElementsAttr serialization. This method - /// iterates the DenseElementsAttr to construct the constant array, and - /// returns the result associated with it. Returns 0 if failed. Note - /// that the size of `index` must match the rank. - /// TODO: Consider to enhance splat elements cases. For splat cases, - /// we don't need to loop over all elements, especially when the splat value - /// is zero. We can use OpConstantNull when the value is zero. - uint32_t prepareDenseElementsConstant(Location loc, Type constType, - DenseElementsAttr valueAttr, int dim, - MutableArrayRef index); - - /// Prepares scalar attribute serialization. This method emits corresponding - /// OpConstant* and returns the result associated with it. Returns 0 if - /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is - /// true, then the constant will be serialized as a specialization constant. - uint32_t prepareConstantScalar(Location loc, Attribute valueAttr, - bool isSpec = false); - - uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, - bool isSpec = false); - - uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, - bool isSpec = false); - - uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, - bool isSpec = false); - - //===--------------------------------------------------------------------===// - // Control flow - //===--------------------------------------------------------------------===// - - /// Returns the result for the given block. - uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } - - /// Returns the result for the given block. If no has been assigned, - /// assigns the next available - uint32_t getOrCreateBlockID(Block *block); - - /// Processes the given `block` and emits SPIR-V instructions for all ops - /// inside. Does not emit OpLabel for this block if `omitLabel` is true. - /// `actionBeforeTerminator` is a callback that will be invoked before - /// handling the terminator op. It can be used to inject the Op*Merge - /// instruction if this is a SPIR-V selection/loop header block. - LogicalResult - processBlock(Block *block, bool omitLabel = false, - function_ref actionBeforeTerminator = nullptr); - - /// Emits OpPhi instructions for the given block if it has block arguments. - LogicalResult emitPhiForBlockArguments(Block *block); - - LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); - - LogicalResult processLoopOp(spirv::LoopOp loopOp); - - LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); - - LogicalResult processBranchOp(spirv::BranchOp branchOp); - - //===--------------------------------------------------------------------===// - // Operations - //===--------------------------------------------------------------------===// - - LogicalResult encodeExtensionInstruction(Operation *op, - StringRef extensionSetName, - uint32_t opcode, - ArrayRef operands); - - uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } - - LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); - - LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp); - - /// Main dispatch method for serializing an operation. - LogicalResult processOperation(Operation *op); - - /// Serializes an operation `op` as core instruction with `opcode` if - /// `extInstSet` is empty. Otherwise serializes it as an extended instruction - /// with `opcode` from `extInstSet`. - /// This method is a generic one for dispatching any SPIR-V ops that has no - /// variadic operands and attributes in TableGen definitions. - LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet, - uint32_t opcode); - - /// Dispatches to the serialization function for an operation in SPIR-V - /// dialect that is a mirror of an instruction in the SPIR-V spec. This is - /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V - /// dialect that have hasOpcode == 1. - LogicalResult dispatchToAutogenSerialization(Operation *op); - - /// Serializes an operation in the SPIR-V dialect that is a mirror of an - /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1 - /// and autogenSerialization == 1 in ODS. - template - LogicalResult processOp(OpTy op) { - return op.emitError("unsupported op serialization"); - } - - //===--------------------------------------------------------------------===// - // Utilities - //===--------------------------------------------------------------------===// - - /// Emits an OpDecorate instruction to decorate the given `target` with the - /// given `decoration`. - LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, - ArrayRef params = {}); - - /// Emits an OpLine instruction with the given `loc` location information into - /// the given `binary` vector. - LogicalResult emitDebugLine(SmallVectorImpl &binary, Location loc); - -private: - /// The SPIR-V module to be serialized. - spirv::ModuleOp module; - - /// An MLIR builder for getting MLIR constructs. - mlir::Builder mlirBuilder; - - /// A flag which indicates if the debuginfo should be emitted. - bool emitDebugInfo = false; - - /// A flag which indicates if the last processed instruction was a merge - /// instruction. - /// According to SPIR-V spec: "If a branch merge instruction is used, the last - /// OpLine in the block must be before its merge instruction". - bool lastProcessedWasMergeInst = false; - - /// The of the OpString instruction, which specifies a file name, for - /// use by other debug instructions. - uint32_t fileID = 0; - - /// The next available result . - uint32_t nextID = 1; - - // The following are for different SPIR-V instruction sections. They follow - // the logical layout of a SPIR-V module. - - SmallVector capabilities; - SmallVector extensions; - SmallVector extendedSets; - SmallVector memoryModel; - SmallVector entryPoints; - SmallVector executionModes; - SmallVector debug; - SmallVector names; - SmallVector decorations; - SmallVector typesGlobalValues; - SmallVector functions; - - /// Recursive struct references are serialized as OpTypePointer instructions - /// to the recursive struct type. However, the OpTypePointer instruction - /// cannot be emitted before the recursive struct's OpTypeStruct. - /// RecursiveStructPointerInfo stores the data needed to emit such - /// OpTypePointer instructions after forward references to such types. - struct RecursiveStructPointerInfo { - uint32_t pointerTypeID; - spirv::StorageClass storageClass; - }; - - // Maps spirv::StructType to its recursive reference member info. - DenseMap> - recursiveStructInfos; - - /// `functionHeader` contains all the instructions that must be in the first - /// block in the function, and `functionBody` contains the rest. After - /// processing FuncOp, the encoded instructions of a function are appended to - /// `functions`. An example of instructions in `functionHeader` in order: - /// OpFunction ... - /// OpFunctionParameter ... - /// OpFunctionParameter ... - /// OpLabel ... - /// OpVariable ... - /// OpVariable ... - SmallVector functionHeader; - SmallVector functionBody; - - /// Map from type used in SPIR-V module to their s. - DenseMap typeIDMap; - - /// Map from constant values to their s. - DenseMap constIDMap; - - /// Map from specialization constant names to their s. - llvm::StringMap specConstIDMap; - - /// Map from GlobalVariableOps name to s. - llvm::StringMap globalVarIDMap; - - /// Map from FuncOps name to s. - llvm::StringMap funcIDMap; - - /// Map from blocks to their s. - DenseMap blockIDMap; - - /// Map from the Type to the that represents undef value of that type. - DenseMap undefValIDMap; - - /// Map from results of normal operations to their s. - DenseMap valueIDMap; - - /// Map from extended instruction set name to s. - llvm::StringMap extendedInstSetIDMap; - - /// Map from values used in OpPhi instructions to their offset in the - /// `functions` section. - /// - /// When processing a block with arguments, we need to emit OpPhi - /// instructions to record the predecessor block s and the values they - /// send to the block in question. But it's not guaranteed all values are - /// visited and thus assigned result s. So we need this list to capture - /// the offsets into `functions` where a value is used so that we can fix it - /// up later after processing all the blocks in a function. - /// - /// More concretely, say if we are visiting the following blocks: - /// - /// ```mlir - /// ^phi(%arg0: i32): - /// ... - /// ^parent1: - /// ... - /// spv.Branch ^phi(%val0: i32) - /// ^parent2: - /// ... - /// spv.Branch ^phi(%val1: i32) - /// ``` - /// - /// When we are serializing the `^phi` block, we need to emit at the beginning - /// of the block OpPhi instructions which has the following parameters: - /// - /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 - /// id-for-%val1 id-for-^parent2 - /// - /// But we don't know the for %val0 and %val1 yet. One way is to visit - /// all the blocks twice and use the first visit to assign an to each - /// value. But it's paying the overheads just for OpPhi emission. Instead, - /// we still visit the blocks once for emission. When we emit the OpPhi - /// instructions, we use 0 as a placeholder for the s for %val0 and %val1. - /// At the same time, we record their offsets in the emitted binary (which is - /// placed inside `functions`) here. And then after emitting all blocks, we - /// replace the dummy 0 with the real result by overwriting - /// `functions[offset]`. - DenseMap> deferredPhiValues; -}; -} // namespace - -Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) - : module(module), mlirBuilder(module.getContext()), - emitDebugInfo(emitDebugInfo) {} - -LogicalResult Serializer::serialize() { - LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); - - if (failed(module.verify())) - return failure(); - - // TODO: handle the other sections - processCapability(); - processExtension(); - processMemoryModel(); - processDebugInfo(); - - // Iterate over the module body to serialize it. Assumptions are that there is - // only one basic block in the moduleOp - for (auto &op : module.getBlock()) { - if (failed(processOperation(&op))) { - return failure(); - } - } - - LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); - return success(); -} - -void Serializer::collect(SmallVectorImpl &binary) { - auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + - extensions.size() + extendedSets.size() + - memoryModel.size() + entryPoints.size() + - executionModes.size() + decorations.size() + - typesGlobalValues.size() + functions.size(); - - binary.clear(); - binary.reserve(moduleSize); - - spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); - binary.append(capabilities.begin(), capabilities.end()); - binary.append(extensions.begin(), extensions.end()); - binary.append(extendedSets.begin(), extendedSets.end()); - binary.append(memoryModel.begin(), memoryModel.end()); - binary.append(entryPoints.begin(), entryPoints.end()); - binary.append(executionModes.begin(), executionModes.end()); - binary.append(debug.begin(), debug.end()); - binary.append(names.begin(), names.end()); - binary.append(decorations.begin(), decorations.end()); - binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); - binary.append(functions.begin(), functions.end()); -} - -#ifndef NDEBUG -void Serializer::printValueIDMap(raw_ostream &os) { - os << "\n= Value Map =\n\n"; - for (auto valueIDPair : valueIDMap) { - Value val = valueIDPair.first; - os << " " << val << " " - << "id = " << valueIDPair.second << ' '; - if (auto *op = val.getDefiningOp()) { - os << "from op '" << op->getName() << "'"; - } else if (auto arg = val.dyn_cast()) { - Block *block = arg.getOwner(); - os << "from argument of block " << block << ' '; - os << " in op '" << block->getParentOp()->getName() << "'"; - } - os << '\n'; - } -} -#endif - -//===----------------------------------------------------------------------===// -// Module structure -//===----------------------------------------------------------------------===// - -uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { - auto funcID = funcIDMap.lookup(fnName); - if (!funcID) { - funcID = getNextID(); - funcIDMap[fnName] = funcID; - } - return funcID; -} - -void Serializer::processCapability() { - for (auto cap : module.vce_triple()->getCapabilities()) - encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, - {static_cast(cap)}); -} - -void Serializer::processDebugInfo() { - if (!emitDebugInfo) - return; - auto fileLoc = module.getLoc().dyn_cast(); - auto fileName = fileLoc ? fileLoc.getFilename() : ""; - fileID = getNextID(); - SmallVector operands; - operands.push_back(fileID); - spirv::encodeStringLiteralInto(operands, fileName); - encodeInstructionInto(debug, spirv::Opcode::OpString, operands); - // TODO: Encode more debug instructions. -} - -void Serializer::processExtension() { - llvm::SmallVector extName; - for (spirv::Extension ext : module.vce_triple()->getExtensions()) { - extName.clear(); - spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); - encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); - } -} - -void Serializer::processMemoryModel() { - uint32_t mm = module->getAttrOfType("memory_model").getInt(); - uint32_t am = module->getAttrOfType("addressing_model").getInt(); - - encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); -} - -LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { - if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { - valueIDMap[op.getResult()] = resultID; - return success(); - } - return failure(); -} - -LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { - if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), - /*isSpec=*/true)) { - // Emit the OpDecorate instruction for SpecId. - if (auto specID = op->getAttrOfType("spec_id")) { - auto val = static_cast(specID.getInt()); - emitDecoration(resultID, spirv::Decoration::SpecId, {val}); - } - - specConstIDMap[op.sym_name()] = resultID; - return processName(resultID, op.sym_name()); - } - return failure(); -} - -LogicalResult -Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { - uint32_t typeID = 0; - if (failed(processType(op.getLoc(), op.type(), typeID))) { - return failure(); - } - - auto resultID = getNextID(); - - SmallVector operands; - operands.push_back(typeID); - operands.push_back(resultID); - - auto constituents = op.constituents(); - - for (auto index : llvm::seq(0, constituents.size())) { - auto constituent = constituents[index].dyn_cast(); - - auto constituentName = constituent.getValue(); - auto constituentID = getSpecConstID(constituentName); - - if (!constituentID) { - return op.emitError("unknown result for specialization constant ") - << constituentName; - } - - operands.push_back(constituentID); - } - - encodeInstructionInto(typesGlobalValues, - spirv::Opcode::OpSpecConstantComposite, operands); - specConstIDMap[op.sym_name()] = resultID; - - return processName(resultID, op.sym_name()); -} - -LogicalResult -Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { - uint32_t typeID = 0; - if (failed(processType(op.getLoc(), op.getType(), typeID))) { - return failure(); - } - - auto resultID = getNextID(); - - SmallVector operands; - operands.push_back(typeID); - operands.push_back(resultID); - - Block &block = op.getRegion().getBlocks().front(); - Operation &enclosedOp = block.getOperations().front(); - - std::string enclosedOpName; - llvm::raw_string_ostream rss(enclosedOpName); - rss << "Op" << enclosedOp.getName().stripDialect(); - auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); - - if (!enclosedOpcode) { - op.emitError("Couldn't find op code for op ") - << enclosedOp.getName().getStringRef(); - return failure(); - } - - operands.push_back(static_cast(enclosedOpcode.getValue())); - - // Append operands to the enclosed op to the list of operands. - for (Value operand : enclosedOp.getOperands()) { - uint32_t id = getValueID(operand); - assert(id && "use before def!"); - operands.push_back(id); - } - - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, - operands); - valueIDMap[op.getResult()] = resultID; - - return success(); -} - -LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { - auto undefType = op.getType(); - auto &id = undefValIDMap[undefType]; - if (!id) { - id = getNextID(); - uint32_t typeID = 0; - if (failed(processType(op.getLoc(), undefType, typeID)) || - failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, - {typeID, id}))) { - return failure(); - } - } - valueIDMap[op.getResult()] = id; - return success(); -} - -LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, - NamedAttribute attr) { - auto attrName = attr.first.strref(); - auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); - auto decoration = spirv::symbolizeDecoration(decorationName); - if (!decoration) { - return emitError( - loc, "non-argument attributes expected to have snake-case-ified " - "decoration name, unhandled attribute with name : ") - << attrName; - } - SmallVector args; - switch (decoration.getValue()) { - case spirv::Decoration::Binding: - case spirv::Decoration::DescriptorSet: - case spirv::Decoration::Location: - if (auto intAttr = attr.second.dyn_cast()) { - args.push_back(intAttr.getValue().getZExtValue()); - break; - } - return emitError(loc, "expected integer attribute for ") << attrName; - case spirv::Decoration::BuiltIn: - if (auto strAttr = attr.second.dyn_cast()) { - auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); - if (enumVal) { - args.push_back(static_cast(enumVal.getValue())); - break; - } - return emitError(loc, "invalid ") - << attrName << " attribute " << strAttr.getValue(); - } - return emitError(loc, "expected string attribute for ") << attrName; - case spirv::Decoration::Aliased: - case spirv::Decoration::Flat: - case spirv::Decoration::NonReadable: - case spirv::Decoration::NonWritable: - case spirv::Decoration::NoPerspective: - case spirv::Decoration::Restrict: - // For unit attributes, the args list has no values so we do nothing - if (auto unitAttr = attr.second.dyn_cast()) - break; - return emitError(loc, "expected unit attribute for ") << attrName; - default: - return emitError(loc, "unhandled decoration ") << decorationName; - } - return emitDecoration(resultID, decoration.getValue(), args); -} - -LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { - assert(!name.empty() && "unexpected empty string for OpName"); - - SmallVector nameOperands; - nameOperands.push_back(resultID); - if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { - return failure(); - } - return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); -} - -namespace { -template <> -LogicalResult Serializer::processTypeDecoration( - Location loc, spirv::ArrayType type, uint32_t resultID) { - if (unsigned stride = type.getArrayStride()) { - // OpDecorate %arrayTypeSSA ArrayStride strideLiteral - return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); - } - return success(); -} - -template <> -LogicalResult Serializer::processTypeDecoration( - Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) { - if (unsigned stride = type.getArrayStride()) { - // OpDecorate %arrayTypeSSA ArrayStride strideLiteral - return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); - } - return success(); -} - -LogicalResult Serializer::processMemberDecoration( - uint32_t structID, - const spirv::StructType::MemberDecorationInfo &memberDecoration) { - SmallVector args( - {structID, memberDecoration.memberIndex, - static_cast(memberDecoration.decoration)}); - if (memberDecoration.hasValue) { - args.push_back(memberDecoration.decorationValue); - } - return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, - args); -} -} // namespace - -LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { - LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); - assert(functionHeader.empty() && functionBody.empty()); - - uint32_t fnTypeID = 0; - // Generate type of the function. - processType(op.getLoc(), op.getType(), fnTypeID); - - // Add the function definition. - SmallVector operands; - uint32_t resTypeID = 0; - auto resultTypes = op.getType().getResults(); - if (resultTypes.size() > 1) { - return op.emitError("cannot serialize function with multiple return types"); - } - if (failed(processType(op.getLoc(), - (resultTypes.empty() ? getVoidType() : resultTypes[0]), - resTypeID))) { - return failure(); - } - operands.push_back(resTypeID); - auto funcID = getOrCreateFunctionID(op.getName()); - operands.push_back(funcID); - operands.push_back(static_cast(op.function_control())); - operands.push_back(fnTypeID); - encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); - - // Add function name. - if (failed(processName(funcID, op.getName()))) { - return failure(); - } - - // Declare the parameters. - for (auto arg : op.getArguments()) { - uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { - return failure(); - } - auto argValueID = getNextID(); - valueIDMap[arg] = argValueID; - encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, - {argTypeID, argValueID}); - } - - // Process the body. - if (op.isExternal()) { - return op.emitError("external function is unhandled"); - } - - // Some instructions (e.g., OpVariable) in a function must be in the first - // block in the function. These instructions will be put in functionHeader. - // Thus, we put the label in functionHeader first, and omit it from the first - // block. - encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, - {getOrCreateBlockID(&op.front())}); - processBlock(&op.front(), /*omitLabel=*/true); - if (failed(visitInPrettyBlockOrder( - &op.front(), [&](Block *block) { return processBlock(block); }, - /*skipHeader=*/true))) { - return failure(); - } - - // There might be OpPhi instructions who have value references needing to fix. - for (auto deferredValue : deferredPhiValues) { - Value value = deferredValue.first; - uint32_t id = getValueID(value); - LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value - << " to id = " << id << '\n'); - assert(id && "OpPhi references undefined value!"); - for (size_t offset : deferredValue.second) - functionBody[offset] = id; - } - deferredPhiValues.clear(); - - LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() - << "' --\n"); - // Insert OpFunctionEnd. - if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, - {}))) { - return failure(); - } - - functions.append(functionHeader.begin(), functionHeader.end()); - functions.append(functionBody.begin(), functionBody.end()); - functionHeader.clear(); - functionBody.clear(); - - return success(); -} - -LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { - SmallVector operands; - SmallVector elidedAttrs; - uint32_t resultID = 0; - uint32_t resultTypeID = 0; - if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { - return failure(); - } - operands.push_back(resultTypeID); - resultID = getNextID(); - valueIDMap[op.getResult()] = resultID; - operands.push_back(resultID); - auto attr = op->getAttr(spirv::attributeName()); - if (attr) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - elidedAttrs.push_back(spirv::attributeName()); - for (auto arg : op.getODSOperands(0)) { - auto argID = getValueID(arg); - if (!argID) { - return emitError(op.getLoc(), "operand 0 has a use before def"); - } - operands.push_back(argID); - } - emitDebugLine(functionHeader, op.getLoc()); - encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); - for (auto attr : op->getAttrs()) { - if (llvm::any_of(elidedAttrs, - [&](StringRef elided) { return attr.first == elided; })) { - continue; - } - if (failed(processDecoration(op.getLoc(), resultID, attr))) { - return failure(); - } - } - return success(); -} - -LogicalResult -Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { - // Get TypeID. - uint32_t resultTypeID = 0; - SmallVector elidedAttrs; - if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { - return failure(); - } - - if (isInterfaceStructPtrType(varOp.type())) { - auto structType = varOp.type() - .cast() - .getPointeeType() - .cast(); - if (failed( - emitDecoration(getTypeID(structType), spirv::Decoration::Block))) { - return varOp.emitError("cannot decorate ") - << structType << " with Block decoration"; - } - } - - elidedAttrs.push_back("type"); - SmallVector operands; - operands.push_back(resultTypeID); - auto resultID = getNextID(); - - // Encode the name. - auto varName = varOp.sym_name(); - elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); - if (failed(processName(resultID, varName))) { - return failure(); - } - globalVarIDMap[varName] = resultID; - operands.push_back(resultID); - - // Encode StorageClass. - operands.push_back(static_cast(varOp.storageClass())); - - // Encode initialization. - if (auto initializer = varOp.initializer()) { - auto initializerID = getVariableID(initializer.getValue()); - if (!initializerID) { - return emitError(varOp.getLoc(), - "invalid usage of undefined variable as initializer"); - } - operands.push_back(initializerID); - elidedAttrs.push_back("initializer"); - } - - emitDebugLine(typesGlobalValues, varOp.getLoc()); - if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, - operands))) { - elidedAttrs.push_back("initializer"); - return failure(); - } - - // Encode decorations. - for (auto attr : varOp->getAttrs()) { - if (llvm::any_of(elidedAttrs, - [&](StringRef elided) { return attr.first == elided; })) { - continue; - } - if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { - return failure(); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Type -//===----------------------------------------------------------------------===// - -// According to the SPIR-V spec "Validation Rules for Shader Capabilities": -// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and -// PushConstant Storage Classes must be explicitly laid out." -bool Serializer::isInterfaceStructPtrType(Type type) const { - if (auto ptrType = type.dyn_cast()) { - switch (ptrType.getStorageClass()) { - case spirv::StorageClass::PhysicalStorageBuffer: - case spirv::StorageClass::PushConstant: - case spirv::StorageClass::StorageBuffer: - case spirv::StorageClass::Uniform: - return ptrType.getPointeeType().isa(); - default: - break; - } - } - return false; -} - -LogicalResult Serializer::processType(Location loc, Type type, - uint32_t &typeID) { - // Maintains a set of names for nested identified struct types. This is used - // to properly serialize recursive references. - llvm::SetVector serializationCtx; - return processTypeImpl(loc, type, typeID, serializationCtx); -} - -LogicalResult -Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, - llvm::SetVector &serializationCtx) { - typeID = getTypeID(type); - if (typeID) { - return success(); - } - typeID = getNextID(); - SmallVector operands; - - operands.push_back(typeID); - auto typeEnum = spirv::Opcode::OpTypeVoid; - bool deferSerialization = false; - - if ((type.isa() && - succeeded(prepareFunctionType(loc, type.cast(), typeEnum, - operands))) || - succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, - deferSerialization, serializationCtx))) { - if (deferSerialization) - return success(); - - typeIDMap[type] = typeID; - - if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) - return failure(); - - if (recursiveStructInfos.count(type) != 0) { - // This recursive struct type is emitted already, now the OpTypePointer - // instructions referring to recursive references are emitted as well. - for (auto &ptrInfo : recursiveStructInfos[type]) { - // TODO: This might not work if more than 1 recursive reference is - // present in the struct. - SmallVector ptrOperands; - ptrOperands.push_back(ptrInfo.pointerTypeID); - ptrOperands.push_back(static_cast(ptrInfo.storageClass)); - ptrOperands.push_back(typeIDMap[type]); - - if (failed(encodeInstructionInto( - typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) - return failure(); - } - - recursiveStructInfos[type].clear(); - } - - return success(); - } - - return failure(); -} - -LogicalResult Serializer::prepareBasicType( - Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, - SmallVectorImpl &operands, bool &deferSerialization, - llvm::SetVector &serializationCtx) { - deferSerialization = false; - - if (isVoidType(type)) { - typeEnum = spirv::Opcode::OpTypeVoid; - return success(); - } - - if (auto intType = type.dyn_cast()) { - if (intType.getWidth() == 1) { - typeEnum = spirv::Opcode::OpTypeBool; - return success(); - } - - typeEnum = spirv::Opcode::OpTypeInt; - operands.push_back(intType.getWidth()); - // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics - // to preserve or validate. - // 0 indicates unsigned, or no signedness semantics - // 1 indicates signed semantics." - operands.push_back(intType.isSigned() ? 1 : 0); - return success(); - } - - if (auto floatType = type.dyn_cast()) { - typeEnum = spirv::Opcode::OpTypeFloat; - operands.push_back(floatType.getWidth()); - return success(); - } - - if (auto vectorType = type.dyn_cast()) { - uint32_t elementTypeID = 0; - if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, - serializationCtx))) { - return failure(); - } - typeEnum = spirv::Opcode::OpTypeVector; - operands.push_back(elementTypeID); - operands.push_back(vectorType.getNumElements()); - return success(); - } - - if (auto imageType = type.dyn_cast()) { - typeEnum = spirv::Opcode::OpTypeImage; - uint32_t sampledTypeID = 0; - if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) - return failure(); - - operands.push_back(sampledTypeID); - operands.push_back(static_cast(imageType.getDim())); - operands.push_back(static_cast(imageType.getDepthInfo())); - operands.push_back(static_cast(imageType.getArrayedInfo())); - operands.push_back(static_cast(imageType.getSamplingInfo())); - operands.push_back(static_cast(imageType.getSamplerUseInfo())); - operands.push_back(static_cast(imageType.getImageFormat())); - return success(); - } - - if (auto arrayType = type.dyn_cast()) { - typeEnum = spirv::Opcode::OpTypeArray; - uint32_t elementTypeID = 0; - if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, - serializationCtx))) { - return failure(); - } - operands.push_back(elementTypeID); - if (auto elementCountID = prepareConstantInt( - loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { - operands.push_back(elementCountID); - } - return processTypeDecoration(loc, arrayType, resultID); - } - - if (auto ptrType = type.dyn_cast()) { - uint32_t pointeeTypeID = 0; - spirv::StructType pointeeStruct = - ptrType.getPointeeType().dyn_cast(); - - if (pointeeStruct && pointeeStruct.isIdentified() && - serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { - // A recursive reference to an enclosing struct is found. - // - // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage - // class as operands. - SmallVector forwardPtrOperands; - forwardPtrOperands.push_back(resultID); - forwardPtrOperands.push_back( - static_cast(ptrType.getStorageClass())); - - encodeInstructionInto(typesGlobalValues, - spirv::Opcode::OpTypeForwardPointer, - forwardPtrOperands); - - // 2. Find the pointee (enclosing) struct. - auto structType = spirv::StructType::getIdentified( - module.getContext(), pointeeStruct.getIdentifier()); - - if (!structType) - return failure(); - - // 3. Mark the OpTypePointer that is supposed to be emitted by this call - // as deferred. - deferSerialization = true; - - // 4. Record the info needed to emit the deferred OpTypePointer - // instruction when the enclosing struct is completely serialized. - recursiveStructInfos[structType].push_back( - {resultID, ptrType.getStorageClass()}); - } else { - if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, - serializationCtx))) - return failure(); - } - - typeEnum = spirv::Opcode::OpTypePointer; - operands.push_back(static_cast(ptrType.getStorageClass())); - operands.push_back(pointeeTypeID); - return success(); - } - - if (auto runtimeArrayType = type.dyn_cast()) { - uint32_t elementTypeID = 0; - if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), - elementTypeID, serializationCtx))) { - return failure(); - } - typeEnum = spirv::Opcode::OpTypeRuntimeArray; - operands.push_back(elementTypeID); - return processTypeDecoration(loc, runtimeArrayType, resultID); - } - - if (auto structType = type.dyn_cast()) { - if (structType.isIdentified()) { - processName(resultID, structType.getIdentifier()); - serializationCtx.insert(structType.getIdentifier()); - } - - bool hasOffset = structType.hasOffset(); - for (auto elementIndex : - llvm::seq(0, structType.getNumElements())) { - uint32_t elementTypeID = 0; - if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), - elementTypeID, serializationCtx))) { - return failure(); - } - operands.push_back(elementTypeID); - if (hasOffset) { - // Decorate each struct member with an offset - spirv::StructType::MemberDecorationInfo offsetDecoration{ - elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, - static_cast(structType.getMemberOffset(elementIndex))}; - if (failed(processMemberDecoration(resultID, offsetDecoration))) { - return emitError(loc, "cannot decorate ") - << elementIndex << "-th member of " << structType - << " with its offset"; - } - } - } - SmallVector memberDecorations; - structType.getMemberDecorations(memberDecorations); - - for (auto &memberDecoration : memberDecorations) { - if (failed(processMemberDecoration(resultID, memberDecoration))) { - return emitError(loc, "cannot decorate ") - << static_cast(memberDecoration.memberIndex) - << "-th member of " << structType << " with " - << stringifyDecoration(memberDecoration.decoration); - } - } - - typeEnum = spirv::Opcode::OpTypeStruct; - - if (structType.isIdentified()) - serializationCtx.remove(structType.getIdentifier()); - - return success(); - } - - if (auto cooperativeMatrixType = - type.dyn_cast()) { - uint32_t elementTypeID = 0; - if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), - elementTypeID, serializationCtx))) { - return failure(); - } - typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; - auto getConstantOp = [&](uint32_t id) { - auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); - return prepareConstantInt(loc, attr); - }; - operands.push_back(elementTypeID); - operands.push_back( - getConstantOp(static_cast(cooperativeMatrixType.getScope()))); - operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); - operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); - return success(); - } - - if (auto matrixType = type.dyn_cast()) { - uint32_t elementTypeID = 0; - if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, - serializationCtx))) { - return failure(); - } - typeEnum = spirv::Opcode::OpTypeMatrix; - operands.push_back(elementTypeID); - operands.push_back(matrixType.getNumColumns()); - return success(); - } - - // TODO: Handle other types. - return emitError(loc, "unhandled type in serialization: ") << type; -} - -LogicalResult -Serializer::prepareFunctionType(Location loc, FunctionType type, - spirv::Opcode &typeEnum, - SmallVectorImpl &operands) { - typeEnum = spirv::Opcode::OpTypeFunction; - assert(type.getNumResults() <= 1 && - "serialization supports only a single return value"); - uint32_t resultID = 0; - if (failed(processType( - loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), - resultID))) { - return failure(); - } - operands.push_back(resultID); - for (auto &res : type.getInputs()) { - uint32_t argTypeID = 0; - if (failed(processType(loc, res, argTypeID))) { - return failure(); - } - operands.push_back(argTypeID); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Constant -//===----------------------------------------------------------------------===// - -uint32_t Serializer::prepareConstant(Location loc, Type constType, - Attribute valueAttr) { - if (auto id = prepareConstantScalar(loc, valueAttr)) { - return id; - } - - // This is a composite literal. We need to handle each component separately - // and then emit an OpConstantComposite for the whole. - - if (auto id = getConstantID(valueAttr)) { - return id; - } - - uint32_t typeID = 0; - if (failed(processType(loc, constType, typeID))) { - return 0; - } - - uint32_t resultID = 0; - if (auto attr = valueAttr.dyn_cast()) { - int rank = attr.getType().dyn_cast().getRank(); - SmallVector index(rank); - resultID = prepareDenseElementsConstant(loc, constType, attr, - /*dim=*/0, index); - } else if (auto arrayAttr = valueAttr.dyn_cast()) { - resultID = prepareArrayConstant(loc, constType, arrayAttr); - } - - if (resultID == 0) { - emitError(loc, "cannot serialize attribute: ") << valueAttr; - return 0; - } - - constIDMap[valueAttr] = resultID; - return resultID; -} - -uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, - ArrayAttr attr) { - uint32_t typeID = 0; - if (failed(processType(loc, constType, typeID))) { - return 0; - } - - uint32_t resultID = getNextID(); - SmallVector operands = {typeID, resultID}; - operands.reserve(attr.size() + 2); - auto elementType = constType.cast().getElementType(); - for (Attribute elementAttr : attr) { - if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { - operands.push_back(elementID); - } else { - return 0; - } - } - spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; - encodeInstructionInto(typesGlobalValues, opcode, operands); - - return resultID; -} - -// TODO: Turn the below function into iterative function, instead of -// recursive function. -uint32_t -Serializer::prepareDenseElementsConstant(Location loc, Type constType, - DenseElementsAttr valueAttr, int dim, - MutableArrayRef index) { - auto shapedType = valueAttr.getType().dyn_cast(); - assert(dim <= shapedType.getRank()); - if (shapedType.getRank() == dim) { - if (auto attr = valueAttr.dyn_cast()) { - return attr.getType().getElementType().isInteger(1) - ? prepareConstantBool(loc, attr.getValue(index)) - : prepareConstantInt(loc, attr.getValue(index)); - } - if (auto attr = valueAttr.dyn_cast()) { - return prepareConstantFp(loc, attr.getValue(index)); - } - return 0; - } - - uint32_t typeID = 0; - if (failed(processType(loc, constType, typeID))) { - return 0; - } - - uint32_t resultID = getNextID(); - SmallVector operands = {typeID, resultID}; - operands.reserve(shapedType.getDimSize(dim) + 2); - auto elementType = constType.cast().getElementType(0); - for (int i = 0; i < shapedType.getDimSize(dim); ++i) { - index[dim] = i; - if (auto elementID = prepareDenseElementsConstant( - loc, elementType, valueAttr, dim + 1, index)) { - operands.push_back(elementID); - } else { - return 0; - } - } - spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; - encodeInstructionInto(typesGlobalValues, opcode, operands); - - return resultID; -} - -uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, - bool isSpec) { - if (auto floatAttr = valueAttr.dyn_cast()) { - return prepareConstantFp(loc, floatAttr, isSpec); - } - if (auto boolAttr = valueAttr.dyn_cast()) { - return prepareConstantBool(loc, boolAttr, isSpec); - } - if (auto intAttr = valueAttr.dyn_cast()) { - return prepareConstantInt(loc, intAttr, isSpec); - } - - return 0; -} - -uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, - bool isSpec) { - if (!isSpec) { - // We can de-duplicate normal constants, but not specialization constants. - if (auto id = getConstantID(boolAttr)) { - return id; - } - } - - // Process the type for this bool literal - uint32_t typeID = 0; - if (failed(processType(loc, boolAttr.getType(), typeID))) { - return 0; - } - - auto resultID = getNextID(); - auto opcode = boolAttr.getValue() - ? (isSpec ? spirv::Opcode::OpSpecConstantTrue - : spirv::Opcode::OpConstantTrue) - : (isSpec ? spirv::Opcode::OpSpecConstantFalse - : spirv::Opcode::OpConstantFalse); - encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); - - if (!isSpec) { - constIDMap[boolAttr] = resultID; - } - return resultID; -} - -uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, - bool isSpec) { - if (!isSpec) { - // We can de-duplicate normal constants, but not specialization constants. - if (auto id = getConstantID(intAttr)) { - return id; - } - } - - // Process the type for this integer literal - uint32_t typeID = 0; - if (failed(processType(loc, intAttr.getType(), typeID))) { - return 0; - } - - auto resultID = getNextID(); - APInt value = intAttr.getValue(); - unsigned bitwidth = value.getBitWidth(); - bool isSigned = value.isSignedIntN(bitwidth); - - auto opcode = - isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; - - // According to SPIR-V spec, "When the type's bit width is less than 32-bits, - // the literal's value appears in the low-order bits of the word, and the - // high-order bits must be 0 for a floating-point type, or 0 for an integer - // type with Signedness of 0, or sign extended when Signedness is 1." - if (bitwidth == 32 || bitwidth == 16) { - uint32_t word = 0; - if (isSigned) { - word = static_cast(value.getSExtValue()); - } else { - word = static_cast(value.getZExtValue()); - } - encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); - } - // According to SPIR-V spec: "When the type's bit width is larger than one - // word, the literal’s low-order words appear first." - else if (bitwidth == 64) { - struct DoubleWord { - uint32_t word1; - uint32_t word2; - } words; - if (isSigned) { - words = llvm::bit_cast(value.getSExtValue()); - } else { - words = llvm::bit_cast(value.getZExtValue()); - } - encodeInstructionInto(typesGlobalValues, opcode, - {typeID, resultID, words.word1, words.word2}); - } else { - std::string valueStr; - llvm::raw_string_ostream rss(valueStr); - value.print(rss, /*isSigned=*/false); - - emitError(loc, "cannot serialize ") - << bitwidth << "-bit integer literal: " << rss.str(); - return 0; - } - - if (!isSpec) { - constIDMap[intAttr] = resultID; - } - return resultID; -} - -uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, - bool isSpec) { - if (!isSpec) { - // We can de-duplicate normal constants, but not specialization constants. - if (auto id = getConstantID(floatAttr)) { - return id; - } - } - - // Process the type for this float literal - uint32_t typeID = 0; - if (failed(processType(loc, floatAttr.getType(), typeID))) { - return 0; - } - - auto resultID = getNextID(); - APFloat value = floatAttr.getValue(); - APInt intValue = value.bitcastToAPInt(); - - auto opcode = - isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; - - if (&value.getSemantics() == &APFloat::IEEEsingle()) { - uint32_t word = llvm::bit_cast(value.convertToFloat()); - encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); - } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { - struct DoubleWord { - uint32_t word1; - uint32_t word2; - } words = llvm::bit_cast(value.convertToDouble()); - encodeInstructionInto(typesGlobalValues, opcode, - {typeID, resultID, words.word1, words.word2}); - } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { - uint32_t word = - static_cast(value.bitcastToAPInt().getZExtValue()); - encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); - } else { - std::string valueStr; - llvm::raw_string_ostream rss(valueStr); - value.print(rss); - - emitError(loc, "cannot serialize ") - << floatAttr.getType() << "-typed float literal: " << rss.str(); - return 0; - } - - if (!isSpec) { - constIDMap[floatAttr] = resultID; - } - return resultID; -} - -//===----------------------------------------------------------------------===// -// Control flow -//===----------------------------------------------------------------------===// - -uint32_t Serializer::getOrCreateBlockID(Block *block) { - if (uint32_t id = getBlockID(block)) - return id; - return blockIDMap[block] = getNextID(); -} - -LogicalResult -Serializer::processBlock(Block *block, bool omitLabel, - function_ref actionBeforeTerminator) { - LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); - LLVM_DEBUG(block->print(llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << '\n'); - if (!omitLabel) { - uint32_t blockID = getOrCreateBlockID(block); - LLVM_DEBUG(llvm::dbgs() - << "[block] " << block << " (id = " << blockID << ")\n"); - - // Emit OpLabel for this block. - encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); - } - - // Emit OpPhi instructions for block arguments, if any. - if (failed(emitPhiForBlockArguments(block))) - return failure(); - - // Process each op in this block except the terminator. - for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { - if (failed(processOperation(&op))) - return failure(); - } - - // Process the terminator. - if (actionBeforeTerminator) - actionBeforeTerminator(); - if (failed(processOperation(&block->back()))) - return failure(); - - return success(); -} - -LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { - // Nothing to do if this block has no arguments or it's the entry block, which - // always has the same arguments as the function signature. - if (block->args_empty() || block->isEntryBlock()) - return success(); - - // If the block has arguments, we need to create SPIR-V OpPhi instructions. - // A SPIR-V OpPhi instruction is of the syntax: - // OpPhi | result type | result | (value , parent block ) pair - // So we need to collect all predecessor blocks and the arguments they send - // to this block. - SmallVector, 4> predecessors; - for (Block *predecessor : block->getPredecessors()) { - auto *terminator = predecessor->getTerminator(); - // The predecessor here is the immediate one according to MLIR's IR - // structure. It does not directly map to the incoming parent block for the - // OpPhi instructions at SPIR-V binary level. This is because structured - // control flow ops are serialized to multiple SPIR-V blocks. If there is a - // spv.selection/spv.loop op in the MLIR predecessor block, the branch op - // jumping to the OpPhi's block then resides in the previous structured - // control flow op's merge block. - predecessor = getPhiIncomingBlock(predecessor); - if (auto branchOp = dyn_cast(terminator)) { - predecessors.emplace_back(predecessor, branchOp.operand_begin()); - } else { - return terminator->emitError("unimplemented terminator for Phi creation"); - } - } - - // Then create OpPhi instruction for each of the block argument. - for (auto argIndex : llvm::seq(0, block->getNumArguments())) { - BlockArgument arg = block->getArgument(argIndex); - - // Get the type and result for this OpPhi instruction. - uint32_t phiTypeID = 0; - if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) - return failure(); - uint32_t phiID = getNextID(); - - LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' - << arg << " (id = " << phiID << ")\n"); - - // Prepare the (value , parent block ) pairs. - SmallVector phiArgs; - phiArgs.push_back(phiTypeID); - phiArgs.push_back(phiID); - - for (auto predIndex : llvm::seq(0, predecessors.size())) { - Value value = *(predecessors[predIndex].second + argIndex); - uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); - LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId - << ") value " << value << ' '); - // Each pair is a value ... - uint32_t valueId = getValueID(value); - if (valueId == 0) { - // The op generating this value hasn't been visited yet so we don't have - // an assigned yet. Record this to fix up later. - LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); - deferredPhiValues[value].push_back(functionBody.size() + 1 + - phiArgs.size()); - } else { - LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); - } - phiArgs.push_back(valueId); - // ... and a parent block . - phiArgs.push_back(predBlockId); - } - - encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); - valueIDMap[arg] = phiID; - } - - return success(); -} - -LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { - // Assign s to all blocks so that branches inside the SelectionOp can - // resolve properly. - auto &body = selectionOp.body(); - for (Block &block : body) - getOrCreateBlockID(&block); - - auto *headerBlock = selectionOp.getHeaderBlock(); - auto *mergeBlock = selectionOp.getMergeBlock(); - auto mergeID = getBlockID(mergeBlock); - auto loc = selectionOp.getLoc(); - - // Emit the selection header block, which dominates all other blocks, first. - // We need to emit an OpSelectionMerge instruction before the selection header - // block's terminator. - auto emitSelectionMerge = [&]() { - emitDebugLine(functionBody, loc); - lastProcessedWasMergeInst = true; - encodeInstructionInto( - functionBody, spirv::Opcode::OpSelectionMerge, - {mergeID, static_cast(selectionOp.selection_control())}); - }; - // For structured selection, we cannot have blocks in the selection construct - // branching to the selection header block. Entering the selection (and - // reaching the selection header) must be from the block containing the - // spv.selection op. If there are ops ahead of the spv.selection op in the - // block, we can "merge" them into the selection header. So here we don't need - // to emit a separate block; just continue with the existing block. - if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge))) - return failure(); - - // Process all blocks with a depth-first visitor starting from the header - // block. The selection header block and merge block are skipped by this - // visitor. - if (failed(visitInPrettyBlockOrder( - headerBlock, [&](Block *block) { return processBlock(block); }, - /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) - return failure(); - - // There is nothing to do for the merge block in the selection, which just - // contains a spv.mlir.merge op, itself. But we need to have an OpLabel - // instruction to start a new SPIR-V block for ops following this SelectionOp. - // The block should use the for the merge block. - return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); -} - -LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { - // Assign s to all blocks so that branches inside the LoopOp can resolve - // properly. We don't need to assign for the entry block, which is just for - // satisfying MLIR region's structural requirement. - auto &body = loopOp.body(); - for (Block &block : - llvm::make_range(std::next(body.begin(), 1), body.end())) { - getOrCreateBlockID(&block); - } - auto *headerBlock = loopOp.getHeaderBlock(); - auto *continueBlock = loopOp.getContinueBlock(); - auto *mergeBlock = loopOp.getMergeBlock(); - auto headerID = getBlockID(headerBlock); - auto continueID = getBlockID(continueBlock); - auto mergeID = getBlockID(mergeBlock); - auto loc = loopOp.getLoc(); - - // This LoopOp is in some MLIR block with preceding and following ops. In the - // binary format, it should reside in separate SPIR-V blocks from its - // preceding and following ops. So we need to emit unconditional branches to - // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow - // afterwards. - encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); - - // LoopOp's entry block is just there for satisfying MLIR's structural - // requirements so we omit it and start serialization from the loop header - // block. - - // Emit the loop header block, which dominates all other blocks, first. We - // need to emit an OpLoopMerge instruction before the loop header block's - // terminator. - auto emitLoopMerge = [&]() { - emitDebugLine(functionBody, loc); - lastProcessedWasMergeInst = true; - encodeInstructionInto( - functionBody, spirv::Opcode::OpLoopMerge, - {mergeID, continueID, static_cast(loopOp.loop_control())}); - }; - if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) - return failure(); - - // Process all blocks with a depth-first visitor starting from the header - // block. The loop header block, loop continue block, and loop merge block are - // skipped by this visitor and handled later in this function. - if (failed(visitInPrettyBlockOrder( - headerBlock, [&](Block *block) { return processBlock(block); }, - /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) - return failure(); - - // We have handled all other blocks. Now get to the loop continue block. - if (failed(processBlock(continueBlock))) - return failure(); - - // There is nothing to do for the merge block in the loop, which just contains - // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to - // start a new SPIR-V block for ops following this LoopOp. The block should - // use the for the merge block. - return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); -} - -LogicalResult Serializer::processBranchConditionalOp( - spirv::BranchConditionalOp condBranchOp) { - auto conditionID = getValueID(condBranchOp.condition()); - auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); - auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); - SmallVector arguments{conditionID, trueLabelID, falseLabelID}; - - if (auto weights = condBranchOp.branch_weights()) { - for (auto val : weights->getValue()) - arguments.push_back(val.cast().getInt()); - } - - emitDebugLine(functionBody, condBranchOp.getLoc()); - return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, - arguments); -} - -LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { - emitDebugLine(functionBody, branchOp.getLoc()); - return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, - {getOrCreateBlockID(branchOp.getTarget())}); -} - -//===----------------------------------------------------------------------===// -// Operation -//===----------------------------------------------------------------------===// - -LogicalResult Serializer::encodeExtensionInstruction( - Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, - ArrayRef operands) { - // Check if the extension has been imported. - auto &setID = extendedInstSetIDMap[extensionSetName]; - if (!setID) { - setID = getNextID(); - SmallVector importOperands; - importOperands.push_back(setID); - if (failed( - spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || - failed(encodeInstructionInto( - extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { - return failure(); - } - } - - // The first two operands are the result type and result . The set - // and the opcode need to be insert after this. - if (operands.size() < 2) { - return op->emitError("extended instructions must have a result encoding"); - } - SmallVector extInstOperands; - extInstOperands.reserve(operands.size() + 2); - extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); - extInstOperands.push_back(setID); - extInstOperands.push_back(extensionOpcode); - extInstOperands.append(std::next(operands.begin(), 2), operands.end()); - return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, - extInstOperands); -} - -LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { - auto varName = addressOfOp.variable(); - auto variableID = getVariableID(varName); - if (!variableID) { - return addressOfOp.emitError("unknown result for variable ") - << varName; - } - valueIDMap[addressOfOp.pointer()] = variableID; - return success(); -} - -LogicalResult -Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { - auto constName = referenceOfOp.spec_const(); - auto constID = getSpecConstID(constName); - if (!constID) { - return referenceOfOp.emitError( - "unknown result for specialization constant ") - << constName; - } - valueIDMap[referenceOfOp.reference()] = constID; - return success(); -} - -LogicalResult Serializer::processOperation(Operation *opInst) { - LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); - - // First dispatch the ops that do not directly mirror an instruction from - // the SPIR-V spec. - return TypeSwitch(opInst) - .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) - .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) - .Case([&](spirv::BranchConditionalOp op) { - return processBranchConditionalOp(op); - }) - .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) - .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) - .Case([&](spirv::GlobalVariableOp op) { - return processGlobalVariableOp(op); - }) - .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) - .Case([&](spirv::ModuleEndOp) { return success(); }) - .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) - .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) - .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) - .Case([&](spirv::SpecConstantCompositeOp op) { - return processSpecConstantCompositeOp(op); - }) - .Case([&](spirv::SpecConstantOperationOp op) { - return processSpecConstantOperationOp(op); - }) - .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) - .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) - - // Then handle all the ops that directly mirror SPIR-V instructions with - // auto-generated methods. - .Default( - [&](Operation *op) { return dispatchToAutogenSerialization(op); }); -} - -LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, - StringRef extInstSet, - uint32_t opcode) { - SmallVector operands; - Location loc = op->getLoc(); - - uint32_t resultID = 0; - if (op->getNumResults() != 0) { - uint32_t resultTypeID = 0; - if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) - return failure(); - operands.push_back(resultTypeID); - - resultID = getNextID(); - operands.push_back(resultID); - valueIDMap[op->getResult(0)] = resultID; - }; - - for (Value operand : op->getOperands()) - operands.push_back(getValueID(operand)); - - emitDebugLine(functionBody, loc); - - if (extInstSet.empty()) { - encodeInstructionInto(functionBody, static_cast(opcode), - operands); - } else { - encodeExtensionInstruction(op, extInstSet, opcode, operands); - } - - if (op->getNumResults() != 0) { - for (auto attr : op->getAttrs()) { - if (failed(processDecoration(loc, resultID, attr))) - return failure(); - } - } - - return success(); -} - -namespace { -template <> -LogicalResult -Serializer::processOp(spirv::EntryPointOp op) { - SmallVector operands; - // Add the ExecutionModel. - operands.push_back(static_cast(op.execution_model())); - // Add the function . - auto funcID = getFunctionID(op.fn()); - if (!funcID) { - return op.emitError("missing for function ") - << op.fn() - << "; function needs to be defined before spv.EntryPoint is " - "serialized"; - } - operands.push_back(funcID); - // Add the name of the function. - spirv::encodeStringLiteralInto(operands, op.fn()); - - // Add the interface values. - if (auto interface = op.interface()) { - for (auto var : interface.getValue()) { - auto id = getVariableID(var.cast().getValue()); - if (!id) { - return op.emitError("referencing undefined global variable." - "spv.EntryPoint is at the end of spv.module. All " - "referenced variables should already be defined"); - } - operands.push_back(id); - } - } - return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, - operands); -} - -template <> -LogicalResult -Serializer::processOp(spirv::ControlBarrierOp op) { - StringRef argNames[] = {"execution_scope", "memory_scope", - "memory_semantics"}; - SmallVector operands; - - for (auto argName : argNames) { - auto argIntAttr = op->getAttrOfType(argName); - auto operand = prepareConstantInt(op.getLoc(), argIntAttr); - if (!operand) { - return failure(); - } - operands.push_back(operand); - } - - return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, - operands); -} - -template <> -LogicalResult -Serializer::processOp(spirv::ExecutionModeOp op) { - SmallVector operands; - // Add the function . - auto funcID = getFunctionID(op.fn()); - if (!funcID) { - return op.emitError("missing for function ") - << op.fn() - << "; function needs to be serialized before ExecutionModeOp is " - "serialized"; - } - operands.push_back(funcID); - // Add the ExecutionMode. - operands.push_back(static_cast(op.execution_mode())); - - // Serialize values if any. - auto values = op.values(); - if (values) { - for (auto &intVal : values.getValue()) { - operands.push_back(static_cast( - intVal.cast().getValue().getZExtValue())); - } - } - return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, - operands); -} - -template <> -LogicalResult -Serializer::processOp(spirv::MemoryBarrierOp op) { - StringRef argNames[] = {"memory_scope", "memory_semantics"}; - SmallVector operands; - - for (auto argName : argNames) { - auto argIntAttr = op->getAttrOfType(argName); - auto operand = prepareConstantInt(op.getLoc(), argIntAttr); - if (!operand) { - return failure(); - } - operands.push_back(operand); - } - - return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, - operands); -} - -template <> -LogicalResult -Serializer::processOp(spirv::FunctionCallOp op) { - auto funcName = op.callee(); - uint32_t resTypeID = 0; - - Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); - if (failed(processType(op.getLoc(), resultTy, resTypeID))) - return failure(); - - auto funcID = getOrCreateFunctionID(funcName); - auto funcCallID = getNextID(); - SmallVector operands{resTypeID, funcCallID, funcID}; - - for (auto value : op.arguments()) { - auto valueID = getValueID(value); - assert(valueID && "cannot find a value for spv.FunctionCall"); - operands.push_back(valueID); - } - - if (!resultTy.isa()) - valueIDMap[op.getResult(0)] = funcCallID; - - return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, - operands); -} - -template <> -LogicalResult -Serializer::processOp(spirv::CopyMemoryOp op) { - SmallVector operands; - SmallVector elidedAttrs; - - for (Value operand : op->getOperands()) { - auto id = getValueID(operand); - assert(id && "use before def!"); - operands.push_back(id); - } - - if (auto attr = op->getAttr("memory_access")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - - elidedAttrs.push_back("memory_access"); - - if (auto attr = op->getAttr("alignment")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - - elidedAttrs.push_back("alignment"); - - if (auto attr = op->getAttr("source_memory_access")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - - elidedAttrs.push_back("source_memory_access"); - - if (auto attr = op->getAttr("source_alignment")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - - elidedAttrs.push_back("source_alignment"); - emitDebugLine(functionBody, op.getLoc()); - encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); - - return success(); -} - -// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and -// various Serializer::processOp<...>() specializations. -#define GET_SERIALIZATION_FNS -#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" -} // namespace - -LogicalResult Serializer::emitDecoration(uint32_t target, - spirv::Decoration decoration, - ArrayRef params) { - uint32_t wordCount = 3 + params.size(); - decorations.push_back( - spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); - decorations.push_back(target); - decorations.push_back(static_cast(decoration)); - decorations.append(params.begin(), params.end()); - return success(); -} - -LogicalResult Serializer::emitDebugLine(SmallVectorImpl &binary, - Location loc) { - if (!emitDebugInfo) - return success(); - - if (lastProcessedWasMergeInst) { - lastProcessedWasMergeInst = false; - return success(); - } - - auto fileLoc = loc.dyn_cast(); - if (fileLoc) - encodeInstructionInto(binary, spirv::Opcode::OpLine, - {fileID, fileLoc.getLine(), fileLoc.getColumn()}); - return success(); -} - namespace mlir { LogicalResult spirv::serialize(spirv::ModuleOp module, SmallVectorImpl &binary, diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -0,0 +1,707 @@ +//===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the serialization methods for MLIR SPIR-V module ops. +// +//===----------------------------------------------------------------------===// + +#include "Serializer.h" + +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/IR/RegionGraphTraits.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "spirv-serialization" + +using namespace mlir; + +/// A pre-order depth-first visitor function for processing basic blocks. +/// +/// Visits the basic blocks starting from the given `headerBlock` in pre-order +/// depth-first manner and calls `blockHandler` on each block. Skips handling +/// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` +/// will not be invoked in `headerBlock` but still handles all `headerBlock`'s +/// successors. +/// +/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order +/// of blocks in a function must satisfy the rule that blocks appear before +/// all blocks they dominate." This can be achieved by a pre-order CFG +/// traversal algorithm. To make the serialization output more logical and +/// readable to human, we perform depth-first CFG traversal and delay the +/// serialization of the merge block and the continue block, if exists, until +/// after all other blocks have been processed. +static LogicalResult +visitInPrettyBlockOrder(Block *headerBlock, + function_ref blockHandler, + bool skipHeader = false, BlockRange skipBlocks = {}) { + llvm::df_iterator_default_set doneBlocks; + doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); + + for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { + if (skipHeader && block == headerBlock) + continue; + if (failed(blockHandler(block))) + return failure(); + } + return success(); +} + +namespace mlir { +namespace spirv { +LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { + if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { + valueIDMap[op.getResult()] = resultID; + return success(); + } + return failure(); +} + +LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { + if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), + /*isSpec=*/true)) { + // Emit the OpDecorate instruction for SpecId. + if (auto specID = op->getAttrOfType("spec_id")) { + auto val = static_cast(specID.getInt()); + emitDecoration(resultID, spirv::Decoration::SpecId, {val}); + } + + specConstIDMap[op.sym_name()] = resultID; + return processName(resultID, op.sym_name()); + } + return failure(); +} + +LogicalResult +Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { + uint32_t typeID = 0; + if (failed(processType(op.getLoc(), op.type(), typeID))) { + return failure(); + } + + auto resultID = getNextID(); + + SmallVector operands; + operands.push_back(typeID); + operands.push_back(resultID); + + auto constituents = op.constituents(); + + for (auto index : llvm::seq(0, constituents.size())) { + auto constituent = constituents[index].dyn_cast(); + + auto constituentName = constituent.getValue(); + auto constituentID = getSpecConstID(constituentName); + + if (!constituentID) { + return op.emitError("unknown result for specialization constant ") + << constituentName; + } + + operands.push_back(constituentID); + } + + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpSpecConstantComposite, operands); + specConstIDMap[op.sym_name()] = resultID; + + return processName(resultID, op.sym_name()); +} + +LogicalResult +Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { + uint32_t typeID = 0; + if (failed(processType(op.getLoc(), op.getType(), typeID))) { + return failure(); + } + + auto resultID = getNextID(); + + SmallVector operands; + operands.push_back(typeID); + operands.push_back(resultID); + + Block &block = op.getRegion().getBlocks().front(); + Operation &enclosedOp = block.getOperations().front(); + + std::string enclosedOpName; + llvm::raw_string_ostream rss(enclosedOpName); + rss << "Op" << enclosedOp.getName().stripDialect(); + auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); + + if (!enclosedOpcode) { + op.emitError("Couldn't find op code for op ") + << enclosedOp.getName().getStringRef(); + return failure(); + } + + operands.push_back(static_cast(enclosedOpcode.getValue())); + + // Append operands to the enclosed op to the list of operands. + for (Value operand : enclosedOp.getOperands()) { + uint32_t id = getValueID(operand); + assert(id && "use before def!"); + operands.push_back(id); + } + + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, + operands); + valueIDMap[op.getResult()] = resultID; + + return success(); +} + +LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { + auto undefType = op.getType(); + auto &id = undefValIDMap[undefType]; + if (!id) { + id = getNextID(); + uint32_t typeID = 0; + if (failed(processType(op.getLoc(), undefType, typeID)) || + failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, + {typeID, id}))) { + return failure(); + } + } + valueIDMap[op.getResult()] = id; + return success(); +} + +LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { + LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); + assert(functionHeader.empty() && functionBody.empty()); + + uint32_t fnTypeID = 0; + // Generate type of the function. + processType(op.getLoc(), op.getType(), fnTypeID); + + // Add the function definition. + SmallVector operands; + uint32_t resTypeID = 0; + auto resultTypes = op.getType().getResults(); + if (resultTypes.size() > 1) { + return op.emitError("cannot serialize function with multiple return types"); + } + if (failed(processType(op.getLoc(), + (resultTypes.empty() ? getVoidType() : resultTypes[0]), + resTypeID))) { + return failure(); + } + operands.push_back(resTypeID); + auto funcID = getOrCreateFunctionID(op.getName()); + operands.push_back(funcID); + operands.push_back(static_cast(op.function_control())); + operands.push_back(fnTypeID); + encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); + + // Add function name. + if (failed(processName(funcID, op.getName()))) { + return failure(); + } + + // Declare the parameters. + for (auto arg : op.getArguments()) { + uint32_t argTypeID = 0; + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + return failure(); + } + auto argValueID = getNextID(); + valueIDMap[arg] = argValueID; + encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, + {argTypeID, argValueID}); + } + + // Process the body. + if (op.isExternal()) { + return op.emitError("external function is unhandled"); + } + + // Some instructions (e.g., OpVariable) in a function must be in the first + // block in the function. These instructions will be put in functionHeader. + // Thus, we put the label in functionHeader first, and omit it from the first + // block. + encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, + {getOrCreateBlockID(&op.front())}); + processBlock(&op.front(), /*omitLabel=*/true); + if (failed(visitInPrettyBlockOrder( + &op.front(), [&](Block *block) { return processBlock(block); }, + /*skipHeader=*/true))) { + return failure(); + } + + // There might be OpPhi instructions who have value references needing to fix. + for (auto deferredValue : deferredPhiValues) { + Value value = deferredValue.first; + uint32_t id = getValueID(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value + << " to id = " << id << '\n'); + assert(id && "OpPhi references undefined value!"); + for (size_t offset : deferredValue.second) + functionBody[offset] = id; + } + deferredPhiValues.clear(); + + LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() + << "' --\n"); + // Insert OpFunctionEnd. + if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, + {}))) { + return failure(); + } + + functions.append(functionHeader.begin(), functionHeader.end()); + functions.append(functionBody.begin(), functionBody.end()); + functionHeader.clear(); + functionBody.clear(); + + return success(); +} + +LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { + SmallVector operands; + SmallVector elidedAttrs; + uint32_t resultID = 0; + uint32_t resultTypeID = 0; + if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { + return failure(); + } + operands.push_back(resultTypeID); + resultID = getNextID(); + valueIDMap[op.getResult()] = resultID; + operands.push_back(resultID); + auto attr = op->getAttr(spirv::attributeName()); + if (attr) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + elidedAttrs.push_back(spirv::attributeName()); + for (auto arg : op.getODSOperands(0)) { + auto argID = getValueID(arg); + if (!argID) { + return emitError(op.getLoc(), "operand 0 has a use before def"); + } + operands.push_back(argID); + } + emitDebugLine(functionHeader, op.getLoc()); + encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); + for (auto attr : op->getAttrs()) { + if (llvm::any_of(elidedAttrs, + [&](StringRef elided) { return attr.first == elided; })) { + continue; + } + if (failed(processDecoration(op.getLoc(), resultID, attr))) { + return failure(); + } + } + return success(); +} + +LogicalResult +Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { + // Get TypeID. + uint32_t resultTypeID = 0; + SmallVector elidedAttrs; + if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { + return failure(); + } + + if (isInterfaceStructPtrType(varOp.type())) { + auto structType = varOp.type() + .cast() + .getPointeeType() + .cast(); + if (failed( + emitDecoration(getTypeID(structType), spirv::Decoration::Block))) { + return varOp.emitError("cannot decorate ") + << structType << " with Block decoration"; + } + } + + elidedAttrs.push_back("type"); + SmallVector operands; + operands.push_back(resultTypeID); + auto resultID = getNextID(); + + // Encode the name. + auto varName = varOp.sym_name(); + elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); + if (failed(processName(resultID, varName))) { + return failure(); + } + globalVarIDMap[varName] = resultID; + operands.push_back(resultID); + + // Encode StorageClass. + operands.push_back(static_cast(varOp.storageClass())); + + // Encode initialization. + if (auto initializer = varOp.initializer()) { + auto initializerID = getVariableID(initializer.getValue()); + if (!initializerID) { + return emitError(varOp.getLoc(), + "invalid usage of undefined variable as initializer"); + } + operands.push_back(initializerID); + elidedAttrs.push_back("initializer"); + } + + emitDebugLine(typesGlobalValues, varOp.getLoc()); + if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, + operands))) { + elidedAttrs.push_back("initializer"); + return failure(); + } + + // Encode decorations. + for (auto attr : varOp->getAttrs()) { + if (llvm::any_of(elidedAttrs, + [&](StringRef elided) { return attr.first == elided; })) { + continue; + } + if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { + return failure(); + } + } + return success(); +} + +LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { + // Assign s to all blocks so that branches inside the SelectionOp can + // resolve properly. + auto &body = selectionOp.body(); + for (Block &block : body) + getOrCreateBlockID(&block); + + auto *headerBlock = selectionOp.getHeaderBlock(); + auto *mergeBlock = selectionOp.getMergeBlock(); + auto mergeID = getBlockID(mergeBlock); + auto loc = selectionOp.getLoc(); + + // Emit the selection header block, which dominates all other blocks, first. + // We need to emit an OpSelectionMerge instruction before the selection header + // block's terminator. + auto emitSelectionMerge = [&]() { + emitDebugLine(functionBody, loc); + lastProcessedWasMergeInst = true; + encodeInstructionInto( + functionBody, spirv::Opcode::OpSelectionMerge, + {mergeID, static_cast(selectionOp.selection_control())}); + }; + // For structured selection, we cannot have blocks in the selection construct + // branching to the selection header block. Entering the selection (and + // reaching the selection header) must be from the block containing the + // spv.selection op. If there are ops ahead of the spv.selection op in the + // block, we can "merge" them into the selection header. So here we don't need + // to emit a separate block; just continue with the existing block. + if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge))) + return failure(); + + // Process all blocks with a depth-first visitor starting from the header + // block. The selection header block and merge block are skipped by this + // visitor. + if (failed(visitInPrettyBlockOrder( + headerBlock, [&](Block *block) { return processBlock(block); }, + /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) + return failure(); + + // There is nothing to do for the merge block in the selection, which just + // contains a spv.mlir.merge op, itself. But we need to have an OpLabel + // instruction to start a new SPIR-V block for ops following this SelectionOp. + // The block should use the for the merge block. + return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); +} + +LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { + // Assign s to all blocks so that branches inside the LoopOp can resolve + // properly. We don't need to assign for the entry block, which is just for + // satisfying MLIR region's structural requirement. + auto &body = loopOp.body(); + for (Block &block : + llvm::make_range(std::next(body.begin(), 1), body.end())) { + getOrCreateBlockID(&block); + } + auto *headerBlock = loopOp.getHeaderBlock(); + auto *continueBlock = loopOp.getContinueBlock(); + auto *mergeBlock = loopOp.getMergeBlock(); + auto headerID = getBlockID(headerBlock); + auto continueID = getBlockID(continueBlock); + auto mergeID = getBlockID(mergeBlock); + auto loc = loopOp.getLoc(); + + // This LoopOp is in some MLIR block with preceding and following ops. In the + // binary format, it should reside in separate SPIR-V blocks from its + // preceding and following ops. So we need to emit unconditional branches to + // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow + // afterwards. + encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); + + // LoopOp's entry block is just there for satisfying MLIR's structural + // requirements so we omit it and start serialization from the loop header + // block. + + // Emit the loop header block, which dominates all other blocks, first. We + // need to emit an OpLoopMerge instruction before the loop header block's + // terminator. + auto emitLoopMerge = [&]() { + emitDebugLine(functionBody, loc); + lastProcessedWasMergeInst = true; + encodeInstructionInto( + functionBody, spirv::Opcode::OpLoopMerge, + {mergeID, continueID, static_cast(loopOp.loop_control())}); + }; + if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) + return failure(); + + // Process all blocks with a depth-first visitor starting from the header + // block. The loop header block, loop continue block, and loop merge block are + // skipped by this visitor and handled later in this function. + if (failed(visitInPrettyBlockOrder( + headerBlock, [&](Block *block) { return processBlock(block); }, + /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) + return failure(); + + // We have handled all other blocks. Now get to the loop continue block. + if (failed(processBlock(continueBlock))) + return failure(); + + // There is nothing to do for the merge block in the loop, which just contains + // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to + // start a new SPIR-V block for ops following this LoopOp. The block should + // use the for the merge block. + return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); +} + +LogicalResult Serializer::processBranchConditionalOp( + spirv::BranchConditionalOp condBranchOp) { + auto conditionID = getValueID(condBranchOp.condition()); + auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); + auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); + SmallVector arguments{conditionID, trueLabelID, falseLabelID}; + + if (auto weights = condBranchOp.branch_weights()) { + for (auto val : weights->getValue()) + arguments.push_back(val.cast().getInt()); + } + + emitDebugLine(functionBody, condBranchOp.getLoc()); + return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, + arguments); +} + +LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { + emitDebugLine(functionBody, branchOp.getLoc()); + return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, + {getOrCreateBlockID(branchOp.getTarget())}); +} + +LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { + auto varName = addressOfOp.variable(); + auto variableID = getVariableID(varName); + if (!variableID) { + return addressOfOp.emitError("unknown result for variable ") + << varName; + } + valueIDMap[addressOfOp.pointer()] = variableID; + return success(); +} + +LogicalResult +Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { + auto constName = referenceOfOp.spec_const(); + auto constID = getSpecConstID(constName); + if (!constID) { + return referenceOfOp.emitError( + "unknown result for specialization constant ") + << constName; + } + valueIDMap[referenceOfOp.reference()] = constID; + return success(); +} + +template <> +LogicalResult +Serializer::processOp(spirv::EntryPointOp op) { + SmallVector operands; + // Add the ExecutionModel. + operands.push_back(static_cast(op.execution_model())); + // Add the function . + auto funcID = getFunctionID(op.fn()); + if (!funcID) { + return op.emitError("missing for function ") + << op.fn() + << "; function needs to be defined before spv.EntryPoint is " + "serialized"; + } + operands.push_back(funcID); + // Add the name of the function. + spirv::encodeStringLiteralInto(operands, op.fn()); + + // Add the interface values. + if (auto interface = op.interface()) { + for (auto var : interface.getValue()) { + auto id = getVariableID(var.cast().getValue()); + if (!id) { + return op.emitError("referencing undefined global variable." + "spv.EntryPoint is at the end of spv.module. All " + "referenced variables should already be defined"); + } + operands.push_back(id); + } + } + return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, + operands); +} + +template <> +LogicalResult +Serializer::processOp(spirv::ControlBarrierOp op) { + StringRef argNames[] = {"execution_scope", "memory_scope", + "memory_semantics"}; + SmallVector operands; + + for (auto argName : argNames) { + auto argIntAttr = op->getAttrOfType(argName); + auto operand = prepareConstantInt(op.getLoc(), argIntAttr); + if (!operand) { + return failure(); + } + operands.push_back(operand); + } + + return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, + operands); +} + +template <> +LogicalResult +Serializer::processOp(spirv::ExecutionModeOp op) { + SmallVector operands; + // Add the function . + auto funcID = getFunctionID(op.fn()); + if (!funcID) { + return op.emitError("missing for function ") + << op.fn() + << "; function needs to be serialized before ExecutionModeOp is " + "serialized"; + } + operands.push_back(funcID); + // Add the ExecutionMode. + operands.push_back(static_cast(op.execution_mode())); + + // Serialize values if any. + auto values = op.values(); + if (values) { + for (auto &intVal : values.getValue()) { + operands.push_back(static_cast( + intVal.cast().getValue().getZExtValue())); + } + } + return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, + operands); +} + +template <> +LogicalResult +Serializer::processOp(spirv::MemoryBarrierOp op) { + StringRef argNames[] = {"memory_scope", "memory_semantics"}; + SmallVector operands; + + for (auto argName : argNames) { + auto argIntAttr = op->getAttrOfType(argName); + auto operand = prepareConstantInt(op.getLoc(), argIntAttr); + if (!operand) { + return failure(); + } + operands.push_back(operand); + } + + return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, + operands); +} + +template <> +LogicalResult +Serializer::processOp(spirv::FunctionCallOp op) { + auto funcName = op.callee(); + uint32_t resTypeID = 0; + + Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); + if (failed(processType(op.getLoc(), resultTy, resTypeID))) + return failure(); + + auto funcID = getOrCreateFunctionID(funcName); + auto funcCallID = getNextID(); + SmallVector operands{resTypeID, funcCallID, funcID}; + + for (auto value : op.arguments()) { + auto valueID = getValueID(value); + assert(valueID && "cannot find a value for spv.FunctionCall"); + operands.push_back(valueID); + } + + if (!resultTy.isa()) + valueIDMap[op.getResult(0)] = funcCallID; + + return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, + operands); +} + +template <> +LogicalResult +Serializer::processOp(spirv::CopyMemoryOp op) { + SmallVector operands; + SmallVector elidedAttrs; + + for (Value operand : op->getOperands()) { + auto id = getValueID(operand); + assert(id && "use before def!"); + operands.push_back(id); + } + + if (auto attr = op->getAttr("memory_access")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("memory_access"); + + if (auto attr = op->getAttr("alignment")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("alignment"); + + if (auto attr = op->getAttr("source_memory_access")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("source_memory_access"); + + if (auto attr = op->getAttr("source_alignment")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("source_alignment"); + emitDebugLine(functionBody, op.getLoc()); + encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); + + return success(); +} + +// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and +// various Serializer::processOp<...>() specializations. +#define GET_SERIALIZATION_FNS +#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" + +} // namespace spirv +} // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -0,0 +1,448 @@ +//===- Serializer.h - MLIR SPIR-V Serializer ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the MLIR SPIR-V module to SPIR-V binary serializer. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_SPIRV_SERIALIZER_H +#define MLIR_TARGET_SPIRV_SERIALIZER_H + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/IR/Builders.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace spirv { + +LogicalResult encodeInstructionInto(SmallVectorImpl &binary, + spirv::Opcode op, + ArrayRef operands); + +/// A SPIR-V module serializer. +/// +/// A SPIR-V binary module is a single linear stream of instructions; each +/// instruction is composed of 32-bit words with the layout: +/// +/// | | | | | ... | +/// | <------ word -------> | <-- word --> | <-- word --> | ... | +/// +/// For the first word, the 16 high-order bits are the word count of the +/// instruction, the 16 low-order bits are the opcode enumerant. The +/// instructions then belong to different sections, which must be laid out in +/// the particular order as specified in "2.4 Logical Layout of a Module" of +/// the SPIR-V spec. +class Serializer { +public: + /// Creates a serializer for the given SPIR-V `module`. + explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); + + /// Serializes the remembered SPIR-V module. + LogicalResult serialize(); + + /// Collects the final SPIR-V `binary`. + void collect(SmallVectorImpl &binary); + +#ifndef NDEBUG + /// (For debugging) prints each value and its corresponding result . + void printValueIDMap(raw_ostream &os); +#endif + +private: + // Note that there are two main categories of methods in this class: + // * process*() methods are meant to fully serialize a SPIR-V module entity + // (header, type, op, etc.). They update internal vectors containing + // different binary sections. They are not meant to be called except the + // top-level serialization loop. + // * prepare*() methods are meant to be helpers that prepare for serializing + // certain entity. They may or may not update internal vectors containing + // different binary sections. They are meant to be called among themselves + // or by other process*() methods for subtasks. + + //===--------------------------------------------------------------------===// + // + //===--------------------------------------------------------------------===// + + // Note that it is illegal to use id <0> in SPIR-V binary module. Various + // methods in this class, if using SPIR-V word (uint32_t) as interface, + // check or return id <0> to indicate error in processing. + + /// Consumes the next unused . This method will never return 0. + uint32_t getNextID() { return nextID++; } + + //===--------------------------------------------------------------------===// + // Module structure + //===--------------------------------------------------------------------===// + + uint32_t getSpecConstID(StringRef constName) const { + return specConstIDMap.lookup(constName); + } + + uint32_t getVariableID(StringRef varName) const { + return globalVarIDMap.lookup(varName); + } + + uint32_t getFunctionID(StringRef fnName) const { + return funcIDMap.lookup(fnName); + } + + /// Gets the for the function with the given name. Assigns the next + /// available if the function haven't been deserialized. + uint32_t getOrCreateFunctionID(StringRef fnName); + + void processCapability(); + + void processDebugInfo(); + + void processExtension(); + + void processMemoryModel(); + + LogicalResult processConstantOp(spirv::ConstantOp op); + + LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); + + LogicalResult + processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); + + LogicalResult + processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); + + /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA + /// value to use with other operations. The SPIR-V spec recommends that + /// OpUndef be generated at module level. The serialization generates an + /// OpUndef for each type needed at module level. + LogicalResult processUndefOp(spirv::UndefOp op); + + /// Emit OpName for the given `resultID`. + LogicalResult processName(uint32_t resultID, StringRef name); + + /// Processes a SPIR-V function op. + LogicalResult processFuncOp(spirv::FuncOp op); + + LogicalResult processVariableOp(spirv::VariableOp op); + + /// Process a SPIR-V GlobalVariableOp + LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); + + /// Process attributes that translate to decorations on the result + LogicalResult processDecoration(Location loc, uint32_t resultID, + NamedAttribute attr); + + template + LogicalResult processTypeDecoration(Location loc, DType type, + uint32_t resultId) { + return emitError(loc, "unhandled decoration for type:") << type; + } + + /// Process member decoration + LogicalResult processMemberDecoration( + uint32_t structID, + const spirv::StructType::MemberDecorationInfo &memberDecorationInfo); + + //===--------------------------------------------------------------------===// + // Types + //===--------------------------------------------------------------------===// + + uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); } + + Type getVoidType() { return mlirBuilder.getNoneType(); } + + bool isVoidType(Type type) const { return type.isa(); } + + /// Returns true if the given type is a pointer type to a struct in some + /// interface storage class. + bool isInterfaceStructPtrType(Type type) const; + + /// Main dispatch method for serializing a type. The result of the + /// serialized type will be returned as `typeID`. + LogicalResult processType(Location loc, Type type, uint32_t &typeID); + LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID, + llvm::SetVector &serializationCtx); + + /// Method for preparing basic SPIR-V type serialization. Returns the type's + /// opcode and operands for the instruction via `typeEnum` and `operands`. + LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, + spirv::Opcode &typeEnum, + SmallVectorImpl &operands, + bool &deferSerialization, + llvm::SetVector &serializationCtx); + + LogicalResult prepareFunctionType(Location loc, FunctionType type, + spirv::Opcode &typeEnum, + SmallVectorImpl &operands); + + //===--------------------------------------------------------------------===// + // Constant + //===--------------------------------------------------------------------===// + + uint32_t getConstantID(Attribute value) const { + return constIDMap.lookup(value); + } + + /// Main dispatch method for processing a constant with the given `constType` + /// and `valueAttr`. `constType` is needed here because we can interpret the + /// `valueAttr` as a different type than the type of `valueAttr` itself; for + /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType + /// constants. + uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); + + /// Prepares array attribute serialization. This method emits corresponding + /// OpConstant* and returns the result associated with it. Returns 0 if + /// failed. + uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); + + /// Prepares bool/int/float DenseElementsAttr serialization. This method + /// iterates the DenseElementsAttr to construct the constant array, and + /// returns the result associated with it. Returns 0 if failed. Note + /// that the size of `index` must match the rank. + /// TODO: Consider to enhance splat elements cases. For splat cases, + /// we don't need to loop over all elements, especially when the splat value + /// is zero. We can use OpConstantNull when the value is zero. + uint32_t prepareDenseElementsConstant(Location loc, Type constType, + DenseElementsAttr valueAttr, int dim, + MutableArrayRef index); + + /// Prepares scalar attribute serialization. This method emits corresponding + /// OpConstant* and returns the result associated with it. Returns 0 if + /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is + /// true, then the constant will be serialized as a specialization constant. + uint32_t prepareConstantScalar(Location loc, Attribute valueAttr, + bool isSpec = false); + + uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, + bool isSpec = false); + + uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, + bool isSpec = false); + + uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, + bool isSpec = false); + + //===--------------------------------------------------------------------===// + // Control flow + //===--------------------------------------------------------------------===// + + /// Returns the result for the given block. + uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } + + /// Returns the result for the given block. If no has been assigned, + /// assigns the next available + uint32_t getOrCreateBlockID(Block *block); + + /// Processes the given `block` and emits SPIR-V instructions for all ops + /// inside. Does not emit OpLabel for this block if `omitLabel` is true. + /// `actionBeforeTerminator` is a callback that will be invoked before + /// handling the terminator op. It can be used to inject the Op*Merge + /// instruction if this is a SPIR-V selection/loop header block. + LogicalResult + processBlock(Block *block, bool omitLabel = false, + function_ref actionBeforeTerminator = nullptr); + + /// Emits OpPhi instructions for the given block if it has block arguments. + LogicalResult emitPhiForBlockArguments(Block *block); + + LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); + + LogicalResult processLoopOp(spirv::LoopOp loopOp); + + LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); + + LogicalResult processBranchOp(spirv::BranchOp branchOp); + + //===--------------------------------------------------------------------===// + // Operations + //===--------------------------------------------------------------------===// + + LogicalResult encodeExtensionInstruction(Operation *op, + StringRef extensionSetName, + uint32_t opcode, + ArrayRef operands); + + uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } + + LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); + + LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp); + + /// Main dispatch method for serializing an operation. + LogicalResult processOperation(Operation *op); + + /// Serializes an operation `op` as core instruction with `opcode` if + /// `extInstSet` is empty. Otherwise serializes it as an extended instruction + /// with `opcode` from `extInstSet`. + /// This method is a generic one for dispatching any SPIR-V ops that has no + /// variadic operands and attributes in TableGen definitions. + LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet, + uint32_t opcode); + + /// Dispatches to the serialization function for an operation in SPIR-V + /// dialect that is a mirror of an instruction in the SPIR-V spec. This is + /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V + /// dialect that have hasOpcode == 1. + LogicalResult dispatchToAutogenSerialization(Operation *op); + + /// Serializes an operation in the SPIR-V dialect that is a mirror of an + /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1 + /// and autogenSerialization == 1 in ODS. + template + LogicalResult processOp(OpTy op) { + return op.emitError("unsupported op serialization"); + } + + //===--------------------------------------------------------------------===// + // Utilities + //===--------------------------------------------------------------------===// + + /// Emits an OpDecorate instruction to decorate the given `target` with the + /// given `decoration`. + LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, + ArrayRef params = {}); + + /// Emits an OpLine instruction with the given `loc` location information into + /// the given `binary` vector. + LogicalResult emitDebugLine(SmallVectorImpl &binary, Location loc); + +private: + /// The SPIR-V module to be serialized. + spirv::ModuleOp module; + + /// An MLIR builder for getting MLIR constructs. + mlir::Builder mlirBuilder; + + /// A flag which indicates if the debuginfo should be emitted. + bool emitDebugInfo = false; + + /// A flag which indicates if the last processed instruction was a merge + /// instruction. + /// According to SPIR-V spec: "If a branch merge instruction is used, the last + /// OpLine in the block must be before its merge instruction". + bool lastProcessedWasMergeInst = false; + + /// The of the OpString instruction, which specifies a file name, for + /// use by other debug instructions. + uint32_t fileID = 0; + + /// The next available result . + uint32_t nextID = 1; + + // The following are for different SPIR-V instruction sections. They follow + // the logical layout of a SPIR-V module. + + SmallVector capabilities; + SmallVector extensions; + SmallVector extendedSets; + SmallVector memoryModel; + SmallVector entryPoints; + SmallVector executionModes; + SmallVector debug; + SmallVector names; + SmallVector decorations; + SmallVector typesGlobalValues; + SmallVector functions; + + /// Recursive struct references are serialized as OpTypePointer instructions + /// to the recursive struct type. However, the OpTypePointer instruction + /// cannot be emitted before the recursive struct's OpTypeStruct. + /// RecursiveStructPointerInfo stores the data needed to emit such + /// OpTypePointer instructions after forward references to such types. + struct RecursiveStructPointerInfo { + uint32_t pointerTypeID; + spirv::StorageClass storageClass; + }; + + // Maps spirv::StructType to its recursive reference member info. + DenseMap> + recursiveStructInfos; + + /// `functionHeader` contains all the instructions that must be in the first + /// block in the function, and `functionBody` contains the rest. After + /// processing FuncOp, the encoded instructions of a function are appended to + /// `functions`. An example of instructions in `functionHeader` in order: + /// OpFunction ... + /// OpFunctionParameter ... + /// OpFunctionParameter ... + /// OpLabel ... + /// OpVariable ... + /// OpVariable ... + SmallVector functionHeader; + SmallVector functionBody; + + /// Map from type used in SPIR-V module to their s. + DenseMap typeIDMap; + + /// Map from constant values to their s. + DenseMap constIDMap; + + /// Map from specialization constant names to their s. + llvm::StringMap specConstIDMap; + + /// Map from GlobalVariableOps name to s. + llvm::StringMap globalVarIDMap; + + /// Map from FuncOps name to s. + llvm::StringMap funcIDMap; + + /// Map from blocks to their s. + DenseMap blockIDMap; + + /// Map from the Type to the that represents undef value of that type. + DenseMap undefValIDMap; + + /// Map from results of normal operations to their s. + DenseMap valueIDMap; + + /// Map from extended instruction set name to s. + llvm::StringMap extendedInstSetIDMap; + + /// Map from values used in OpPhi instructions to their offset in the + /// `functions` section. + /// + /// When processing a block with arguments, we need to emit OpPhi + /// instructions to record the predecessor block s and the values they + /// send to the block in question. But it's not guaranteed all values are + /// visited and thus assigned result s. So we need this list to capture + /// the offsets into `functions` where a value is used so that we can fix it + /// up later after processing all the blocks in a function. + /// + /// More concretely, say if we are visiting the following blocks: + /// + /// ```mlir + /// ^phi(%arg0: i32): + /// ... + /// ^parent1: + /// ... + /// spv.Branch ^phi(%val0: i32) + /// ^parent2: + /// ... + /// spv.Branch ^phi(%val1: i32) + /// ``` + /// + /// When we are serializing the `^phi` block, we need to emit at the beginning + /// of the block OpPhi instructions which has the following parameters: + /// + /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 + /// id-for-%val1 id-for-^parent2 + /// + /// But we don't know the for %val0 and %val1 yet. One way is to visit + /// all the blocks twice and use the first visit to assign an to each + /// value. But it's paying the overheads just for OpPhi emission. Instead, + /// we still visit the blocks once for emission. When we emit the OpPhi + /// instructions, we use 0 as a placeholder for the s for %val0 and %val1. + /// At the same time, we record their offsets in the emitted binary (which is + /// placed inside `functions`) here. And then after emitting all blocks, we + /// replace the dummy 0 with the real result by overwriting + /// `functions[offset]`. + DenseMap> deferredPhiValues; +}; +} // namespace spirv +} // namespace mlir + +#endif // MLIR_TARGET_SPIRV_SERIALIZER_H diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -0,0 +1,1149 @@ +//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the MLIR SPIR-V module to SPIR-V binary serializer. +// +//===----------------------------------------------------------------------===// + +#include "Serializer.h" + +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/bit.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "spirv-serialization" + +using namespace mlir; + +/// Returns the merge block if the given `op` is a structured control flow op. +/// Otherwise returns nullptr. +static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { + if (auto selectionOp = dyn_cast(op)) + return selectionOp.getMergeBlock(); + if (auto loopOp = dyn_cast(op)) + return loopOp.getMergeBlock(); + return nullptr; +} + +/// Given a predecessor `block` for a block with arguments, returns the block +/// that should be used as the parent block for SPIR-V OpPhi instructions +/// corresponding to the block arguments. +static Block *getPhiIncomingBlock(Block *block) { + // If the predecessor block in question is the entry block for a spv.loop, + // we jump to this spv.loop from its enclosing block. + if (block->isEntryBlock()) { + if (auto loopOp = dyn_cast(block->getParentOp())) { + // Then the incoming parent block for OpPhi should be the merge block of + // the structured control flow op before this loop. + Operation *op = loopOp.getOperation(); + while ((op = op->getPrevNode()) != nullptr) + if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) + return incomingBlock; + // Or the enclosing block itself if no structured control flow ops + // exists before this loop. + return loopOp->getBlock(); + } + } + + // Otherwise, we jump from the given predecessor block. Try to see if there is + // a structured control flow op inside it. + for (Operation &op : llvm::reverse(block->getOperations())) { + if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) + return incomingBlock; + } + return block; +} + +namespace mlir { +namespace spirv { + +/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into +/// the given `binary` vector. +LogicalResult encodeInstructionInto(SmallVectorImpl &binary, + spirv::Opcode op, + ArrayRef operands) { + uint32_t wordCount = 1 + operands.size(); + binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); + binary.append(operands.begin(), operands.end()); + return success(); +} + +Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) + : module(module), mlirBuilder(module.getContext()), + emitDebugInfo(emitDebugInfo) {} + +LogicalResult Serializer::serialize() { + LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); + + if (failed(module.verify())) + return failure(); + + // TODO: handle the other sections + processCapability(); + processExtension(); + processMemoryModel(); + processDebugInfo(); + + // Iterate over the module body to serialize it. Assumptions are that there is + // only one basic block in the moduleOp + for (auto &op : module.getBlock()) { + if (failed(processOperation(&op))) { + return failure(); + } + } + + LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); + return success(); +} + +void Serializer::collect(SmallVectorImpl &binary) { + auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + + extensions.size() + extendedSets.size() + + memoryModel.size() + entryPoints.size() + + executionModes.size() + decorations.size() + + typesGlobalValues.size() + functions.size(); + + binary.clear(); + binary.reserve(moduleSize); + + spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); + binary.append(capabilities.begin(), capabilities.end()); + binary.append(extensions.begin(), extensions.end()); + binary.append(extendedSets.begin(), extendedSets.end()); + binary.append(memoryModel.begin(), memoryModel.end()); + binary.append(entryPoints.begin(), entryPoints.end()); + binary.append(executionModes.begin(), executionModes.end()); + binary.append(debug.begin(), debug.end()); + binary.append(names.begin(), names.end()); + binary.append(decorations.begin(), decorations.end()); + binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); + binary.append(functions.begin(), functions.end()); +} + +#ifndef NDEBUG +void Serializer::printValueIDMap(raw_ostream &os) { + os << "\n= Value Map =\n\n"; + for (auto valueIDPair : valueIDMap) { + Value val = valueIDPair.first; + os << " " << val << " " + << "id = " << valueIDPair.second << ' '; + if (auto *op = val.getDefiningOp()) { + os << "from op '" << op->getName() << "'"; + } else if (auto arg = val.dyn_cast()) { + Block *block = arg.getOwner(); + os << "from argument of block " << block << ' '; + os << " in op '" << block->getParentOp()->getName() << "'"; + } + os << '\n'; + } +} +#endif + +//===----------------------------------------------------------------------===// +// Module structure +//===----------------------------------------------------------------------===// + +uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { + auto funcID = funcIDMap.lookup(fnName); + if (!funcID) { + funcID = getNextID(); + funcIDMap[fnName] = funcID; + } + return funcID; +} + +void Serializer::processCapability() { + for (auto cap : module.vce_triple()->getCapabilities()) + encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, + {static_cast(cap)}); +} + +void Serializer::processDebugInfo() { + if (!emitDebugInfo) + return; + auto fileLoc = module.getLoc().dyn_cast(); + auto fileName = fileLoc ? fileLoc.getFilename() : ""; + fileID = getNextID(); + SmallVector operands; + operands.push_back(fileID); + spirv::encodeStringLiteralInto(operands, fileName); + encodeInstructionInto(debug, spirv::Opcode::OpString, operands); + // TODO: Encode more debug instructions. +} + +void Serializer::processExtension() { + llvm::SmallVector extName; + for (spirv::Extension ext : module.vce_triple()->getExtensions()) { + extName.clear(); + spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); + encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); + } +} + +void Serializer::processMemoryModel() { + uint32_t mm = module->getAttrOfType("memory_model").getInt(); + uint32_t am = module->getAttrOfType("addressing_model").getInt(); + + encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); +} + +LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, + NamedAttribute attr) { + auto attrName = attr.first.strref(); + auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); + auto decoration = spirv::symbolizeDecoration(decorationName); + if (!decoration) { + return emitError( + loc, "non-argument attributes expected to have snake-case-ified " + "decoration name, unhandled attribute with name : ") + << attrName; + } + SmallVector args; + switch (decoration.getValue()) { + case spirv::Decoration::Binding: + case spirv::Decoration::DescriptorSet: + case spirv::Decoration::Location: + if (auto intAttr = attr.second.dyn_cast()) { + args.push_back(intAttr.getValue().getZExtValue()); + break; + } + return emitError(loc, "expected integer attribute for ") << attrName; + case spirv::Decoration::BuiltIn: + if (auto strAttr = attr.second.dyn_cast()) { + auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); + if (enumVal) { + args.push_back(static_cast(enumVal.getValue())); + break; + } + return emitError(loc, "invalid ") + << attrName << " attribute " << strAttr.getValue(); + } + return emitError(loc, "expected string attribute for ") << attrName; + case spirv::Decoration::Aliased: + case spirv::Decoration::Flat: + case spirv::Decoration::NonReadable: + case spirv::Decoration::NonWritable: + case spirv::Decoration::NoPerspective: + case spirv::Decoration::Restrict: + // For unit attributes, the args list has no values so we do nothing + if (auto unitAttr = attr.second.dyn_cast()) + break; + return emitError(loc, "expected unit attribute for ") << attrName; + default: + return emitError(loc, "unhandled decoration ") << decorationName; + } + return emitDecoration(resultID, decoration.getValue(), args); +} + +LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { + assert(!name.empty() && "unexpected empty string for OpName"); + + SmallVector nameOperands; + nameOperands.push_back(resultID); + if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { + return failure(); + } + return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); +} + +template <> +LogicalResult Serializer::processTypeDecoration( + Location loc, spirv::ArrayType type, uint32_t resultID) { + if (unsigned stride = type.getArrayStride()) { + // OpDecorate %arrayTypeSSA ArrayStride strideLiteral + return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); + } + return success(); +} + +template <> +LogicalResult Serializer::processTypeDecoration( + Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { + if (unsigned stride = type.getArrayStride()) { + // OpDecorate %arrayTypeSSA ArrayStride strideLiteral + return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); + } + return success(); +} + +LogicalResult Serializer::processMemberDecoration( + uint32_t structID, + const spirv::StructType::MemberDecorationInfo &memberDecoration) { + SmallVector args( + {structID, memberDecoration.memberIndex, + static_cast(memberDecoration.decoration)}); + if (memberDecoration.hasValue) { + args.push_back(memberDecoration.decorationValue); + } + return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, + args); +} + +//===----------------------------------------------------------------------===// +// Type +//===----------------------------------------------------------------------===// + +// According to the SPIR-V spec "Validation Rules for Shader Capabilities": +// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and +// PushConstant Storage Classes must be explicitly laid out." +bool Serializer::isInterfaceStructPtrType(Type type) const { + if (auto ptrType = type.dyn_cast()) { + switch (ptrType.getStorageClass()) { + case spirv::StorageClass::PhysicalStorageBuffer: + case spirv::StorageClass::PushConstant: + case spirv::StorageClass::StorageBuffer: + case spirv::StorageClass::Uniform: + return ptrType.getPointeeType().isa(); + default: + break; + } + } + return false; +} + +LogicalResult Serializer::processType(Location loc, Type type, + uint32_t &typeID) { + // Maintains a set of names for nested identified struct types. This is used + // to properly serialize recursive references. + llvm::SetVector serializationCtx; + return processTypeImpl(loc, type, typeID, serializationCtx); +} + +LogicalResult +Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, + llvm::SetVector &serializationCtx) { + typeID = getTypeID(type); + if (typeID) { + return success(); + } + typeID = getNextID(); + SmallVector operands; + + operands.push_back(typeID); + auto typeEnum = spirv::Opcode::OpTypeVoid; + bool deferSerialization = false; + + if ((type.isa() && + succeeded(prepareFunctionType(loc, type.cast(), typeEnum, + operands))) || + succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, + deferSerialization, serializationCtx))) { + if (deferSerialization) + return success(); + + typeIDMap[type] = typeID; + + if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) + return failure(); + + if (recursiveStructInfos.count(type) != 0) { + // This recursive struct type is emitted already, now the OpTypePointer + // instructions referring to recursive references are emitted as well. + for (auto &ptrInfo : recursiveStructInfos[type]) { + // TODO: This might not work if more than 1 recursive reference is + // present in the struct. + SmallVector ptrOperands; + ptrOperands.push_back(ptrInfo.pointerTypeID); + ptrOperands.push_back(static_cast(ptrInfo.storageClass)); + ptrOperands.push_back(typeIDMap[type]); + + if (failed(encodeInstructionInto( + typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) + return failure(); + } + + recursiveStructInfos[type].clear(); + } + + return success(); + } + + return failure(); +} + +LogicalResult Serializer::prepareBasicType( + Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, + SmallVectorImpl &operands, bool &deferSerialization, + llvm::SetVector &serializationCtx) { + deferSerialization = false; + + if (isVoidType(type)) { + typeEnum = spirv::Opcode::OpTypeVoid; + return success(); + } + + if (auto intType = type.dyn_cast()) { + if (intType.getWidth() == 1) { + typeEnum = spirv::Opcode::OpTypeBool; + return success(); + } + + typeEnum = spirv::Opcode::OpTypeInt; + operands.push_back(intType.getWidth()); + // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics + // to preserve or validate. + // 0 indicates unsigned, or no signedness semantics + // 1 indicates signed semantics." + operands.push_back(intType.isSigned() ? 1 : 0); + return success(); + } + + if (auto floatType = type.dyn_cast()) { + typeEnum = spirv::Opcode::OpTypeFloat; + operands.push_back(floatType.getWidth()); + return success(); + } + + if (auto vectorType = type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, + serializationCtx))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeVector; + operands.push_back(elementTypeID); + operands.push_back(vectorType.getNumElements()); + return success(); + } + + if (auto imageType = type.dyn_cast()) { + typeEnum = spirv::Opcode::OpTypeImage; + uint32_t sampledTypeID = 0; + if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) + return failure(); + + operands.push_back(sampledTypeID); + operands.push_back(static_cast(imageType.getDim())); + operands.push_back(static_cast(imageType.getDepthInfo())); + operands.push_back(static_cast(imageType.getArrayedInfo())); + operands.push_back(static_cast(imageType.getSamplingInfo())); + operands.push_back(static_cast(imageType.getSamplerUseInfo())); + operands.push_back(static_cast(imageType.getImageFormat())); + return success(); + } + + if (auto arrayType = type.dyn_cast()) { + typeEnum = spirv::Opcode::OpTypeArray; + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, + serializationCtx))) { + return failure(); + } + operands.push_back(elementTypeID); + if (auto elementCountID = prepareConstantInt( + loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { + operands.push_back(elementCountID); + } + return processTypeDecoration(loc, arrayType, resultID); + } + + if (auto ptrType = type.dyn_cast()) { + uint32_t pointeeTypeID = 0; + spirv::StructType pointeeStruct = + ptrType.getPointeeType().dyn_cast(); + + if (pointeeStruct && pointeeStruct.isIdentified() && + serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { + // A recursive reference to an enclosing struct is found. + // + // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage + // class as operands. + SmallVector forwardPtrOperands; + forwardPtrOperands.push_back(resultID); + forwardPtrOperands.push_back( + static_cast(ptrType.getStorageClass())); + + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpTypeForwardPointer, + forwardPtrOperands); + + // 2. Find the pointee (enclosing) struct. + auto structType = spirv::StructType::getIdentified( + module.getContext(), pointeeStruct.getIdentifier()); + + if (!structType) + return failure(); + + // 3. Mark the OpTypePointer that is supposed to be emitted by this call + // as deferred. + deferSerialization = true; + + // 4. Record the info needed to emit the deferred OpTypePointer + // instruction when the enclosing struct is completely serialized. + recursiveStructInfos[structType].push_back( + {resultID, ptrType.getStorageClass()}); + } else { + if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, + serializationCtx))) + return failure(); + } + + typeEnum = spirv::Opcode::OpTypePointer; + operands.push_back(static_cast(ptrType.getStorageClass())); + operands.push_back(pointeeTypeID); + return success(); + } + + if (auto runtimeArrayType = type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), + elementTypeID, serializationCtx))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeRuntimeArray; + operands.push_back(elementTypeID); + return processTypeDecoration(loc, runtimeArrayType, resultID); + } + + if (auto structType = type.dyn_cast()) { + if (structType.isIdentified()) { + processName(resultID, structType.getIdentifier()); + serializationCtx.insert(structType.getIdentifier()); + } + + bool hasOffset = structType.hasOffset(); + for (auto elementIndex : + llvm::seq(0, structType.getNumElements())) { + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), + elementTypeID, serializationCtx))) { + return failure(); + } + operands.push_back(elementTypeID); + if (hasOffset) { + // Decorate each struct member with an offset + spirv::StructType::MemberDecorationInfo offsetDecoration{ + elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, + static_cast(structType.getMemberOffset(elementIndex))}; + if (failed(processMemberDecoration(resultID, offsetDecoration))) { + return emitError(loc, "cannot decorate ") + << elementIndex << "-th member of " << structType + << " with its offset"; + } + } + } + SmallVector memberDecorations; + structType.getMemberDecorations(memberDecorations); + + for (auto &memberDecoration : memberDecorations) { + if (failed(processMemberDecoration(resultID, memberDecoration))) { + return emitError(loc, "cannot decorate ") + << static_cast(memberDecoration.memberIndex) + << "-th member of " << structType << " with " + << stringifyDecoration(memberDecoration.decoration); + } + } + + typeEnum = spirv::Opcode::OpTypeStruct; + + if (structType.isIdentified()) + serializationCtx.remove(structType.getIdentifier()); + + return success(); + } + + if (auto cooperativeMatrixType = + type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), + elementTypeID, serializationCtx))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; + auto getConstantOp = [&](uint32_t id) { + auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); + return prepareConstantInt(loc, attr); + }; + operands.push_back(elementTypeID); + operands.push_back( + getConstantOp(static_cast(cooperativeMatrixType.getScope()))); + operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); + operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); + return success(); + } + + if (auto matrixType = type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, + serializationCtx))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeMatrix; + operands.push_back(elementTypeID); + operands.push_back(matrixType.getNumColumns()); + return success(); + } + + // TODO: Handle other types. + return emitError(loc, "unhandled type in serialization: ") << type; +} + +LogicalResult +Serializer::prepareFunctionType(Location loc, FunctionType type, + spirv::Opcode &typeEnum, + SmallVectorImpl &operands) { + typeEnum = spirv::Opcode::OpTypeFunction; + assert(type.getNumResults() <= 1 && + "serialization supports only a single return value"); + uint32_t resultID = 0; + if (failed(processType( + loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), + resultID))) { + return failure(); + } + operands.push_back(resultID); + for (auto &res : type.getInputs()) { + uint32_t argTypeID = 0; + if (failed(processType(loc, res, argTypeID))) { + return failure(); + } + operands.push_back(argTypeID); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Constant +//===----------------------------------------------------------------------===// + +uint32_t Serializer::prepareConstant(Location loc, Type constType, + Attribute valueAttr) { + if (auto id = prepareConstantScalar(loc, valueAttr)) { + return id; + } + + // This is a composite literal. We need to handle each component separately + // and then emit an OpConstantComposite for the whole. + + if (auto id = getConstantID(valueAttr)) { + return id; + } + + uint32_t typeID = 0; + if (failed(processType(loc, constType, typeID))) { + return 0; + } + + uint32_t resultID = 0; + if (auto attr = valueAttr.dyn_cast()) { + int rank = attr.getType().dyn_cast().getRank(); + SmallVector index(rank); + resultID = prepareDenseElementsConstant(loc, constType, attr, + /*dim=*/0, index); + } else if (auto arrayAttr = valueAttr.dyn_cast()) { + resultID = prepareArrayConstant(loc, constType, arrayAttr); + } + + if (resultID == 0) { + emitError(loc, "cannot serialize attribute: ") << valueAttr; + return 0; + } + + constIDMap[valueAttr] = resultID; + return resultID; +} + +uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, + ArrayAttr attr) { + uint32_t typeID = 0; + if (failed(processType(loc, constType, typeID))) { + return 0; + } + + uint32_t resultID = getNextID(); + SmallVector operands = {typeID, resultID}; + operands.reserve(attr.size() + 2); + auto elementType = constType.cast().getElementType(); + for (Attribute elementAttr : attr) { + if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { + operands.push_back(elementID); + } else { + return 0; + } + } + spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; + encodeInstructionInto(typesGlobalValues, opcode, operands); + + return resultID; +} + +// TODO: Turn the below function into iterative function, instead of +// recursive function. +uint32_t +Serializer::prepareDenseElementsConstant(Location loc, Type constType, + DenseElementsAttr valueAttr, int dim, + MutableArrayRef index) { + auto shapedType = valueAttr.getType().dyn_cast(); + assert(dim <= shapedType.getRank()); + if (shapedType.getRank() == dim) { + if (auto attr = valueAttr.dyn_cast()) { + return attr.getType().getElementType().isInteger(1) + ? prepareConstantBool(loc, attr.getValue(index)) + : prepareConstantInt(loc, attr.getValue(index)); + } + if (auto attr = valueAttr.dyn_cast()) { + return prepareConstantFp(loc, attr.getValue(index)); + } + return 0; + } + + uint32_t typeID = 0; + if (failed(processType(loc, constType, typeID))) { + return 0; + } + + uint32_t resultID = getNextID(); + SmallVector operands = {typeID, resultID}; + operands.reserve(shapedType.getDimSize(dim) + 2); + auto elementType = constType.cast().getElementType(0); + for (int i = 0; i < shapedType.getDimSize(dim); ++i) { + index[dim] = i; + if (auto elementID = prepareDenseElementsConstant( + loc, elementType, valueAttr, dim + 1, index)) { + operands.push_back(elementID); + } else { + return 0; + } + } + spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; + encodeInstructionInto(typesGlobalValues, opcode, operands); + + return resultID; +} + +uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, + bool isSpec) { + if (auto floatAttr = valueAttr.dyn_cast()) { + return prepareConstantFp(loc, floatAttr, isSpec); + } + if (auto boolAttr = valueAttr.dyn_cast()) { + return prepareConstantBool(loc, boolAttr, isSpec); + } + if (auto intAttr = valueAttr.dyn_cast()) { + return prepareConstantInt(loc, intAttr, isSpec); + } + + return 0; +} + +uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, + bool isSpec) { + if (!isSpec) { + // We can de-duplicate normal constants, but not specialization constants. + if (auto id = getConstantID(boolAttr)) { + return id; + } + } + + // Process the type for this bool literal + uint32_t typeID = 0; + if (failed(processType(loc, boolAttr.getType(), typeID))) { + return 0; + } + + auto resultID = getNextID(); + auto opcode = boolAttr.getValue() + ? (isSpec ? spirv::Opcode::OpSpecConstantTrue + : spirv::Opcode::OpConstantTrue) + : (isSpec ? spirv::Opcode::OpSpecConstantFalse + : spirv::Opcode::OpConstantFalse); + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); + + if (!isSpec) { + constIDMap[boolAttr] = resultID; + } + return resultID; +} + +uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, + bool isSpec) { + if (!isSpec) { + // We can de-duplicate normal constants, but not specialization constants. + if (auto id = getConstantID(intAttr)) { + return id; + } + } + + // Process the type for this integer literal + uint32_t typeID = 0; + if (failed(processType(loc, intAttr.getType(), typeID))) { + return 0; + } + + auto resultID = getNextID(); + APInt value = intAttr.getValue(); + unsigned bitwidth = value.getBitWidth(); + bool isSigned = value.isSignedIntN(bitwidth); + + auto opcode = + isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; + + // According to SPIR-V spec, "When the type's bit width is less than 32-bits, + // the literal's value appears in the low-order bits of the word, and the + // high-order bits must be 0 for a floating-point type, or 0 for an integer + // type with Signedness of 0, or sign extended when Signedness is 1." + if (bitwidth == 32 || bitwidth == 16) { + uint32_t word = 0; + if (isSigned) { + word = static_cast(value.getSExtValue()); + } else { + word = static_cast(value.getZExtValue()); + } + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); + } + // According to SPIR-V spec: "When the type's bit width is larger than one + // word, the literal’s low-order words appear first." + else if (bitwidth == 64) { + struct DoubleWord { + uint32_t word1; + uint32_t word2; + } words; + if (isSigned) { + words = llvm::bit_cast(value.getSExtValue()); + } else { + words = llvm::bit_cast(value.getZExtValue()); + } + encodeInstructionInto(typesGlobalValues, opcode, + {typeID, resultID, words.word1, words.word2}); + } else { + std::string valueStr; + llvm::raw_string_ostream rss(valueStr); + value.print(rss, /*isSigned=*/false); + + emitError(loc, "cannot serialize ") + << bitwidth << "-bit integer literal: " << rss.str(); + return 0; + } + + if (!isSpec) { + constIDMap[intAttr] = resultID; + } + return resultID; +} + +uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, + bool isSpec) { + if (!isSpec) { + // We can de-duplicate normal constants, but not specialization constants. + if (auto id = getConstantID(floatAttr)) { + return id; + } + } + + // Process the type for this float literal + uint32_t typeID = 0; + if (failed(processType(loc, floatAttr.getType(), typeID))) { + return 0; + } + + auto resultID = getNextID(); + APFloat value = floatAttr.getValue(); + APInt intValue = value.bitcastToAPInt(); + + auto opcode = + isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; + + if (&value.getSemantics() == &APFloat::IEEEsingle()) { + uint32_t word = llvm::bit_cast(value.convertToFloat()); + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); + } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { + struct DoubleWord { + uint32_t word1; + uint32_t word2; + } words = llvm::bit_cast(value.convertToDouble()); + encodeInstructionInto(typesGlobalValues, opcode, + {typeID, resultID, words.word1, words.word2}); + } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { + uint32_t word = + static_cast(value.bitcastToAPInt().getZExtValue()); + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); + } else { + std::string valueStr; + llvm::raw_string_ostream rss(valueStr); + value.print(rss); + + emitError(loc, "cannot serialize ") + << floatAttr.getType() << "-typed float literal: " << rss.str(); + return 0; + } + + if (!isSpec) { + constIDMap[floatAttr] = resultID; + } + return resultID; +} + +//===----------------------------------------------------------------------===// +// Control flow +//===----------------------------------------------------------------------===// + +uint32_t Serializer::getOrCreateBlockID(Block *block) { + if (uint32_t id = getBlockID(block)) + return id; + return blockIDMap[block] = getNextID(); +} + +LogicalResult +Serializer::processBlock(Block *block, bool omitLabel, + function_ref actionBeforeTerminator) { + LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); + LLVM_DEBUG(block->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); + if (!omitLabel) { + uint32_t blockID = getOrCreateBlockID(block); + LLVM_DEBUG(llvm::dbgs() + << "[block] " << block << " (id = " << blockID << ")\n"); + + // Emit OpLabel for this block. + encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); + } + + // Emit OpPhi instructions for block arguments, if any. + if (failed(emitPhiForBlockArguments(block))) + return failure(); + + // Process each op in this block except the terminator. + for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { + if (failed(processOperation(&op))) + return failure(); + } + + // Process the terminator. + if (actionBeforeTerminator) + actionBeforeTerminator(); + if (failed(processOperation(&block->back()))) + return failure(); + + return success(); +} + +LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { + // Nothing to do if this block has no arguments or it's the entry block, which + // always has the same arguments as the function signature. + if (block->args_empty() || block->isEntryBlock()) + return success(); + + // If the block has arguments, we need to create SPIR-V OpPhi instructions. + // A SPIR-V OpPhi instruction is of the syntax: + // OpPhi | result type | result | (value , parent block ) pair + // So we need to collect all predecessor blocks and the arguments they send + // to this block. + SmallVector, 4> predecessors; + for (Block *predecessor : block->getPredecessors()) { + auto *terminator = predecessor->getTerminator(); + // The predecessor here is the immediate one according to MLIR's IR + // structure. It does not directly map to the incoming parent block for the + // OpPhi instructions at SPIR-V binary level. This is because structured + // control flow ops are serialized to multiple SPIR-V blocks. If there is a + // spv.selection/spv.loop op in the MLIR predecessor block, the branch op + // jumping to the OpPhi's block then resides in the previous structured + // control flow op's merge block. + predecessor = getPhiIncomingBlock(predecessor); + if (auto branchOp = dyn_cast(terminator)) { + predecessors.emplace_back(predecessor, branchOp.operand_begin()); + } else { + return terminator->emitError("unimplemented terminator for Phi creation"); + } + } + + // Then create OpPhi instruction for each of the block argument. + for (auto argIndex : llvm::seq(0, block->getNumArguments())) { + BlockArgument arg = block->getArgument(argIndex); + + // Get the type and result for this OpPhi instruction. + uint32_t phiTypeID = 0; + if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) + return failure(); + uint32_t phiID = getNextID(); + + LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' + << arg << " (id = " << phiID << ")\n"); + + // Prepare the (value , parent block ) pairs. + SmallVector phiArgs; + phiArgs.push_back(phiTypeID); + phiArgs.push_back(phiID); + + for (auto predIndex : llvm::seq(0, predecessors.size())) { + Value value = *(predecessors[predIndex].second + argIndex); + uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); + LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId + << ") value " << value << ' '); + // Each pair is a value ... + uint32_t valueId = getValueID(value); + if (valueId == 0) { + // The op generating this value hasn't been visited yet so we don't have + // an assigned yet. Record this to fix up later. + LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); + deferredPhiValues[value].push_back(functionBody.size() + 1 + + phiArgs.size()); + } else { + LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); + } + phiArgs.push_back(valueId); + // ... and a parent block . + phiArgs.push_back(predBlockId); + } + + encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); + valueIDMap[arg] = phiID; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Operation +//===----------------------------------------------------------------------===// + +LogicalResult Serializer::encodeExtensionInstruction( + Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, + ArrayRef operands) { + // Check if the extension has been imported. + auto &setID = extendedInstSetIDMap[extensionSetName]; + if (!setID) { + setID = getNextID(); + SmallVector importOperands; + importOperands.push_back(setID); + if (failed( + spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || + failed(encodeInstructionInto( + extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { + return failure(); + } + } + + // The first two operands are the result type and result . The set + // and the opcode need to be insert after this. + if (operands.size() < 2) { + return op->emitError("extended instructions must have a result encoding"); + } + SmallVector extInstOperands; + extInstOperands.reserve(operands.size() + 2); + extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); + extInstOperands.push_back(setID); + extInstOperands.push_back(extensionOpcode); + extInstOperands.append(std::next(operands.begin(), 2), operands.end()); + return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, + extInstOperands); +} + +LogicalResult Serializer::processOperation(Operation *opInst) { + LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); + + // First dispatch the ops that do not directly mirror an instruction from + // the SPIR-V spec. + return TypeSwitch(opInst) + .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) + .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) + .Case([&](spirv::BranchConditionalOp op) { + return processBranchConditionalOp(op); + }) + .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) + .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) + .Case([&](spirv::GlobalVariableOp op) { + return processGlobalVariableOp(op); + }) + .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) + .Case([&](spirv::ModuleEndOp) { return success(); }) + .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) + .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) + .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) + .Case([&](spirv::SpecConstantCompositeOp op) { + return processSpecConstantCompositeOp(op); + }) + .Case([&](spirv::SpecConstantOperationOp op) { + return processSpecConstantOperationOp(op); + }) + .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) + .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) + + // Then handle all the ops that directly mirror SPIR-V instructions with + // auto-generated methods. + .Default( + [&](Operation *op) { return dispatchToAutogenSerialization(op); }); +} + +LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, + StringRef extInstSet, + uint32_t opcode) { + SmallVector operands; + Location loc = op->getLoc(); + + uint32_t resultID = 0; + if (op->getNumResults() != 0) { + uint32_t resultTypeID = 0; + if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) + return failure(); + operands.push_back(resultTypeID); + + resultID = getNextID(); + operands.push_back(resultID); + valueIDMap[op->getResult(0)] = resultID; + }; + + for (Value operand : op->getOperands()) + operands.push_back(getValueID(operand)); + + emitDebugLine(functionBody, loc); + + if (extInstSet.empty()) { + encodeInstructionInto(functionBody, static_cast(opcode), + operands); + } else { + encodeExtensionInstruction(op, extInstSet, opcode, operands); + } + + if (op->getNumResults() != 0) { + for (auto attr : op->getAttrs()) { + if (failed(processDecoration(loc, resultID, attr))) + return failure(); + } + } + + return success(); +} + +LogicalResult Serializer::emitDecoration(uint32_t target, + spirv::Decoration decoration, + ArrayRef params) { + uint32_t wordCount = 3 + params.size(); + decorations.push_back( + spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); + decorations.push_back(target); + decorations.push_back(static_cast(decoration)); + decorations.append(params.begin(), params.end()); + return success(); +} + +LogicalResult Serializer::emitDebugLine(SmallVectorImpl &binary, + Location loc) { + if (!emitDebugInfo) + return success(); + + if (lastProcessedWasMergeInst) { + lastProcessedWasMergeInst = false; + return success(); + } + + auto fileLoc = loc.dyn_cast(); + if (fileLoc) + encodeInstructionInto(binary, spirv::Opcode::OpLine, + {fileID, fileLoc.getLine(), fileLoc.getColumn()}); + return success(); +} +} // namespace spirv +} // namespace mlir