diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(SPIRV) + add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation LLVMIR/DebugTranslation.cpp LLVMIR/ModuleTranslation.cpp @@ -132,52 +134,3 @@ MLIRROCDLIR MLIRTargetLLVMIRModuleTranslation ) - -add_mlir_translation_library(MLIRSPIRVBinaryUtils - SPIRV/SPIRVBinaryUtils.cpp - - LINK_LIBS PUBLIC - MLIRIR - MLIRSPIRV - MLIRSupport - ) - -add_mlir_translation_library(MLIRSPIRVSerialization - SPIRV/Serialization.cpp - - DEPENDS - MLIRSPIRVSerializationGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRSPIRV - MLIRSPIRVBinaryUtils - MLIRSupport - MLIRTranslation - ) - -add_mlir_translation_library(MLIRSPIRVDeserialization - SPIRV/Deserialization.cpp - - DEPENDS - MLIRSPIRVSerializationGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRSPIRV - MLIRSPIRVBinaryUtils - MLIRSupport - MLIRTranslation - ) - -add_mlir_translation_library(MLIRSPIRVTranslateRegistration - SPIRV/TranslateRegistration.cpp - - LINK_LIBS PUBLIC - MLIRIR - MLIRSPIRV - MLIRSPIRVSerialization - MLIRSPIRVDeserialization - MLIRSupport - MLIRTranslation - ) diff --git a/mlir/lib/Target/SPIRV/CMakeLists.txt b/mlir/lib/Target/SPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/CMakeLists.txt @@ -0,0 +1,28 @@ +add_subdirectory(Deserialization) +add_subdirectory(Serialization) + +set(LLVM_OPTIONAL_SOURCES + SPIRVBinaryUtils.cpp + TranslateRegistration.cpp + ) + +add_mlir_translation_library(MLIRSPIRVBinaryUtils + SPIRVBinaryUtils.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIRSPIRV + MLIRSupport + ) + +add_mlir_translation_library(MLIRSPIRVTranslateRegistration + TranslateRegistration.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIRSPIRV + MLIRSPIRVSerialization + MLIRSPIRVDeserialization + MLIRSupport + MLIRTranslation + ) diff --git a/mlir/lib/Target/SPIRV/Deserialization/CMakeLists.txt b/mlir/lib/Target/SPIRV/Deserialization/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/Deserialization/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_translation_library(MLIRSPIRVDeserialization + DeserializeOps.cpp + Deserializer.cpp + Deserialization.cpp + + DEPENDS + MLIRSPIRVSerializationGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSPIRV + MLIRSPIRVBinaryUtils + MLIRSupport + MLIRTranslation + ) + + diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp @@ -0,0 +1,23 @@ +//===- Deserialization.cpp - MLIR SPIR-V Deserialization ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/SPIRV/Deserialization.h" + +#include "Deserializer.h" + +namespace mlir { +spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef binary, + MLIRContext *context) { + Deserializer deserializer(binary, context); + + if (failed(deserializer.deserialize())) + return nullptr; + + return deserializer.collect(); +} +} // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -0,0 +1,557 @@ +//===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Deserializer methods for SPIR-V binary instructions. +// +//===----------------------------------------------------------------------===// + +#include "Deserializer.h" + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +#define DEBUG_TYPE "spirv-deserialization" + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +/// Extracts the opcode from the given first word of a SPIR-V instruction. +static inline spirv::Opcode extractOpcode(uint32_t word) { + return static_cast(word & 0xffff); +} + +//===----------------------------------------------------------------------===// +// Instruction +//===----------------------------------------------------------------------===// + +Value spirv::Deserializer::getValue(uint32_t id) { + if (auto constInfo = getConstant(id)) { + // Materialize a `spv.constant` op at every use site. + return opBuilder.create(unknownLoc, constInfo->second, + constInfo->first); + } + if (auto varOp = getGlobalVariable(id)) { + auto addressOfOp = opBuilder.create( + unknownLoc, varOp.type(), + opBuilder.getSymbolRefAttr(varOp.getOperation())); + return addressOfOp.pointer(); + } + if (auto constOp = getSpecConstant(id)) { + auto referenceOfOp = opBuilder.create( + unknownLoc, constOp.default_value().getType(), + opBuilder.getSymbolRefAttr(constOp.getOperation())); + return referenceOfOp.reference(); + } + if (auto constCompositeOp = getSpecConstantComposite(id)) { + auto referenceOfOp = opBuilder.create( + unknownLoc, constCompositeOp.type(), + opBuilder.getSymbolRefAttr(constCompositeOp.getOperation())); + return referenceOfOp.reference(); + } + if (auto undef = getUndefType(id)) { + return opBuilder.create(unknownLoc, undef); + } + return valueMap.lookup(id); +} + +LogicalResult +spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode, + ArrayRef &operands, + Optional expectedOpcode) { + auto binarySize = binary.size(); + if (curOffset >= binarySize) { + return emitError(unknownLoc, "expected ") + << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) + : "more") + << " instruction"; + } + + // For each instruction, get its word count from the first word to slice it + // from the stream properly, and then dispatch to the instruction handler. + + uint32_t wordCount = binary[curOffset] >> 16; + + if (wordCount == 0) + return emitError(unknownLoc, "word count cannot be zero"); + + uint32_t nextOffset = curOffset + wordCount; + if (nextOffset > binarySize) + return emitError(unknownLoc, "insufficient words for the last instruction"); + + opcode = extractOpcode(binary[curOffset]); + operands = binary.slice(curOffset + 1, wordCount - 1); + curOffset = nextOffset; + return success(); +} + +LogicalResult spirv::Deserializer::processInstruction( + spirv::Opcode opcode, ArrayRef operands, bool deferInstructions) { + LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction " + << spirv::stringifyOpcode(opcode) << "\n"); + + // First dispatch all the instructions whose opcode does not correspond to + // those that have a direct mirror in the SPIR-V dialect + switch (opcode) { + case spirv::Opcode::OpCapability: + return processCapability(operands); + case spirv::Opcode::OpExtension: + return processExtension(operands); + case spirv::Opcode::OpExtInst: + return processExtInst(operands); + case spirv::Opcode::OpExtInstImport: + return processExtInstImport(operands); + case spirv::Opcode::OpMemberName: + return processMemberName(operands); + case spirv::Opcode::OpMemoryModel: + return processMemoryModel(operands); + case spirv::Opcode::OpEntryPoint: + case spirv::Opcode::OpExecutionMode: + if (deferInstructions) { + deferredInstructions.emplace_back(opcode, operands); + return success(); + } + break; + case spirv::Opcode::OpVariable: + if (isa(opBuilder.getBlock()->getParentOp())) { + return processGlobalVariable(operands); + } + break; + case spirv::Opcode::OpLine: + return processDebugLine(operands); + case spirv::Opcode::OpNoLine: + return clearDebugLine(); + case spirv::Opcode::OpName: + return processName(operands); + case spirv::Opcode::OpString: + return processDebugString(operands); + case spirv::Opcode::OpModuleProcessed: + case spirv::Opcode::OpSource: + case spirv::Opcode::OpSourceContinued: + case spirv::Opcode::OpSourceExtension: + // TODO: This is debug information embedded in the binary which should be + // translated into the spv.module. + return success(); + case spirv::Opcode::OpTypeVoid: + case spirv::Opcode::OpTypeBool: + case spirv::Opcode::OpTypeInt: + case spirv::Opcode::OpTypeFloat: + case spirv::Opcode::OpTypeVector: + case spirv::Opcode::OpTypeMatrix: + case spirv::Opcode::OpTypeArray: + case spirv::Opcode::OpTypeFunction: + case spirv::Opcode::OpTypeRuntimeArray: + case spirv::Opcode::OpTypeStruct: + case spirv::Opcode::OpTypePointer: + case spirv::Opcode::OpTypeCooperativeMatrixNV: + return processType(opcode, operands); + case spirv::Opcode::OpTypeForwardPointer: + return processTypeForwardPointer(operands); + case spirv::Opcode::OpConstant: + return processConstant(operands, /*isSpec=*/false); + case spirv::Opcode::OpSpecConstant: + return processConstant(operands, /*isSpec=*/true); + case spirv::Opcode::OpConstantComposite: + return processConstantComposite(operands); + case spirv::Opcode::OpSpecConstantComposite: + return processSpecConstantComposite(operands); + case spirv::Opcode::OpConstantTrue: + return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); + case spirv::Opcode::OpSpecConstantTrue: + return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); + case spirv::Opcode::OpConstantFalse: + return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); + case spirv::Opcode::OpSpecConstantFalse: + return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); + case spirv::Opcode::OpConstantNull: + return processConstantNull(operands); + case spirv::Opcode::OpDecorate: + return processDecoration(operands); + case spirv::Opcode::OpMemberDecorate: + return processMemberDecoration(operands); + case spirv::Opcode::OpFunction: + return processFunction(operands); + case spirv::Opcode::OpLabel: + return processLabel(operands); + case spirv::Opcode::OpBranch: + return processBranch(operands); + case spirv::Opcode::OpBranchConditional: + return processBranchConditional(operands); + case spirv::Opcode::OpSelectionMerge: + return processSelectionMerge(operands); + case spirv::Opcode::OpLoopMerge: + return processLoopMerge(operands); + case spirv::Opcode::OpPhi: + return processPhi(operands); + case spirv::Opcode::OpUndef: + return processUndef(operands); + default: + break; + } + return dispatchToAutogenDeserialization(opcode, operands); +} + +LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr( + ArrayRef words, StringRef opName, bool hasResult, + unsigned numOperands) { + SmallVector resultTypes; + uint32_t valueID = 0; + + size_t wordIndex = 0; + if (hasResult) { + if (wordIndex >= words.size()) + return emitError(unknownLoc, + "expected result type while deserializing for ") + << opName; + + // Decode the type + auto type = getType(words[wordIndex]); + if (!type) + return emitError(unknownLoc, "unknown type result : ") + << words[wordIndex]; + resultTypes.push_back(type); + ++wordIndex; + + // Decode the result + if (wordIndex >= words.size()) + return emitError(unknownLoc, + "expected result while deserializing for ") + << opName; + valueID = words[wordIndex]; + ++wordIndex; + } + + SmallVector operands; + SmallVector attributes; + + // Decode operands + size_t operandIndex = 0; + for (; operandIndex < numOperands && wordIndex < words.size(); + ++operandIndex, ++wordIndex) { + auto arg = getValue(words[wordIndex]); + if (!arg) + return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; + operands.push_back(arg); + } + if (operandIndex != numOperands) { + return emitError( + unknownLoc, + "found less operands than expected when deserializing for ") + << opName << "; only " << operandIndex << " of " << numOperands + << " processed"; + } + if (wordIndex != words.size()) { + return emitError( + unknownLoc, + "found more operands than expected when deserializing for ") + << opName << "; only " << wordIndex << " of " << words.size() + << " processed"; + } + + // Attach attributes from decorations + if (decorations.count(valueID)) { + auto attrs = decorations[valueID].getAttrs(); + attributes.append(attrs.begin(), attrs.end()); + } + + // Create the op and update bookkeeping maps + Location loc = createFileLineColLoc(opBuilder); + OperationState opState(loc, opName); + opState.addOperands(operands); + if (hasResult) + opState.addTypes(resultTypes); + opState.addAttributes(attributes); + Operation *op = opBuilder.createOperation(opState); + if (hasResult) + valueMap[valueID] = op->getResult(0); + + if (op->hasTrait()) + clearDebugLine(); + + return success(); +} + +LogicalResult spirv::Deserializer::processUndef(ArrayRef operands) { + if (operands.size() != 2) { + return emitError(unknownLoc, "OpUndef instruction must have two operands"); + } + auto type = getType(operands[0]); + if (!type) { + return emitError(unknownLoc, "unknown type with OpUndef instruction"); + } + undefMap[operands[1]] = type; + return success(); +} + +LogicalResult spirv::Deserializer::processExtInst(ArrayRef operands) { + if (operands.size() < 4) { + return emitError(unknownLoc, + "OpExtInst must have at least 4 operands, result type " + ", result , set and instruction opcode"); + } + if (!extendedInstSets.count(operands[2])) { + return emitError(unknownLoc, "undefined set in OpExtInst"); + } + SmallVector slicedOperands; + slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); + slicedOperands.append(std::next(operands.begin(), 4), operands.end()); + return dispatchToExtensionSetAutogenDeserialization( + extendedInstSets[operands[2]], operands[3], slicedOperands); +} + +namespace mlir { +namespace spirv { + +template <> +LogicalResult +Deserializer::processOp(ArrayRef words) { + unsigned wordIndex = 0; + if (wordIndex >= words.size()) { + return emitError(unknownLoc, + "missing Execution Model specification in OpEntryPoint"); + } + auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]); + if (wordIndex >= words.size()) { + return emitError(unknownLoc, "missing in OpEntryPoint"); + } + // Get the function + auto fnID = words[wordIndex++]; + // Get the function name + auto fnName = decodeStringLiteral(words, wordIndex); + // Verify that the function matches the fnName + auto parsedFunc = getFunction(fnID); + if (!parsedFunc) { + return emitError(unknownLoc, "no function matching ") << fnID; + } + if (parsedFunc.getName() != fnName) { + return emitError(unknownLoc, "function name mismatch between OpEntryPoint " + "and OpFunction with ") + << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); + } + SmallVector interface; + while (wordIndex < words.size()) { + auto arg = getGlobalVariable(words[wordIndex]); + if (!arg) { + return emitError(unknownLoc, "undefined result ") + << words[wordIndex] << " while decoding OpEntryPoint"; + } + interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); + wordIndex++; + } + opBuilder.create(unknownLoc, execModel, + opBuilder.getSymbolRefAttr(fnName), + opBuilder.getArrayAttr(interface)); + return success(); +} + +template <> +LogicalResult +Deserializer::processOp(ArrayRef words) { + unsigned wordIndex = 0; + if (wordIndex >= words.size()) { + return emitError(unknownLoc, + "missing function result in OpExecutionMode"); + } + // Get the function to get the name of the function + auto fnID = words[wordIndex++]; + auto fn = getFunction(fnID); + if (!fn) { + return emitError(unknownLoc, "no function matching ") << fnID; + } + // Get the Execution mode + if (wordIndex >= words.size()) { + return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); + } + auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); + + // Get the values + SmallVector attrListElems; + while (wordIndex < words.size()) { + attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); + } + auto values = opBuilder.getArrayAttr(attrListElems); + opBuilder.create( + unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values); + return success(); +} + +template <> +LogicalResult +Deserializer::processOp(ArrayRef operands) { + if (operands.size() != 3) { + return emitError( + unknownLoc, + "OpControlBarrier must have execution scope , memory scope " + "and memory semantics "); + } + + SmallVector argAttrs; + for (auto operand : operands) { + auto argAttr = getConstantInt(operand); + if (!argAttr) { + return emitError(unknownLoc, + "expected 32-bit integer constant from ") + << operand << " for OpControlBarrier"; + } + argAttrs.push_back(argAttr); + } + + opBuilder.create(unknownLoc, argAttrs[0], + argAttrs[1], argAttrs[2]); + return success(); +} + +template <> +LogicalResult +Deserializer::processOp(ArrayRef operands) { + if (operands.size() < 3) { + return emitError(unknownLoc, + "OpFunctionCall must have at least 3 operands"); + } + + Type resultType = getType(operands[0]); + if (!resultType) { + return emitError(unknownLoc, "undefined result type from ") + << operands[0]; + } + + // Use null type to mean no result type. + if (isVoidType(resultType)) + resultType = nullptr; + + auto resultID = operands[1]; + auto functionID = operands[2]; + + auto functionName = getFunctionSymbol(functionID); + + SmallVector arguments; + for (auto operand : llvm::drop_begin(operands, 3)) { + auto value = getValue(operand); + if (!value) { + return emitError(unknownLoc, "unknown ") + << operand << " used by OpFunctionCall"; + } + arguments.push_back(value); + } + + auto opFunctionCall = opBuilder.create( + unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName), + arguments); + + if (resultType) + valueMap[resultID] = opFunctionCall.getResult(0); + return success(); +} + +template <> +LogicalResult +Deserializer::processOp(ArrayRef operands) { + if (operands.size() != 2) { + return emitError(unknownLoc, "OpMemoryBarrier must have memory scope " + "and memory semantics "); + } + + SmallVector argAttrs; + for (auto operand : operands) { + auto argAttr = getConstantInt(operand); + if (!argAttr) { + return emitError(unknownLoc, + "expected 32-bit integer constant from ") + << operand << " for OpMemoryBarrier"; + } + argAttrs.push_back(argAttr); + } + + opBuilder.create(unknownLoc, argAttrs[0], + argAttrs[1]); + return success(); +} + +template <> +LogicalResult +Deserializer::processOp(ArrayRef words) { + SmallVector resultTypes; + size_t wordIndex = 0; + SmallVector operands; + SmallVector attributes; + + if (wordIndex < words.size()) { + auto arg = getValue(words[wordIndex]); + + if (!arg) { + return emitError(unknownLoc, "unknown result : ") + << words[wordIndex]; + } + + operands.push_back(arg); + wordIndex++; + } + + if (wordIndex < words.size()) { + auto arg = getValue(words[wordIndex]); + + if (!arg) { + return emitError(unknownLoc, "unknown result : ") + << words[wordIndex]; + } + + operands.push_back(arg); + wordIndex++; + } + + bool isAlignedAttr = false; + + if (wordIndex < words.size()) { + auto attrValue = words[wordIndex++]; + attributes.push_back(opBuilder.getNamedAttr( + "memory_access", opBuilder.getI32IntegerAttr(attrValue))); + isAlignedAttr = (attrValue == 2); + } + + if (isAlignedAttr && wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "source_memory_access", + opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex != words.size()) { + return emitError(unknownLoc, + "found more operands than expected when deserializing " + "spirv::CopyMemoryOp, only ") + << wordIndex << " of " << words.size() << " processed"; + } + + Location loc = createFileLineColLoc(opBuilder); + opBuilder.create(loc, resultTypes, operands, attributes); + + return success(); +} + +// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and +// various Deserializer::processOp<...>() specializations. +#define GET_DESERIALIZATION_FNS +#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" + +} // namespace spirv +} // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -0,0 +1,583 @@ +//===- Deserializer.h - MLIR SPIR-V Deserializer ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the SPIR-V binary to MLIR SPIR-V module deserializer. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_SPIRV_DESERIALIZER_H +#define MLIR_TARGET_SPIRV_DESERIALIZER_H + +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/IR/Builders.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringRef.h" +#include + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +/// Decodes a string literal in `words` starting at `wordIndex`. Update the +/// latter to point to the position in words after the string literal. +static inline llvm::StringRef +decodeStringLiteral(llvm::ArrayRef words, unsigned &wordIndex) { + llvm::StringRef str(reinterpret_cast(words.data() + wordIndex)); + wordIndex += str.size() / 4 + 1; + return str; +} + +namespace mlir { +namespace spirv { + +//===----------------------------------------------------------------------===// +// Utility Definitions +//===----------------------------------------------------------------------===// + +/// A struct for containing a header block's merge and continue targets. +/// +/// This struct is used to track original structured control flow info from +/// SPIR-V blob. This info will be used to create spv.selection/spv.loop +/// later. +struct BlockMergeInfo { + Block *mergeBlock; + Block *continueBlock; // nullptr for spv.selection + Location loc; + uint32_t control; + + BlockMergeInfo(Location location, uint32_t control) + : mergeBlock(nullptr), continueBlock(nullptr), loc(location), + control(control) {} + BlockMergeInfo(Location location, uint32_t control, Block *m, + Block *c = nullptr) + : mergeBlock(m), continueBlock(c), loc(location), control(control) {} +}; + +/// A struct for containing OpLine instruction information. +struct DebugLine { + uint32_t fileID; + uint32_t line; + uint32_t col; + + DebugLine(uint32_t fileIDNum, uint32_t lineNum, uint32_t colNum) + : fileID(fileIDNum), line(lineNum), col(colNum) {} +}; + +/// Map from a selection/loop's header block to its merge (and continue) target. +using BlockMergeInfoMap = DenseMap; + +/// A "deferred struct type" is a struct type with one or more member types not +/// known when the Deserializer first encounters the struct. This happens, for +/// example, with recursive structs where a pointer to the struct type is +/// forward declared through OpTypeForwardPointer in the SPIR-V module before +/// the struct declaration; the actual pointer to struct type should be defined +/// later through an OpTypePointer. For example, the following C struct: +/// +/// struct A { +/// A* next; +/// }; +/// +/// would be represented in the SPIR-V module as: +/// +/// OpName %A "A" +/// OpTypeForwardPointer %APtr Generic +/// %A = OpTypeStruct %APtr +/// %APtr = OpTypePointer Generic %A +/// +/// This means that the spirv::StructType cannot be fully constructed directly +/// when the Deserializer encounters it. Instead we create a +/// DeferredStructTypeInfo that contains all the information we know about the +/// spirv::StructType. Once all forward references for the struct are resolved, +/// the struct's body is set with all member info. +struct DeferredStructTypeInfo { + spirv::StructType deferredStructType; + + // A list of all unresolved member types for the struct. First element of each + // item is operand ID, second element is member index in the struct. + SmallVector, 0> unresolvedMemberTypes; + + // The list of member types. For unresolved members, this list contains + // place-holder empty types that will be updated later. + SmallVector memberTypes; + SmallVector offsetInfo; + SmallVector memberDecorationsInfo; +}; + +//===----------------------------------------------------------------------===// +// Deserializer Declaration +//===----------------------------------------------------------------------===// + +/// A SPIR-V module serializer. +/// +/// A SPIR-V binary module is a single linear stream of instructions; each +/// instruction is composed of 32-bit words. The first word of an instruction +/// records the total number of words of that instruction using the 16 +/// higher-order bits. So this deserializer uses that to get instruction +/// boundary and parse instructions and build a SPIR-V ModuleOp gradually. +/// +// TODO: clean up created ops on errors +class Deserializer { +public: + /// Creates a deserializer for the given SPIR-V `binary` module. + /// The SPIR-V ModuleOp will be created into `context. + explicit Deserializer(ArrayRef binary, MLIRContext *context); + + /// Deserializes the remembered SPIR-V binary module. + LogicalResult deserialize(); + + /// Collects the final SPIR-V ModuleOp. + spirv::OwningSPIRVModuleRef collect(); + +private: + //===--------------------------------------------------------------------===// + // Module structure + //===--------------------------------------------------------------------===// + + /// Initializes the `module` ModuleOp in this deserializer instance. + spirv::OwningSPIRVModuleRef createModuleOp(); + + /// Processes SPIR-V module header in `binary`. + LogicalResult processHeader(); + + /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping + /// in the deserializer. + LogicalResult processCapability(ArrayRef operands); + + /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping + /// in the deserializer. + LogicalResult processExtension(ArrayRef words); + + /// Processes the SPIR-V OpExtInstImport with `operands` and updates + /// bookkeeping in the deserializer. + LogicalResult processExtInstImport(ArrayRef words); + + /// Attaches (version, capabilities, extensions) triple to `module` as an + /// attribute. + void attachVCETriple(); + + /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`. + LogicalResult processMemoryModel(ArrayRef operands); + + /// Process SPIR-V OpName with `operands`. + LogicalResult processName(ArrayRef operands); + + /// Processes an OpDecorate instruction. + LogicalResult processDecoration(ArrayRef words); + + // Processes an OpMemberDecorate instruction. + LogicalResult processMemberDecoration(ArrayRef words); + + /// Processes an OpMemberName instruction. + LogicalResult processMemberName(ArrayRef words); + + /// Gets the function op associated with a result of OpFunction. + spirv::FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } + + /// Processes the SPIR-V function at the current `offset` into `binary`. + /// The operands to the OpFunction instruction is passed in as ``operands`. + /// This method processes each instruction inside the function and dispatches + /// them to their handler method accordingly. + LogicalResult processFunction(ArrayRef operands); + + /// Processes OpFunctionEnd and finalizes function. This wires up block + /// argument created from OpPhi instructions and also structurizes control + /// flow. + LogicalResult processFunctionEnd(ArrayRef operands); + + /// Gets the constant's attribute and type associated with the given . + Optional> getConstant(uint32_t id); + + /// Gets the constant's integer attribute with the given . Returns a null + /// IntegerAttr if the given is not registered or does not correspond to an + /// integer constant. + IntegerAttr getConstantInt(uint32_t id); + + /// Returns a symbol to be used for the function name with the given + /// result . This tries to use the function's OpName if + /// exists; otherwise creates one based on the . + std::string getFunctionSymbol(uint32_t id); + + /// Returns a symbol to be used for the specialization constant with the given + /// result . This tries to use the specialization constant's OpName if + /// exists; otherwise creates one based on the . + std::string getSpecConstantSymbol(uint32_t id); + + /// Gets the specialization constant with the given result . + spirv::SpecConstantOp getSpecConstant(uint32_t id) { + return specConstMap.lookup(id); + } + + /// Gets the composite specialization constant with the given result . + spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) { + return specConstCompositeMap.lookup(id); + } + + /// Creates a spirv::SpecConstantOp. + spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, + Attribute defaultValue); + + /// Processes the OpVariable instructions at current `offset` into `binary`. + /// It is expected that this method is used for variables that are to be + /// defined at module scope and will be deserialized into a spv.globalVariable + /// instruction. + LogicalResult processGlobalVariable(ArrayRef operands); + + /// Gets the global variable associated with a result of OpVariable. + spirv::GlobalVariableOp getGlobalVariable(uint32_t id) { + return globalVariableMap.lookup(id); + } + + //===--------------------------------------------------------------------===// + // Type + //===--------------------------------------------------------------------===// + + /// Gets type for a given result . + Type getType(uint32_t id) { return typeMap.lookup(id); } + + /// Get the type associated with the result of an OpUndef. + Type getUndefType(uint32_t id) { return undefMap.lookup(id); } + + /// Returns true if the given `type` is for SPIR-V void type. + bool isVoidType(Type type) const { return type.isa(); } + + /// Processes a SPIR-V type instruction with given `opcode` and `operands` and + /// registers the type into `module`. + LogicalResult processType(spirv::Opcode opcode, ArrayRef operands); + + LogicalResult processOpTypePointer(ArrayRef operands); + + LogicalResult processArrayType(ArrayRef operands); + + LogicalResult processCooperativeMatrixType(ArrayRef operands); + + LogicalResult processFunctionType(ArrayRef operands); + + LogicalResult processRuntimeArrayType(ArrayRef operands); + + LogicalResult processStructType(ArrayRef operands); + + LogicalResult processMatrixType(ArrayRef operands); + + LogicalResult processTypeForwardPointer(ArrayRef operands); + + //===--------------------------------------------------------------------===// + // Constant + //===--------------------------------------------------------------------===// + + /// Processes a SPIR-V Op{|Spec}Constant instruction with the given + /// `operands`. `isSpec` indicates whether this is a specialization constant. + LogicalResult processConstant(ArrayRef operands, bool isSpec); + + /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the + /// given `operands`. `isSpec` indicates whether this is a specialization + /// constant. + LogicalResult processConstantBool(bool isTrue, ArrayRef operands, + bool isSpec); + + /// Processes a SPIR-V OpConstantComposite instruction with the given + /// `operands`. + LogicalResult processConstantComposite(ArrayRef operands); + + LogicalResult processSpecConstantComposite(ArrayRef operands); + + /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. + LogicalResult processConstantNull(ArrayRef operands); + + //===--------------------------------------------------------------------===// + // Debug + //===--------------------------------------------------------------------===// + + /// Discontinues any source-level location information that might be active + /// from a previous OpLine instruction. + LogicalResult clearDebugLine(); + + /// Creates a FileLineColLoc with the OpLine location information. + Location createFileLineColLoc(OpBuilder opBuilder); + + /// Processes a SPIR-V OpLine instruction with the given `operands`. + LogicalResult processDebugLine(ArrayRef operands); + + /// Processes a SPIR-V OpString instruction with the given `operands`. + LogicalResult processDebugString(ArrayRef operands); + + //===--------------------------------------------------------------------===// + // Control flow + //===--------------------------------------------------------------------===// + + /// Returns the block for the given label . + Block *getBlock(uint32_t id) const { return blockMap.lookup(id); } + + // In SPIR-V, structured control flow is explicitly declared using merge + // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect, + // we use spv.selection and spv.loop to group structured control flow. + // The deserializer need to turn structured control flow marked with merge + // instructions into using spv.selection/spv.loop ops. + // + // Because structured control flow can nest and the basic block order have + // flexibility, we cannot isolate a structured selection/loop without + // deserializing all the blocks. So we use the following approach: + // + // 1. Deserialize all basic blocks in a function and create MLIR blocks for + // them into the function's region. In the meanwhile, keep a map between + // selection/loop header blocks to their corresponding merge (and continue) + // target blocks. + // 2. For each selection/loop header block, recursively get all basic blocks + // reachable (except the merge block) and put them in a newly created + // spv.selection/spv.loop's region. Structured control flow guarantees + // that we enter and exit in structured ways and the construct is nestable. + // 3. Put the new spv.selection/spv.loop op at the beginning of the old merge + // block and redirect all branches to the old header block to the old + // merge block (which contains the spv.selection/spv.loop op now). + + /// For OpPhi instructions, we use block arguments to represent them. OpPhi + /// encodes a list of (value, predecessor) pairs. At the time of handling the + /// block containing an OpPhi instruction, the predecessor block might not be + /// processed yet, also the value sent by it. So we need to defer handling + /// the block argument from the predecessors. We use the following approach: + /// + /// 1. For each OpPhi instruction, add a block argument to the current block + /// in construction. Record the block argument in `valueMap` so its uses + /// can be resolved. For the list of (value, predecessor) pairs, update + /// `blockPhiInfo` for bookkeeping. + /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each + /// block recorded there to create the proper block arguments on their + /// terminators. + + /// A data structure for containing a SPIR-V block's phi info. It will be + /// represented as block argument in SPIR-V dialect. + using BlockPhiInfo = + SmallVector; // The result of the values sent + + /// Gets or creates the block corresponding to the given label . The newly + /// created block will always be placed at the end of the current function. + Block *getOrCreateBlock(uint32_t id); + + LogicalResult processBranch(ArrayRef operands); + + LogicalResult processBranchConditional(ArrayRef operands); + + /// Processes a SPIR-V OpLabel instruction with the given `operands`. + LogicalResult processLabel(ArrayRef operands); + + /// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`. + LogicalResult processSelectionMerge(ArrayRef operands); + + /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`. + LogicalResult processLoopMerge(ArrayRef operands); + + /// Processes a SPIR-V OpPhi instruction with the given `operands`. + LogicalResult processPhi(ArrayRef operands); + + /// Creates block arguments on predecessors previously recorded when handling + /// OpPhi instructions. + LogicalResult wireUpBlockArgument(); + + /// Extracts blocks belonging to a structured selection/loop into a + /// spv.selection/spv.loop op. This method iterates until all blocks + /// declared as selection/loop headers are handled. + LogicalResult structurizeControlFlow(); + + //===--------------------------------------------------------------------===// + // Instruction + //===--------------------------------------------------------------------===// + + /// Get the Value associated with a result . + /// + /// This method materializes normal constants and inserts "casting" ops + /// (`spv.mlir.addressof` and `spv.mlir.referenceof`) to turn an symbol into a + /// SSA value for handling uses of module scope constants/variables in + /// functions. + Value getValue(uint32_t id); + + /// Slices the first instruction out of `binary` and returns its opcode and + /// operands via `opcode` and `operands` respectively. Returns failure if + /// there is no more remaining instructions (`expectedOpcode` will be used to + /// compose the error message) or the next instruction is malformed. + LogicalResult + sliceInstruction(spirv::Opcode &opcode, ArrayRef &operands, + Optional expectedOpcode = llvm::None); + + /// Processes a SPIR-V instruction with the given `opcode` and `operands`. + /// This method is the main entrance for handling SPIR-V instruction; it + /// checks the instruction opcode and dispatches to the corresponding handler. + /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode) + /// might need to be deferred, since they contain forward references to s + /// in the deserialized binary, but module in SPIR-V dialect expects these to + /// be ssa-uses. + LogicalResult processInstruction(spirv::Opcode opcode, + ArrayRef operands, + bool deferInstructions = true); + + /// Processes a SPIR-V instruction from the given `operands`. It should + /// deserialize into an op with the given `opName` and `numOperands`. + /// This method is a generic one for dispatching any SPIR-V ops without + /// variadic operands and attributes in TableGen definitions. + LogicalResult processOpWithoutGrammarAttr(ArrayRef words, + StringRef opName, bool hasResult, + unsigned numOperands); + + /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current + /// insertion point. + LogicalResult processUndef(ArrayRef operands); + + /// Method to dispatch to the specialized deserialization function for an + /// operation in SPIR-V dialect that is a mirror of an instruction in the + /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for + /// all operations in SPIR-V dialect that have hasOpcode == 1. + LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, + ArrayRef words); + + /// Processes a SPIR-V OpExtInst with given `operands`. This slices the + /// entries of `operands` that specify the extended instruction set and + /// the instruction opcode. The op deserializer is then invoked using the + /// other entries. + LogicalResult processExtInst(ArrayRef operands); + + /// Dispatches the deserialization of extended instruction set operation based + /// on the extended instruction set name, and instruction opcode. This is + /// autogenerated from ODS. + LogicalResult + dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName, + uint32_t instructionID, + ArrayRef words); + + /// Method to deserialize an operation in the SPIR-V dialect that is a mirror + /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode + /// == 1 and autogenSerialization == 1 in ODS. + template LogicalResult processOp(ArrayRef words) { + return emitError(unknownLoc, "unsupported deserialization for ") + << OpTy::getOperationName() << " op"; + } + +private: + /// The SPIR-V binary module. + ArrayRef binary; + + /// Contains the data of the OpLine instruction which precedes the current + /// processing instruction. + llvm::Optional debugLine; + + /// The current word offset into the binary module. + unsigned curOffset = 0; + + /// MLIRContext to create SPIR-V ModuleOp into. + MLIRContext *context; + + // TODO: create Location subclass for binary blob + Location unknownLoc; + + /// The SPIR-V ModuleOp. + spirv::OwningSPIRVModuleRef module; + + /// The current function under construction. + Optional curFunction; + + /// The current block under construction. + Block *curBlock = nullptr; + + OpBuilder opBuilder; + + spirv::Version version; + + /// The list of capabilities used by the module. + llvm::SmallSetVector capabilities; + + /// The list of extensions used by the module. + llvm::SmallSetVector extensions; + + // Result to type mapping. + DenseMap typeMap; + + // Result to constant attribute and type mapping. + /// + /// In the SPIR-V binary format, all constants are placed in the module and + /// shared by instructions at module level and in subsequent functions. But in + /// the SPIR-V dialect, we materialize the constant to where it's used in the + /// function. So when seeing a constant instruction in the binary format, we + /// don't immediately emit a constant op into the module, we keep its value + /// (and type) here. Later when it's used, we materialize the constant. + DenseMap> constantMap; + + // Result to spec constant mapping. + DenseMap specConstMap; + + // Result to composite spec constant mapping. + DenseMap specConstCompositeMap; + + // Result to variable mapping. + DenseMap globalVariableMap; + + // Result to function mapping. + DenseMap funcMap; + + // Result to block mapping. + DenseMap blockMap; + + // Header block to its merge (and continue) target mapping. + BlockMergeInfoMap blockMergeInfo; + + // Block to its phi (block argument) mapping. + DenseMap blockPhiInfo; + + // Result to value mapping. + DenseMap valueMap; + + // Mapping from result to undef value of a type. + DenseMap undefMap; + + // Result to name mapping. + DenseMap nameMap; + + // Result to debug info mapping. + DenseMap debugInfoMap; + + // Result to decorations mapping. + DenseMap decorations; + + // Result to type decorations. + DenseMap typeDecorations; + + // Result to member decorations. + // decorated-struct-type- -> + // (struct-member-index -> (decoration -> decoration-operands)) + DenseMap>>> + memberDecorationMap; + + // Result to member name. + // struct-type- -> (struct-member-index -> name) + DenseMap> memberNameMap; + + // Result to extended instruction set name. + DenseMap extendedInstSets; + + // List of instructions that are processed in a deferred fashion (after an + // initial processing of the entire binary). Some operations like + // OpEntryPoint, and OpExecutionMode use forward references to function + // s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and + // spv.ExecutionMode) need these references resolved. So these instructions + // are deserialized and stored for processing once the entire binary is + // processed. + SmallVector>, 4> + deferredInstructions; + + /// A list of IDs for all types forward-declared through OpTypeForwardPointer + /// instructions. + llvm::SetVector typeForwardPointerIDs; + + /// A list of all structs which have unresolved member types. + SmallVector deferredStructTypesInfos; +}; + +} // namespace spirv +} // namespace mlir + +#endif // MLIR_TARGET_SPIRV_DESERIALIZER_H diff --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp rename from mlir/lib/Target/SPIRV/Deserialization.cpp rename to mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1,4 +1,4 @@ -//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===// +//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// // -// This file defines the SPIR-V binary to MLIR SPIR-V module deserialization. +// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer. // //===----------------------------------------------------------------------===// -#include "mlir/Target/SPIRV/Deserialization.h" +#include "Deserializer.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" @@ -23,7 +23,6 @@ #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" @@ -38,577 +37,22 @@ // Utility Functions //===----------------------------------------------------------------------===// -/// Decodes a string literal in `words` starting at `wordIndex`. Update the -/// latter to point to the position in words after the string literal. -static inline StringRef decodeStringLiteral(ArrayRef words, - unsigned &wordIndex) { - StringRef str(reinterpret_cast(words.data() + wordIndex)); - wordIndex += str.size() / 4 + 1; - return str; -} - -/// Extracts the opcode from the given first word of a SPIR-V instruction. -static inline spirv::Opcode extractOpcode(uint32_t word) { - return static_cast(word & 0xffff); -} - /// Returns true if the given `block` is a function entry block. static inline bool isFnEntryBlock(Block *block) { return block->isEntryBlock() && isa_and_nonnull(block->getParentOp()); } -namespace { -//===----------------------------------------------------------------------===// -// Utility Definitions -//===----------------------------------------------------------------------===// - -/// A struct for containing a header block's merge and continue targets. -/// -/// This struct is used to track original structured control flow info from -/// SPIR-V blob. This info will be used to create spv.selection/spv.loop -/// later. -struct BlockMergeInfo { - Block *mergeBlock; - Block *continueBlock; // nullptr for spv.selection - Location loc; - uint32_t control; - - BlockMergeInfo(Location location, uint32_t control) - : mergeBlock(nullptr), continueBlock(nullptr), loc(location), - control(control) {} - BlockMergeInfo(Location location, uint32_t control, Block *m, - Block *c = nullptr) - : mergeBlock(m), continueBlock(c), loc(location), control(control) {} -}; - -/// A struct for containing OpLine instruction information. -struct DebugLine { - uint32_t fileID; - uint32_t line; - uint32_t col; - - DebugLine(uint32_t fileIDNum, uint32_t lineNum, uint32_t colNum) - : fileID(fileIDNum), line(lineNum), col(colNum) {} -}; - -/// Map from a selection/loop's header block to its merge (and continue) target. -using BlockMergeInfoMap = DenseMap; - -/// A "deferred struct type" is a struct type with one or more member types not -/// known when the Deserializer first encounters the struct. This happens, for -/// example, with recursive structs where a pointer to the struct type is -/// forward declared through OpTypeForwardPointer in the SPIR-V module before -/// the struct declaration; the actual pointer to struct type should be defined -/// later through an OpTypePointer. For example, the following C struct: -/// -/// struct A { -/// A* next; -/// }; -/// -/// would be represented in the SPIR-V module as: -/// -/// OpName %A "A" -/// OpTypeForwardPointer %APtr Generic -/// %A = OpTypeStruct %APtr -/// %APtr = OpTypePointer Generic %A -/// -/// This means that the spirv::StructType cannot be fully constructed directly -/// when the Deserializer encounters it. Instead we create a -/// DeferredStructTypeInfo that contains all the information we know about the -/// spirv::StructType. Once all forward references for the struct are resolved, -/// the struct's body is set with all member info. -struct DeferredStructTypeInfo { - spirv::StructType deferredStructType; - - // A list of all unresolved member types for the struct. First element of each - // item is operand ID, second element is member index in the struct. - SmallVector, 0> unresolvedMemberTypes; - - // The list of member types. For unresolved members, this list contains - // place-holder empty types that will be updated later. - SmallVector memberTypes; - SmallVector offsetInfo; - SmallVector memberDecorationsInfo; -}; - -//===----------------------------------------------------------------------===// -// Deserializer Declaration -//===----------------------------------------------------------------------===// - -/// A SPIR-V module serializer. -/// -/// A SPIR-V binary module is a single linear stream of instructions; each -/// instruction is composed of 32-bit words. The first word of an instruction -/// records the total number of words of that instruction using the 16 -/// higher-order bits. So this deserializer uses that to get instruction -/// boundary and parse instructions and build a SPIR-V ModuleOp gradually. -/// -// TODO: clean up created ops on errors -class Deserializer { -public: - /// Creates a deserializer for the given SPIR-V `binary` module. - /// The SPIR-V ModuleOp will be created into `context. - explicit Deserializer(ArrayRef binary, MLIRContext *context); - - /// Deserializes the remembered SPIR-V binary module. - LogicalResult deserialize(); - - /// Collects the final SPIR-V ModuleOp. - spirv::OwningSPIRVModuleRef collect(); - -private: - //===--------------------------------------------------------------------===// - // Module structure - //===--------------------------------------------------------------------===// - - /// Initializes the `module` ModuleOp in this deserializer instance. - spirv::OwningSPIRVModuleRef createModuleOp(); - - /// Processes SPIR-V module header in `binary`. - LogicalResult processHeader(); - - /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping - /// in the deserializer. - LogicalResult processCapability(ArrayRef operands); - - /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping - /// in the deserializer. - LogicalResult processExtension(ArrayRef words); - - /// Processes the SPIR-V OpExtInstImport with `operands` and updates - /// bookkeeping in the deserializer. - LogicalResult processExtInstImport(ArrayRef words); - - /// Attaches (version, capabilities, extensions) triple to `module` as an - /// attribute. - void attachVCETriple(); - - /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`. - LogicalResult processMemoryModel(ArrayRef operands); - - /// Process SPIR-V OpName with `operands`. - LogicalResult processName(ArrayRef operands); - - /// Processes an OpDecorate instruction. - LogicalResult processDecoration(ArrayRef words); - - // Processes an OpMemberDecorate instruction. - LogicalResult processMemberDecoration(ArrayRef words); - - /// Processes an OpMemberName instruction. - LogicalResult processMemberName(ArrayRef words); - - /// Gets the function op associated with a result of OpFunction. - spirv::FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } - - /// Processes the SPIR-V function at the current `offset` into `binary`. - /// The operands to the OpFunction instruction is passed in as ``operands`. - /// This method processes each instruction inside the function and dispatches - /// them to their handler method accordingly. - LogicalResult processFunction(ArrayRef operands); - - /// Processes OpFunctionEnd and finalizes function. This wires up block - /// argument created from OpPhi instructions and also structurizes control - /// flow. - LogicalResult processFunctionEnd(ArrayRef operands); - - /// Gets the constant's attribute and type associated with the given . - Optional> getConstant(uint32_t id); - - /// Gets the constant's integer attribute with the given . Returns a null - /// IntegerAttr if the given is not registered or does not correspond to an - /// integer constant. - IntegerAttr getConstantInt(uint32_t id); - - /// Returns a symbol to be used for the function name with the given - /// result . This tries to use the function's OpName if - /// exists; otherwise creates one based on the . - std::string getFunctionSymbol(uint32_t id); - - /// Returns a symbol to be used for the specialization constant with the given - /// result . This tries to use the specialization constant's OpName if - /// exists; otherwise creates one based on the . - std::string getSpecConstantSymbol(uint32_t id); - - /// Gets the specialization constant with the given result . - spirv::SpecConstantOp getSpecConstant(uint32_t id) { - return specConstMap.lookup(id); - } - - /// Gets the composite specialization constant with the given result . - spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) { - return specConstCompositeMap.lookup(id); - } - - /// Creates a spirv::SpecConstantOp. - spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, - Attribute defaultValue); - - /// Processes the OpVariable instructions at current `offset` into `binary`. - /// It is expected that this method is used for variables that are to be - /// defined at module scope and will be deserialized into a spv.globalVariable - /// instruction. - LogicalResult processGlobalVariable(ArrayRef operands); - - /// Gets the global variable associated with a result of OpVariable. - spirv::GlobalVariableOp getGlobalVariable(uint32_t id) { - return globalVariableMap.lookup(id); - } - - //===--------------------------------------------------------------------===// - // Type - //===--------------------------------------------------------------------===// - - /// Gets type for a given result . - Type getType(uint32_t id) { return typeMap.lookup(id); } - - /// Get the type associated with the result of an OpUndef. - Type getUndefType(uint32_t id) { return undefMap.lookup(id); } - - /// Returns true if the given `type` is for SPIR-V void type. - bool isVoidType(Type type) const { return type.isa(); } - - /// Processes a SPIR-V type instruction with given `opcode` and `operands` and - /// registers the type into `module`. - LogicalResult processType(spirv::Opcode opcode, ArrayRef operands); - - LogicalResult processOpTypePointer(ArrayRef operands); - - LogicalResult processArrayType(ArrayRef operands); - - LogicalResult processCooperativeMatrixType(ArrayRef operands); - - LogicalResult processFunctionType(ArrayRef operands); - - LogicalResult processRuntimeArrayType(ArrayRef operands); - - LogicalResult processStructType(ArrayRef operands); - - LogicalResult processMatrixType(ArrayRef operands); - - //===--------------------------------------------------------------------===// - // Constant - //===--------------------------------------------------------------------===// - - /// Processes a SPIR-V Op{|Spec}Constant instruction with the given - /// `operands`. `isSpec` indicates whether this is a specialization constant. - LogicalResult processConstant(ArrayRef operands, bool isSpec); - - /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the - /// given `operands`. `isSpec` indicates whether this is a specialization - /// constant. - LogicalResult processConstantBool(bool isTrue, ArrayRef operands, - bool isSpec); - - /// Processes a SPIR-V OpConstantComposite instruction with the given - /// `operands`. - LogicalResult processConstantComposite(ArrayRef operands); - - LogicalResult processSpecConstantComposite(ArrayRef operands); - - /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. - LogicalResult processConstantNull(ArrayRef operands); - - //===--------------------------------------------------------------------===// - // Debug - //===--------------------------------------------------------------------===// - - /// Discontinues any source-level location information that might be active - /// from a previous OpLine instruction. - LogicalResult clearDebugLine(); - - /// Creates a FileLineColLoc with the OpLine location information. - Location createFileLineColLoc(OpBuilder opBuilder); - - /// Processes a SPIR-V OpLine instruction with the given `operands`. - LogicalResult processDebugLine(ArrayRef operands); - - /// Processes a SPIR-V OpString instruction with the given `operands`. - LogicalResult processDebugString(ArrayRef operands); - - //===--------------------------------------------------------------------===// - // Control flow - //===--------------------------------------------------------------------===// - - /// Returns the block for the given label . - Block *getBlock(uint32_t id) const { return blockMap.lookup(id); } - - // In SPIR-V, structured control flow is explicitly declared using merge - // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect, - // we use spv.selection and spv.loop to group structured control flow. - // The deserializer need to turn structured control flow marked with merge - // instructions into using spv.selection/spv.loop ops. - // - // Because structured control flow can nest and the basic block order have - // flexibility, we cannot isolate a structured selection/loop without - // deserializing all the blocks. So we use the following approach: - // - // 1. Deserialize all basic blocks in a function and create MLIR blocks for - // them into the function's region. In the meanwhile, keep a map between - // selection/loop header blocks to their corresponding merge (and continue) - // target blocks. - // 2. For each selection/loop header block, recursively get all basic blocks - // reachable (except the merge block) and put them in a newly created - // spv.selection/spv.loop's region. Structured control flow guarantees - // that we enter and exit in structured ways and the construct is nestable. - // 3. Put the new spv.selection/spv.loop op at the beginning of the old merge - // block and redirect all branches to the old header block to the old - // merge block (which contains the spv.selection/spv.loop op now). - - /// For OpPhi instructions, we use block arguments to represent them. OpPhi - /// encodes a list of (value, predecessor) pairs. At the time of handling the - /// block containing an OpPhi instruction, the predecessor block might not be - /// processed yet, also the value sent by it. So we need to defer handling - /// the block argument from the predecessors. We use the following approach: - /// - /// 1. For each OpPhi instruction, add a block argument to the current block - /// in construction. Record the block argument in `valueMap` so its uses - /// can be resolved. For the list of (value, predecessor) pairs, update - /// `blockPhiInfo` for bookkeeping. - /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each - /// block recorded there to create the proper block arguments on their - /// terminators. - - /// A data structure for containing a SPIR-V block's phi info. It will be - /// represented as block argument in SPIR-V dialect. - using BlockPhiInfo = - SmallVector; // The result of the values sent - - /// Gets or creates the block corresponding to the given label . The newly - /// created block will always be placed at the end of the current function. - Block *getOrCreateBlock(uint32_t id); - - LogicalResult processBranch(ArrayRef operands); - - LogicalResult processBranchConditional(ArrayRef operands); - - /// Processes a SPIR-V OpLabel instruction with the given `operands`. - LogicalResult processLabel(ArrayRef operands); - - /// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`. - LogicalResult processSelectionMerge(ArrayRef operands); - - /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`. - LogicalResult processLoopMerge(ArrayRef operands); - - /// Processes a SPIR-V OpPhi instruction with the given `operands`. - LogicalResult processPhi(ArrayRef operands); - - /// Creates block arguments on predecessors previously recorded when handling - /// OpPhi instructions. - LogicalResult wireUpBlockArgument(); - - /// Extracts blocks belonging to a structured selection/loop into a - /// spv.selection/spv.loop op. This method iterates until all blocks - /// declared as selection/loop headers are handled. - LogicalResult structurizeControlFlow(); - - //===--------------------------------------------------------------------===// - // Instruction - //===--------------------------------------------------------------------===// - - /// Get the Value associated with a result . - /// - /// This method materializes normal constants and inserts "casting" ops - /// (`spv.mlir.addressof` and `spv.mlir.referenceof`) to turn an symbol into a - /// SSA value for handling uses of module scope constants/variables in - /// functions. - Value getValue(uint32_t id); - - /// Slices the first instruction out of `binary` and returns its opcode and - /// operands via `opcode` and `operands` respectively. Returns failure if - /// there is no more remaining instructions (`expectedOpcode` will be used to - /// compose the error message) or the next instruction is malformed. - LogicalResult - sliceInstruction(spirv::Opcode &opcode, ArrayRef &operands, - Optional expectedOpcode = llvm::None); - - /// Processes a SPIR-V instruction with the given `opcode` and `operands`. - /// This method is the main entrance for handling SPIR-V instruction; it - /// checks the instruction opcode and dispatches to the corresponding handler. - /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode) - /// might need to be deferred, since they contain forward references to s - /// in the deserialized binary, but module in SPIR-V dialect expects these to - /// be ssa-uses. - LogicalResult processInstruction(spirv::Opcode opcode, - ArrayRef operands, - bool deferInstructions = true); - - /// Processes a SPIR-V instruction from the given `operands`. It should - /// deserialize into an op with the given `opName` and `numOperands`. - /// This method is a generic one for dispatching any SPIR-V ops without - /// variadic operands and attributes in TableGen definitions. - LogicalResult processOpWithoutGrammarAttr(ArrayRef words, - StringRef opName, bool hasResult, - unsigned numOperands); - - /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current - /// insertion point. - LogicalResult processUndef(ArrayRef operands); - - LogicalResult processTypeForwardPointer(ArrayRef operands); - - /// Method to dispatch to the specialized deserialization function for an - /// operation in SPIR-V dialect that is a mirror of an instruction in the - /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for - /// all operations in SPIR-V dialect that have hasOpcode == 1. - LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, - ArrayRef words); - - /// Processes a SPIR-V OpExtInst with given `operands`. This slices the - /// entries of `operands` that specify the extended instruction set and - /// the instruction opcode. The op deserializer is then invoked using the - /// other entries. - LogicalResult processExtInst(ArrayRef operands); - - /// Dispatches the deserialization of extended instruction set operation based - /// on the extended instruction set name, and instruction opcode. This is - /// autogenerated from ODS. - LogicalResult - dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName, - uint32_t instructionID, - ArrayRef words); - - /// Method to deserialize an operation in the SPIR-V dialect that is a mirror - /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode - /// == 1 and autogenSerialization == 1 in ODS. - template - LogicalResult processOp(ArrayRef words) { - return emitError(unknownLoc, "unsupported deserialization for ") - << OpTy::getOperationName() << " op"; - } - -private: - /// The SPIR-V binary module. - ArrayRef binary; - - /// Contains the data of the OpLine instruction which precedes the current - /// processing instruction. - llvm::Optional debugLine; - - /// The current word offset into the binary module. - unsigned curOffset = 0; - - /// MLIRContext to create SPIR-V ModuleOp into. - MLIRContext *context; - - // TODO: create Location subclass for binary blob - Location unknownLoc; - - /// The SPIR-V ModuleOp. - spirv::OwningSPIRVModuleRef module; - - /// The current function under construction. - Optional curFunction; - - /// The current block under construction. - Block *curBlock = nullptr; - - OpBuilder opBuilder; - - spirv::Version version; - - /// The list of capabilities used by the module. - llvm::SmallSetVector capabilities; - - /// The list of extensions used by the module. - llvm::SmallSetVector extensions; - - // Result to type mapping. - DenseMap typeMap; - - // Result to constant attribute and type mapping. - /// - /// In the SPIR-V binary format, all constants are placed in the module and - /// shared by instructions at module level and in subsequent functions. But in - /// the SPIR-V dialect, we materialize the constant to where it's used in the - /// function. So when seeing a constant instruction in the binary format, we - /// don't immediately emit a constant op into the module, we keep its value - /// (and type) here. Later when it's used, we materialize the constant. - DenseMap> constantMap; - - // Result to spec constant mapping. - DenseMap specConstMap; - - // Result to composite spec constant mapping. - DenseMap specConstCompositeMap; - - // Result to variable mapping. - DenseMap globalVariableMap; - - // Result to function mapping. - DenseMap funcMap; - - // Result to block mapping. - DenseMap blockMap; - - // Header block to its merge (and continue) target mapping. - BlockMergeInfoMap blockMergeInfo; - - // Block to its phi (block argument) mapping. - DenseMap blockPhiInfo; - - // Result to value mapping. - DenseMap valueMap; - - // Mapping from result to undef value of a type. - DenseMap undefMap; - - // Result to name mapping. - DenseMap nameMap; - - // Result to debug info mapping. - DenseMap debugInfoMap; - - // Result to decorations mapping. - DenseMap decorations; - - // Result to type decorations. - DenseMap typeDecorations; - - // Result to member decorations. - // decorated-struct-type- -> - // (struct-member-index -> (decoration -> decoration-operands)) - DenseMap>>> - memberDecorationMap; - - // Result to member name. - // struct-type- -> (struct-member-index -> name) - DenseMap> memberNameMap; - - // Result to extended instruction set name. - DenseMap extendedInstSets; - - // List of instructions that are processed in a deferred fashion (after an - // initial processing of the entire binary). Some operations like - // OpEntryPoint, and OpExecutionMode use forward references to function - // s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and - // spv.ExecutionMode) need these references resolved. So these instructions - // are deserialized and stored for processing once the entire binary is - // processed. - SmallVector>, 4> - deferredInstructions; - - /// A list of IDs for all types forward-declared through OpTypeForwardPointer - /// instructions. - llvm::SetVector typeForwardPointerIDs; - - /// A list of all structs which have unresolved member types. - SmallVector deferredStructTypesInfos; -}; -} // namespace - //===----------------------------------------------------------------------===// // Deserializer Method Definitions //===----------------------------------------------------------------------===// -Deserializer::Deserializer(ArrayRef binary, MLIRContext *context) +spirv::Deserializer::Deserializer(ArrayRef binary, + MLIRContext *context) : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), module(createModuleOp()), opBuilder(module->body()) {} -LogicalResult Deserializer::deserialize() { +LogicalResult spirv::Deserializer::deserialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); if (failed(processHeader())) @@ -642,7 +86,7 @@ return success(); } -spirv::OwningSPIRVModuleRef Deserializer::collect() { +spirv::OwningSPIRVModuleRef spirv::Deserializer::collect() { return std::move(module); } @@ -650,14 +94,14 @@ // Module structure //===----------------------------------------------------------------------===// -spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() { +spirv::OwningSPIRVModuleRef spirv::Deserializer::createModuleOp() { OpBuilder builder(context); OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); spirv::ModuleOp::build(builder, state); return cast(Operation::create(state)); } -LogicalResult Deserializer::processHeader() { +LogicalResult spirv::Deserializer::processHeader() { if (binary.size() < spirv::kHeaderWordCount) return emitError(unknownLoc, "SPIR-V binary module must have a 5-word header"); @@ -696,7 +140,8 @@ return success(); } -LogicalResult Deserializer::processCapability(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processCapability(ArrayRef operands) { if (operands.size() != 1) return emitError(unknownLoc, "OpMemoryModel must have one parameter"); @@ -708,7 +153,7 @@ return success(); } -LogicalResult Deserializer::processExtension(ArrayRef words) { +LogicalResult spirv::Deserializer::processExtension(ArrayRef words) { if (words.empty()) { return emitError( unknownLoc, @@ -728,7 +173,8 @@ return success(); } -LogicalResult Deserializer::processExtInstImport(ArrayRef words) { +LogicalResult +spirv::Deserializer::processExtInstImport(ArrayRef words) { if (words.size() < 2) { return emitError(unknownLoc, "OpExtInstImport must have a result and a literal " @@ -744,14 +190,15 @@ return success(); } -void Deserializer::attachVCETriple() { +void spirv::Deserializer::attachVCETriple() { (*module)->setAttr( spirv::ModuleOp::getVCETripleAttrName(), spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), extensions.getArrayRef(), context)); } -LogicalResult Deserializer::processMemoryModel(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processMemoryModel(ArrayRef operands) { if (operands.size() != 2) return emitError(unknownLoc, "OpMemoryModel must have two operands"); @@ -765,7 +212,7 @@ return success(); } -LogicalResult Deserializer::processDecoration(ArrayRef words) { +LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { // TODO: This function should also be auto-generated. For now, since only a // few decorations are processed/handled in a meaningful manner, going with a // manual implementation. @@ -839,7 +286,8 @@ return success(); } -LogicalResult Deserializer::processMemberDecoration(ArrayRef words) { +LogicalResult +spirv::Deserializer::processMemberDecoration(ArrayRef words) { // The binary layout of OpMemberDecorate is different comparing to OpDecorate if (words.size() < 3) { return emitError(unknownLoc, @@ -860,7 +308,7 @@ return success(); } -LogicalResult Deserializer::processMemberName(ArrayRef words) { +LogicalResult spirv::Deserializer::processMemberName(ArrayRef words) { if (words.size() < 3) { return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); } @@ -874,7 +322,8 @@ return success(); } -LogicalResult Deserializer::processFunction(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processFunction(ArrayRef operands) { if (curFunction) { return emitError(unknownLoc, "found function inside function"); } @@ -1011,7 +460,8 @@ return processFunctionEnd(instOperands); } -LogicalResult Deserializer::processFunctionEnd(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processFunctionEnd(ArrayRef operands) { // Process OpFunctionEnd. if (!operands.empty()) { return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); @@ -1029,14 +479,15 @@ return success(); } -Optional> Deserializer::getConstant(uint32_t id) { +Optional> +spirv::Deserializer::getConstant(uint32_t id) { auto constIt = constantMap.find(id); if (constIt == constantMap.end()) return llvm::None; return constIt->getSecond(); } -std::string Deserializer::getFunctionSymbol(uint32_t id) { +std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { auto funcName = nameMap.lookup(id).str(); if (funcName.empty()) { funcName = "spirv_fn_" + std::to_string(id); @@ -1044,7 +495,7 @@ return funcName; } -std::string Deserializer::getSpecConstantSymbol(uint32_t id) { +std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { auto constName = nameMap.lookup(id).str(); if (constName.empty()) { constName = "spirv_spec_const_" + std::to_string(id); @@ -1052,9 +503,9 @@ return constName; } -spirv::SpecConstantOp Deserializer::createSpecConstant(Location loc, - uint32_t resultID, - Attribute defaultValue) { +spirv::SpecConstantOp +spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, + Attribute defaultValue) { auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); auto op = opBuilder.create(unknownLoc, symName, defaultValue); @@ -1066,7 +517,8 @@ return op; } -LogicalResult Deserializer::processGlobalVariable(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processGlobalVariable(ArrayRef operands) { unsigned wordIndex = 0; if (operands.size() < 3) { return emitError( @@ -1137,7 +589,7 @@ return success(); } -IntegerAttr Deserializer::getConstantInt(uint32_t id) { +IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { auto constInfo = getConstant(id); if (!constInfo) { return nullptr; @@ -1145,7 +597,7 @@ return constInfo->first.dyn_cast(); } -LogicalResult Deserializer::processName(ArrayRef operands) { +LogicalResult spirv::Deserializer::processName(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpName needs at least 2 operands"); } @@ -1167,8 +619,8 @@ // Type //===----------------------------------------------------------------------===// -LogicalResult Deserializer::processType(spirv::Opcode opcode, - ArrayRef operands) { +LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, + ArrayRef operands) { if (operands.empty()) { return emitError(unknownLoc, "type instruction with opcode ") << spirv::stringifyOpcode(opcode) << " needs at least one "; @@ -1263,7 +715,8 @@ return success(); } -LogicalResult Deserializer::processOpTypePointer(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processOpTypePointer(ArrayRef operands) { if (operands.size() != 3) return emitError(unknownLoc, "OpTypePointer must have two parameters"); @@ -1316,7 +769,8 @@ return success(); } -LogicalResult Deserializer::processArrayType(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processArrayType(ArrayRef operands) { if (operands.size() != 3) { return emitError(unknownLoc, "OpTypeArray must have element type and count parameters"); @@ -1348,7 +802,8 @@ return success(); } -LogicalResult Deserializer::processFunctionType(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processFunctionType(ArrayRef operands) { assert(!operands.empty() && "No operands for processing function type"); if (operands.size() == 1) { return emitError(unknownLoc, "missing return type for OpTypeFunction"); @@ -1374,7 +829,7 @@ } LogicalResult -Deserializer::processCooperativeMatrixType(ArrayRef operands) { +spirv::Deserializer::processCooperativeMatrixType(ArrayRef operands) { if (operands.size() != 5) { return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " "type and row x column parameters"); @@ -1403,7 +858,7 @@ } LogicalResult -Deserializer::processRuntimeArrayType(ArrayRef operands) { +spirv::Deserializer::processRuntimeArrayType(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); } @@ -1418,7 +873,8 @@ return success(); } -LogicalResult Deserializer::processStructType(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processStructType(ArrayRef operands) { // TODO: Find a way to handle identified structs when debug info is stripped. if (operands.empty()) { @@ -1505,7 +961,8 @@ return success(); } -LogicalResult Deserializer::processMatrixType(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processMatrixType(ArrayRef operands) { if (operands.size() != 3) { // Three operands are needed: result_id, column_type, and column_count return emitError(unknownLoc, "OpTypeMatrix must have 3 operands" @@ -1524,12 +981,25 @@ return success(); } +LogicalResult +spirv::Deserializer::processTypeForwardPointer(ArrayRef operands) { + if (operands.size() != 2) + return emitError(unknownLoc, + "OpTypeForwardPointer instruction must have two operands"); + + typeForwardPointerIDs.insert(operands[0]); + // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer + // instruction that defines the actual type. + + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// -LogicalResult Deserializer::processConstant(ArrayRef operands, - bool isSpec) { +LogicalResult spirv::Deserializer::processConstant(ArrayRef operands, + bool isSpec) { StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; if (operands.size() < 2) { @@ -1642,9 +1112,8 @@ "scalar integer or floating-point type"); } -LogicalResult Deserializer::processConstantBool(bool isTrue, - ArrayRef operands, - bool isSpec) { +LogicalResult spirv::Deserializer::processConstantBool( + bool isTrue, ArrayRef operands, bool isSpec) { if (operands.size() != 2) { return emitError(unknownLoc, "Op") << (isSpec ? "Spec" : "") << "Constant" @@ -1666,7 +1135,7 @@ } LogicalResult -Deserializer::processConstantComposite(ArrayRef operands) { +spirv::Deserializer::processConstantComposite(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpConstantComposite must have type and result "); @@ -1711,7 +1180,7 @@ } LogicalResult -Deserializer::processSpecConstantComposite(ArrayRef operands) { +spirv::Deserializer::processSpecConstantComposite(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpConstantComposite must have type and result "); @@ -1745,7 +1214,8 @@ return success(); } -LogicalResult Deserializer::processConstantNull(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processConstantNull(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpConstantNull must have type and result "); @@ -1774,7 +1244,7 @@ // Control flow //===----------------------------------------------------------------------===// -Block *Deserializer::getOrCreateBlock(uint32_t id) { +Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) { if (auto *block = getBlock(id)) { LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id << " @ " << block << "\n"); @@ -1790,7 +1260,7 @@ return blockMap[id] = block; } -LogicalResult Deserializer::processBranch(ArrayRef operands) { +LogicalResult spirv::Deserializer::processBranch(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpBranch must appear inside a block"); } @@ -1811,7 +1281,7 @@ } LogicalResult -Deserializer::processBranchConditional(ArrayRef operands) { +spirv::Deserializer::processBranchConditional(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpBranchConditional must appear inside a block"); @@ -1844,7 +1314,7 @@ return success(); } -LogicalResult Deserializer::processLabel(ArrayRef operands) { +LogicalResult spirv::Deserializer::processLabel(ArrayRef operands) { if (!curFunction) { return emitError(unknownLoc, "OpLabel must appear inside a function"); } @@ -1866,7 +1336,8 @@ return success(); } -LogicalResult Deserializer::processSelectionMerge(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processSelectionMerge(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpSelectionMerge must appear in a block"); } @@ -1891,7 +1362,8 @@ return success(); } -LogicalResult Deserializer::processLoopMerge(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processLoopMerge(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpLoopMerge must appear in a block"); } @@ -1917,7 +1389,7 @@ return success(); } -LogicalResult Deserializer::processPhi(ArrayRef operands) { +LogicalResult spirv::Deserializer::processPhi(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpPhi must appear in a block"); } @@ -1961,7 +1433,7 @@ /// This method will also update `mergeInfo` by remapping all blocks inside to /// the newly cloned ones inside structured control flow op's regions. static LogicalResult structurize(Location loc, uint32_t control, - BlockMergeInfoMap &mergeInfo, + spirv::BlockMergeInfoMap &mergeInfo, Block *headerBlock, Block *mergeBlock, Block *continueBlock) { return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock, @@ -1971,7 +1443,7 @@ private: ControlFlowStructurizer(Location loc, uint32_t control, - BlockMergeInfoMap &mergeInfo, Block *header, + spirv::BlockMergeInfoMap &mergeInfo, Block *header, Block *merge, Block *cont) : location(loc), control(control), blockMergeInfo(mergeInfo), headerBlock(header), mergeBlock(merge), continueBlock(cont) {} @@ -1990,7 +1462,7 @@ Location location; uint32_t control; - BlockMergeInfoMap &blockMergeInfo; + spirv::BlockMergeInfoMap &blockMergeInfo; Block *headerBlock; Block *mergeBlock; @@ -2214,7 +1686,7 @@ return success(); } -LogicalResult Deserializer::wireUpBlockArgument() { +LogicalResult spirv::Deserializer::wireUpBlockArgument() { LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); OpBuilder::InsertionGuard guard(opBuilder); @@ -2263,7 +1735,7 @@ return success(); } -LogicalResult Deserializer::structurizeControlFlow() { +LogicalResult spirv::Deserializer::structurizeControlFlow() { LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); while (!blockMergeInfo.empty()) { @@ -2303,7 +1775,7 @@ // Debug //===----------------------------------------------------------------------===// -Location Deserializer::createFileLineColLoc(OpBuilder opBuilder) { +Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) { if (!debugLine) return unknownLoc; @@ -2314,7 +1786,8 @@ debugLine->line, debugLine->col); } -LogicalResult Deserializer::processDebugLine(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processDebugLine(ArrayRef operands) { // According to SPIR-V spec: // "This location information applies to the instructions physically // following this instruction, up to the first occurrence of any of the @@ -2326,12 +1799,13 @@ return success(); } -LogicalResult Deserializer::clearDebugLine() { +LogicalResult spirv::Deserializer::clearDebugLine() { debugLine = llvm::None; return success(); } -LogicalResult Deserializer::processDebugString(ArrayRef operands) { +LogicalResult +spirv::Deserializer::processDebugString(ArrayRef operands) { if (operands.size() < 2) return emitError(unknownLoc, "OpString needs at least 2 operands"); @@ -2349,552 +1823,3 @@ debugInfoMap[operands[0]] = debugString; return success(); } - -//===----------------------------------------------------------------------===// -// Instruction -//===----------------------------------------------------------------------===// - -Value Deserializer::getValue(uint32_t id) { - if (auto constInfo = getConstant(id)) { - // Materialize a `spv.constant` op at every use site. - return opBuilder.create(unknownLoc, constInfo->second, - constInfo->first); - } - if (auto varOp = getGlobalVariable(id)) { - auto addressOfOp = opBuilder.create( - unknownLoc, varOp.type(), - opBuilder.getSymbolRefAttr(varOp.getOperation())); - return addressOfOp.pointer(); - } - if (auto constOp = getSpecConstant(id)) { - auto referenceOfOp = opBuilder.create( - unknownLoc, constOp.default_value().getType(), - opBuilder.getSymbolRefAttr(constOp.getOperation())); - return referenceOfOp.reference(); - } - if (auto constCompositeOp = getSpecConstantComposite(id)) { - auto referenceOfOp = opBuilder.create( - unknownLoc, constCompositeOp.type(), - opBuilder.getSymbolRefAttr(constCompositeOp.getOperation())); - return referenceOfOp.reference(); - } - if (auto undef = getUndefType(id)) { - return opBuilder.create(unknownLoc, undef); - } - return valueMap.lookup(id); -} - -LogicalResult -Deserializer::sliceInstruction(spirv::Opcode &opcode, - ArrayRef &operands, - Optional expectedOpcode) { - auto binarySize = binary.size(); - if (curOffset >= binarySize) { - return emitError(unknownLoc, "expected ") - << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) - : "more") - << " instruction"; - } - - // For each instruction, get its word count from the first word to slice it - // from the stream properly, and then dispatch to the instruction handler. - - uint32_t wordCount = binary[curOffset] >> 16; - - if (wordCount == 0) - return emitError(unknownLoc, "word count cannot be zero"); - - uint32_t nextOffset = curOffset + wordCount; - if (nextOffset > binarySize) - return emitError(unknownLoc, "insufficient words for the last instruction"); - - opcode = extractOpcode(binary[curOffset]); - operands = binary.slice(curOffset + 1, wordCount - 1); - curOffset = nextOffset; - return success(); -} - -LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, - ArrayRef operands, - bool deferInstructions) { - LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction " - << spirv::stringifyOpcode(opcode) << "\n"); - - // First dispatch all the instructions whose opcode does not correspond to - // those that have a direct mirror in the SPIR-V dialect - switch (opcode) { - case spirv::Opcode::OpCapability: - return processCapability(operands); - case spirv::Opcode::OpExtension: - return processExtension(operands); - case spirv::Opcode::OpExtInst: - return processExtInst(operands); - case spirv::Opcode::OpExtInstImport: - return processExtInstImport(operands); - case spirv::Opcode::OpMemberName: - return processMemberName(operands); - case spirv::Opcode::OpMemoryModel: - return processMemoryModel(operands); - case spirv::Opcode::OpEntryPoint: - case spirv::Opcode::OpExecutionMode: - if (deferInstructions) { - deferredInstructions.emplace_back(opcode, operands); - return success(); - } - break; - case spirv::Opcode::OpVariable: - if (isa(opBuilder.getBlock()->getParentOp())) { - return processGlobalVariable(operands); - } - break; - case spirv::Opcode::OpLine: - return processDebugLine(operands); - case spirv::Opcode::OpNoLine: - return clearDebugLine(); - case spirv::Opcode::OpName: - return processName(operands); - case spirv::Opcode::OpString: - return processDebugString(operands); - case spirv::Opcode::OpModuleProcessed: - case spirv::Opcode::OpSource: - case spirv::Opcode::OpSourceContinued: - case spirv::Opcode::OpSourceExtension: - // TODO: This is debug information embedded in the binary which should be - // translated into the spv.module. - return success(); - case spirv::Opcode::OpTypeVoid: - case spirv::Opcode::OpTypeBool: - case spirv::Opcode::OpTypeInt: - case spirv::Opcode::OpTypeFloat: - case spirv::Opcode::OpTypeVector: - case spirv::Opcode::OpTypeMatrix: - case spirv::Opcode::OpTypeArray: - case spirv::Opcode::OpTypeFunction: - case spirv::Opcode::OpTypeRuntimeArray: - case spirv::Opcode::OpTypeStruct: - case spirv::Opcode::OpTypePointer: - case spirv::Opcode::OpTypeCooperativeMatrixNV: - return processType(opcode, operands); - case spirv::Opcode::OpConstant: - return processConstant(operands, /*isSpec=*/false); - case spirv::Opcode::OpSpecConstant: - return processConstant(operands, /*isSpec=*/true); - case spirv::Opcode::OpConstantComposite: - return processConstantComposite(operands); - case spirv::Opcode::OpSpecConstantComposite: - return processSpecConstantComposite(operands); - case spirv::Opcode::OpConstantTrue: - return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); - case spirv::Opcode::OpSpecConstantTrue: - return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); - case spirv::Opcode::OpConstantFalse: - return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); - case spirv::Opcode::OpSpecConstantFalse: - return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); - case spirv::Opcode::OpConstantNull: - return processConstantNull(operands); - case spirv::Opcode::OpDecorate: - return processDecoration(operands); - case spirv::Opcode::OpMemberDecorate: - return processMemberDecoration(operands); - case spirv::Opcode::OpFunction: - return processFunction(operands); - case spirv::Opcode::OpLabel: - return processLabel(operands); - case spirv::Opcode::OpBranch: - return processBranch(operands); - case spirv::Opcode::OpBranchConditional: - return processBranchConditional(operands); - case spirv::Opcode::OpSelectionMerge: - return processSelectionMerge(operands); - case spirv::Opcode::OpLoopMerge: - return processLoopMerge(operands); - case spirv::Opcode::OpPhi: - return processPhi(operands); - case spirv::Opcode::OpUndef: - return processUndef(operands); - case spirv::Opcode::OpTypeForwardPointer: - return processTypeForwardPointer(operands); - default: - break; - } - return dispatchToAutogenDeserialization(opcode, operands); -} - -LogicalResult -Deserializer::processOpWithoutGrammarAttr(ArrayRef words, - StringRef opName, bool hasResult, - unsigned numOperands) { - SmallVector resultTypes; - uint32_t valueID = 0; - - size_t wordIndex= 0; - if (hasResult) { - if (wordIndex >= words.size()) - return emitError(unknownLoc, - "expected result type while deserializing for ") - << opName; - - // Decode the type - auto type = getType(words[wordIndex]); - if (!type) - return emitError(unknownLoc, "unknown type result : ") - << words[wordIndex]; - resultTypes.push_back(type); - ++wordIndex; - - // Decode the result - if (wordIndex >= words.size()) - return emitError(unknownLoc, - "expected result while deserializing for ") - << opName; - valueID = words[wordIndex]; - ++wordIndex; - } - - SmallVector operands; - SmallVector attributes; - - // Decode operands - size_t operandIndex = 0; - for (; operandIndex < numOperands && wordIndex < words.size(); - ++operandIndex, ++wordIndex) { - auto arg = getValue(words[wordIndex]); - if (!arg) - return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; - operands.push_back(arg); - } - if (operandIndex != numOperands) { - return emitError( - unknownLoc, - "found less operands than expected when deserializing for ") - << opName << "; only " << operandIndex << " of " << numOperands - << " processed"; - } - if (wordIndex != words.size()) { - return emitError( - unknownLoc, - "found more operands than expected when deserializing for ") - << opName << "; only " << wordIndex << " of " << words.size() - << " processed"; - } - - // Attach attributes from decorations - if (decorations.count(valueID)) { - auto attrs = decorations[valueID].getAttrs(); - attributes.append(attrs.begin(), attrs.end()); - } - - // Create the op and update bookkeeping maps - Location loc = createFileLineColLoc(opBuilder); - OperationState opState(loc, opName); - opState.addOperands(operands); - if (hasResult) - opState.addTypes(resultTypes); - opState.addAttributes(attributes); - Operation *op = opBuilder.createOperation(opState); - if (hasResult) - valueMap[valueID] = op->getResult(0); - - if (op->hasTrait()) - clearDebugLine(); - - return success(); -} - -LogicalResult Deserializer::processUndef(ArrayRef operands) { - if (operands.size() != 2) { - return emitError(unknownLoc, "OpUndef instruction must have two operands"); - } - auto type = getType(operands[0]); - if (!type) { - return emitError(unknownLoc, "unknown type with OpUndef instruction"); - } - undefMap[operands[1]] = type; - return success(); -} - -LogicalResult -Deserializer::processTypeForwardPointer(ArrayRef operands) { - if (operands.size() != 2) - return emitError(unknownLoc, - "OpTypeForwardPointer instruction must have two operands"); - - typeForwardPointerIDs.insert(operands[0]); - // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer - // instruction that defines the actual type. - - return success(); -} - -LogicalResult Deserializer::processExtInst(ArrayRef operands) { - if (operands.size() < 4) { - return emitError(unknownLoc, - "OpExtInst must have at least 4 operands, result type " - ", result , set and instruction opcode"); - } - if (!extendedInstSets.count(operands[2])) { - return emitError(unknownLoc, "undefined set in OpExtInst"); - } - SmallVector slicedOperands; - slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); - slicedOperands.append(std::next(operands.begin(), 4), operands.end()); - return dispatchToExtensionSetAutogenDeserialization( - extendedInstSets[operands[2]], operands[3], slicedOperands); -} - -namespace { - -template <> -LogicalResult -Deserializer::processOp(ArrayRef words) { - unsigned wordIndex = 0; - if (wordIndex >= words.size()) { - return emitError(unknownLoc, - "missing Execution Model specification in OpEntryPoint"); - } - auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]); - if (wordIndex >= words.size()) { - return emitError(unknownLoc, "missing in OpEntryPoint"); - } - // Get the function - auto fnID = words[wordIndex++]; - // Get the function name - auto fnName = decodeStringLiteral(words, wordIndex); - // Verify that the function matches the fnName - auto parsedFunc = getFunction(fnID); - if (!parsedFunc) { - return emitError(unknownLoc, "no function matching ") << fnID; - } - if (parsedFunc.getName() != fnName) { - return emitError(unknownLoc, "function name mismatch between OpEntryPoint " - "and OpFunction with ") - << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); - } - SmallVector interface; - while (wordIndex < words.size()) { - auto arg = getGlobalVariable(words[wordIndex]); - if (!arg) { - return emitError(unknownLoc, "undefined result ") - << words[wordIndex] << " while decoding OpEntryPoint"; - } - interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); - wordIndex++; - } - opBuilder.create(unknownLoc, execModel, - opBuilder.getSymbolRefAttr(fnName), - opBuilder.getArrayAttr(interface)); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef words) { - unsigned wordIndex = 0; - if (wordIndex >= words.size()) { - return emitError(unknownLoc, - "missing function result in OpExecutionMode"); - } - // Get the function to get the name of the function - auto fnID = words[wordIndex++]; - auto fn = getFunction(fnID); - if (!fn) { - return emitError(unknownLoc, "no function matching ") << fnID; - } - // Get the Execution mode - if (wordIndex >= words.size()) { - return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); - } - auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); - - // Get the values - SmallVector attrListElems; - while (wordIndex < words.size()) { - attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); - } - auto values = opBuilder.getArrayAttr(attrListElems); - opBuilder.create( - unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef operands) { - if (operands.size() != 3) { - return emitError( - unknownLoc, - "OpControlBarrier must have execution scope , memory scope " - "and memory semantics "); - } - - SmallVector argAttrs; - for (auto operand : operands) { - auto argAttr = getConstantInt(operand); - if (!argAttr) { - return emitError(unknownLoc, - "expected 32-bit integer constant from ") - << operand << " for OpControlBarrier"; - } - argAttrs.push_back(argAttr); - } - - opBuilder.create(unknownLoc, argAttrs[0], - argAttrs[1], argAttrs[2]); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef operands) { - if (operands.size() < 3) { - return emitError(unknownLoc, - "OpFunctionCall must have at least 3 operands"); - } - - Type resultType = getType(operands[0]); - if (!resultType) { - return emitError(unknownLoc, "undefined result type from ") - << operands[0]; - } - - // Use null type to mean no result type. - if (isVoidType(resultType)) - resultType = nullptr; - - auto resultID = operands[1]; - auto functionID = operands[2]; - - auto functionName = getFunctionSymbol(functionID); - - SmallVector arguments; - for (auto operand : llvm::drop_begin(operands, 3)) { - auto value = getValue(operand); - if (!value) { - return emitError(unknownLoc, "unknown ") - << operand << " used by OpFunctionCall"; - } - arguments.push_back(value); - } - - auto opFunctionCall = opBuilder.create( - unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName), - arguments); - - if (resultType) - valueMap[resultID] = opFunctionCall.getResult(0); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef operands) { - if (operands.size() != 2) { - return emitError(unknownLoc, "OpMemoryBarrier must have memory scope " - "and memory semantics "); - } - - SmallVector argAttrs; - for (auto operand : operands) { - auto argAttr = getConstantInt(operand); - if (!argAttr) { - return emitError(unknownLoc, - "expected 32-bit integer constant from ") - << operand << " for OpMemoryBarrier"; - } - argAttrs.push_back(argAttr); - } - - opBuilder.create(unknownLoc, argAttrs[0], - argAttrs[1]); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef words) { - SmallVector resultTypes; - size_t wordIndex = 0; - SmallVector operands; - SmallVector attributes; - - if (wordIndex < words.size()) { - auto arg = getValue(words[wordIndex]); - - if (!arg) { - return emitError(unknownLoc, "unknown result : ") - << words[wordIndex]; - } - - operands.push_back(arg); - wordIndex++; - } - - if (wordIndex < words.size()) { - auto arg = getValue(words[wordIndex]); - - if (!arg) { - return emitError(unknownLoc, "unknown result : ") - << words[wordIndex]; - } - - operands.push_back(arg); - wordIndex++; - } - - bool isAlignedAttr = false; - - if (wordIndex < words.size()) { - auto attrValue = words[wordIndex++]; - attributes.push_back(opBuilder.getNamedAttr( - "memory_access", opBuilder.getI32IntegerAttr(attrValue))); - isAlignedAttr = (attrValue == 2); - } - - if (isAlignedAttr && wordIndex < words.size()) { - attributes.push_back(opBuilder.getNamedAttr( - "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); - } - - if (wordIndex < words.size()) { - attributes.push_back(opBuilder.getNamedAttr( - "source_memory_access", - opBuilder.getI32IntegerAttr(words[wordIndex++]))); - } - - if (wordIndex < words.size()) { - attributes.push_back(opBuilder.getNamedAttr( - "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); - } - - if (wordIndex != words.size()) { - return emitError(unknownLoc, - "found more operands than expected when deserializing " - "spirv::CopyMemoryOp, only ") - << wordIndex << " of " << words.size() << " processed"; - } - - Location loc = createFileLineColLoc(opBuilder); - opBuilder.create(loc, resultTypes, operands, attributes); - - return success(); -} - -// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and -// various Deserializer::processOp<...>() specializations. -#define GET_DESERIALIZATION_FNS -#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" - -} // namespace - -namespace mlir { -spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef binary, - MLIRContext *context) { - Deserializer deserializer(binary, context); - - if (failed(deserializer.deserialize())) - return nullptr; - - return deserializer.collect(); -} -} // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt b/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_translation_library(MLIRSPIRVSerialization + Serialization.cpp + + DEPENDS + MLIRSPIRVSerializationGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSPIRV + MLIRSPIRVBinaryUtils + MLIRSupport + MLIRTranslation + ) + + diff --git a/mlir/lib/Target/SPIRV/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp rename from mlir/lib/Target/SPIRV/Serialization.cpp rename to mlir/lib/Target/SPIRV/Serialization/Serialization.cpp diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -996,11 +996,10 @@ /// based on the `opcode`. static void initDispatchDeserializationFn(StringRef opcode, StringRef words, raw_ostream &os) { - os << formatv( - "LogicalResult " - "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode {0}, " - "ArrayRef {1}) {{\n", - opcode, words); + os << formatv("LogicalResult spirv::Deserializer::" + "dispatchToAutogenDeserialization(spirv::Opcode {0}," + " ArrayRef {1}) {{\n", + opcode, words); os << formatv(" switch ({0}) {{\n", opcode); } @@ -1043,8 +1042,8 @@ StringRef instructionID, StringRef words, raw_ostream &os) { - os << formatv("LogicalResult " - "Deserializer::dispatchToExtensionSetAutogenDeserialization(" + os << formatv("LogicalResult spirv::Deserializer::" + "dispatchToExtensionSetAutogenDeserialization(" "StringRef {0}, uint32_t {1}, ArrayRef {2}) {{\n", extensionSetName, instructionID, words); }