diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h @@ -0,0 +1,29 @@ +//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SPIRVMODULE_H +#define MLIR_DIALECT_SPIRV_SPIRVMODULE_H + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/OwningOpRefBase.h" + +namespace mlir { +namespace spirv { + +/// This class acts as an owning reference to a SPIR-V module, and will +/// automatically destroy the held module on destruction if the held module +/// is valid. +class OwningSPIRVModuleRef : public OwningOpRefBase { +public: + using OwningOpRefBase::OwningOpRefBase; +}; + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SPIRVMODULE_H diff --git a/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/mlir/include/mlir/Dialect/SPIRV/Serialization.h --- a/mlir/include/mlir/Dialect/SPIRV/Serialization.h +++ b/mlir/include/mlir/Dialect/SPIRV/Serialization.h @@ -22,6 +22,7 @@ namespace spirv { class ModuleOp; +class OwningSPIRVModuleRef; /// Serializes the given SPIR-V `module` and writes to `binary`. On failure, /// reports errors to the error handler registered with the MLIR context for @@ -31,9 +32,10 @@ /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp /// in the given `context`. Returns the ModuleOp on success; otherwise, reports -/// errors to the error handler registered with `context` and returns -/// llvm::None. -Optional deserialize(ArrayRef binary, MLIRContext *context); +/// errors to the error handler registered with `context` and returns a null +/// module. +OwningSPIRVModuleRef deserialize(ArrayRef binary, + MLIRContext *context); } // end namespace spirv } // end namespace mlir diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -13,6 +13,7 @@ #ifndef MLIR_IR_MODULE_H #define MLIR_IR_MODULE_H +#include "mlir/IR/OwningOpRefBase.h" #include "mlir/IR/SymbolTable.h" #include "llvm/Support/PointerLikeTypeTraits.h" @@ -122,40 +123,10 @@ }; /// This class acts as an owning reference to a module, and will automatically -/// destroy the held module if valid. -class OwningModuleRef { +/// destroy the held module on destruction if the held module is valid. +class OwningModuleRef : public OwningOpRefBase { public: - OwningModuleRef(std::nullptr_t = nullptr) {} - OwningModuleRef(ModuleOp module) : module(module) {} - OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {} - ~OwningModuleRef() { - if (module) - module.erase(); - } - - // Assign from another module reference. - OwningModuleRef &operator=(OwningModuleRef &&other) { - if (module) - module.erase(); - module = other.release(); - return *this; - } - - /// Allow accessing the internal module. - ModuleOp get() const { return module; } - ModuleOp operator*() const { return module; } - ModuleOp *operator->() { return &module; } - explicit operator bool() const { return module; } - - /// Release the referenced module. - ModuleOp release() { - ModuleOp released; - std::swap(released, module); - return released; - } - -private: - ModuleOp module; + using OwningOpRefBase::OwningOpRefBase; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/OwningOpRefBase.h b/mlir/include/mlir/IR/OwningOpRefBase.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/OwningOpRefBase.h @@ -0,0 +1,64 @@ +//===- OwningOpRefBase.h - MLIR OwningOpRefBase -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a base class for owning module refs. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_MODULEREFBASE_H +#define MLIR_IR_MODULEREFBASE_H + +#include + +namespace mlir { + +/// This class acts as an owning reference to an op, and will automatically +/// destroy the held module on destruction if the held module is valid. +/// +/// Note that OpBuilder and related functionality should be highly preferred +/// instead, and this should only be used in situations where existing solutions +/// are not viable. +template +class OwningOpRefBase { +public: + OwningOpRefBase(std::nullptr_t = nullptr) {} + OwningOpRefBase(ModuleTy module) : module(module) {} + OwningOpRefBase(OwningOpRefBase &&other) : module(other.release()) {} + ~OwningOpRefBase() { + if (module) + module.erase(); + } + + // Assign from another module reference. + OwningOpRefBase &operator=(OwningOpRefBase &&other) { + if (module) + module.erase(); + module = other.release(); + return *this; + } + + /// Allow accessing the internal module. + ModuleTy get() const { return module; } + ModuleTy operator*() const { return module; } + ModuleTy *operator->() { return &module; } + explicit operator bool() const { return module; } + + /// Release the referenced module. + ModuleTy release() { + ModuleTy released; + std::swap(released, module); + return released; + } + +private: + ModuleTy module; +}; + +} // end namespace mlir + +#endif // MLIR_IR_MODULEREFBASE_H diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -1745,783 +1746,4 @@ // For each (value, predecessor) pair, insert the value to the predecessor's // blockPhiInfo entry so later we can fix the block argument there. - for (unsigned i = 2, e = operands.size(); i < e; i += 2) { - uint32_t value = operands[i]; - Block *predecessor = getOrCreateBlock(operands[i + 1]); - blockPhiInfo[predecessor].push_back(value); - LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor - << " with arg id = " << value << '\n'); - } - - return success(); -} - -namespace { -/// A class for putting all blocks in a structured selection/loop in a -/// spv.selection/spv.loop op. -class ControlFlowStructurizer { -public: - /// Structurizes the loop at the given `headerBlock`. - /// - /// This method will create an spv.loop op in the `mergeBlock` and move all - /// blocks in the structured loop into the spv.loop's region. All branches to - /// the `headerBlock` will be redirected to the `mergeBlock`. - /// This method will also update `mergeInfo` by remapping all blocks inside to - /// the newly cloned ones inside structured control flow op's regions. - static LogicalResult structurize(Location loc, BlockMergeInfoMap &mergeInfo, - Block *headerBlock, Block *mergeBlock, - Block *continueBlock) { - return ControlFlowStructurizer(loc, mergeInfo, headerBlock, mergeBlock, - continueBlock) - .structurizeImpl(); - } - -private: - ControlFlowStructurizer(Location loc, BlockMergeInfoMap &mergeInfo, - Block *header, Block *merge, Block *cont) - : location(loc), blockMergeInfo(mergeInfo), headerBlock(header), - mergeBlock(merge), continueBlock(cont) {} - - /// Creates a new spv.selection op at the beginning of the `mergeBlock`. - spirv::SelectionOp createSelectionOp(); - - /// Creates a new spv.loop op at the beginning of the `mergeBlock`. - spirv::LoopOp createLoopOp(); - - /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. - void collectBlocksInConstruct(); - - LogicalResult structurizeImpl(); - - Location location; - - BlockMergeInfoMap &blockMergeInfo; - - Block *headerBlock; - Block *mergeBlock; - Block *continueBlock; // nullptr for spv.selection - - llvm::SetVector constructBlocks; -}; -} // namespace - -spirv::SelectionOp ControlFlowStructurizer::createSelectionOp() { - // Create a builder and set the insertion point to the beginning of the - // merge block so that the newly created SelectionOp will be inserted there. - OpBuilder builder(&mergeBlock->front()); - - auto control = builder.getI32IntegerAttr( - static_cast(spirv::SelectionControl::None)); - auto selectionOp = builder.create(location, control); - selectionOp.addMergeBlock(); - - return selectionOp; -} - -spirv::LoopOp ControlFlowStructurizer::createLoopOp() { - // Create a builder and set the insertion point to the beginning of the - // merge block so that the newly created LoopOp will be inserted there. - OpBuilder builder(&mergeBlock->front()); - - // TODO(antiagainst): handle loop control properly - auto loopOp = builder.create(location); - loopOp.addEntryAndMergeBlock(); - - return loopOp; -} - -void ControlFlowStructurizer::collectBlocksInConstruct() { - assert(constructBlocks.empty() && "expected empty constructBlocks"); - - // Put the header block in the work list first. - constructBlocks.insert(headerBlock); - - // For each item in the work list, add its successors excluding the merge - // block. - for (unsigned i = 0; i < constructBlocks.size(); ++i) { - for (auto *successor : constructBlocks[i]->getSuccessors()) - if (successor != mergeBlock) - constructBlocks.insert(successor); - } -} - -LogicalResult ControlFlowStructurizer::structurizeImpl() { - Operation *op = nullptr; - bool isLoop = continueBlock != nullptr; - if (isLoop) { - if (auto loopOp = createLoopOp()) - op = loopOp.getOperation(); - } else { - if (auto selectionOp = createSelectionOp()) - op = selectionOp.getOperation(); - } - if (!op) - return failure(); - Region &body = op->getRegion(0); - - BlockAndValueMapping mapper; - // All references to the old merge block should be directed to the - // selection/loop merge block in the SelectionOp/LoopOp's region. - mapper.map(mergeBlock, &body.back()); - - collectBlocksInConstruct(); - - // We've identified all blocks belonging to the selection/loop's region. Now - // need to "move" them into the selection/loop. Instead of really moving the - // blocks, in the following we copy them and remap all values and branches. - // This is because: - // * Inserting a block into a region requires the block not in any region - // before. But selections/loops can nest so we can create selection/loop ops - // in a nested manner, which means some blocks may already be in a - // selection/loop region when to be moved again. - // * It's much trickier to fix up the branches into and out of the loop's - // region: we need to treat not-moved blocks and moved blocks differently: - // Not-moved blocks jumping to the loop header block need to jump to the - // merge point containing the new loop op but not the loop continue block's - // back edge. Moved blocks jumping out of the loop need to jump to the - // merge block inside the loop region but not other not-moved blocks. - // We cannot use replaceAllUsesWith clearly and it's harder to follow the - // logic. - - // Create a corresponding block in the SelectionOp/LoopOp's region for each - // block in this loop construct. - OpBuilder builder(body); - for (auto *block : constructBlocks) { - // Create a block and insert it before the selection/loop merge block in the - // SelectionOp/LoopOp's region. - auto *newBlock = builder.createBlock(&body.back()); - mapper.map(block, newBlock); - LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock - << " from block " << block << "\n"); - if (!isFnEntryBlock(block)) { - for (BlockArgument blockArg : block->getArguments()) { - auto newArg = newBlock->addArgument(blockArg.getType()); - mapper.map(blockArg, newArg); - LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg - << " to " << newArg << '\n'); - } - } else { - LLVM_DEBUG(llvm::dbgs() - << "[cf] block " << block << " is a function entry block\n"); - } - for (auto &op : *block) - newBlock->push_back(op.clone(mapper)); - } - - // Go through all ops and remap the operands. - auto remapOperands = [&](Operation *op) { - for (auto &operand : op->getOpOperands()) - if (auto mappedOp = mapper.lookupOrNull(operand.get())) - operand.set(mappedOp); - for (auto &succOp : op->getBlockOperands()) - if (auto mappedOp = mapper.lookupOrNull(succOp.get())) - succOp.set(mappedOp); - }; - for (auto &block : body) { - block.walk(remapOperands); - } - - // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to - // the selection/loop construct into its region. Next we need to fix the - // connections between this new SelectionOp/LoopOp with existing blocks. - - // All existing incoming branches should go to the merge block, where the - // SelectionOp/LoopOp resides right now. - headerBlock->replaceAllUsesWith(mergeBlock); - - if (isLoop) { - // The loop selection/loop header block may have block arguments. Since now - // we place the selection/loop op inside the old merge block, we need to - // make sure the old merge block has the same block argument list. - assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); - for (BlockArgument blockArg : headerBlock->getArguments()) { - mergeBlock->addArgument(blockArg.getType()); - } - - // If the loop header block has block arguments, make sure the spv.branch op - // matches. - SmallVector blockArgs; - if (!headerBlock->args_empty()) - blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; - - // The loop entry block should have a unconditional branch jumping to the - // loop header block. - builder.setInsertionPointToEnd(&body.front()); - builder.create(location, mapper.lookupOrNull(headerBlock), - ArrayRef(blockArgs)); - } - - // All the blocks cloned into the SelectionOp/LoopOp's region can now be - // cleaned up. - LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n"); - // First we need to drop all operands' references inside all blocks. This is - // needed because we can have blocks referencing SSA values from one another. - for (auto *block : constructBlocks) - block->dropAllReferences(); - - // Then erase all old blocks. - for (auto *block : constructBlocks) { - // We've cloned all blocks belonging to this construct into the structured - // control flow op's region. Among these blocks, some may compose another - // selection/loop. If so, they will be recorded within blockMergeInfo. - // We need to update the pointers there to the newly remapped ones so we can - // continue structurizing them later. - // TODO(antiagainst): The asserts in the following assumes input SPIR-V blob - // forms correctly nested selection/loop constructs. We should relax this - // and support error cases better. - auto it = blockMergeInfo.find(block); - if (it != blockMergeInfo.end()) { - Block *newHeader = mapper.lookupOrNull(block); - assert(newHeader && "nested loop header block should be remapped!"); - - Block *newContinue = it->second.continueBlock; - if (newContinue) { - newContinue = mapper.lookupOrNull(newContinue); - assert(newContinue && "nested loop continue block should be remapped!"); - } - - Block *newMerge = it->second.mergeBlock; - if (Block *mappedTo = mapper.lookupOrNull(newMerge)) - newMerge = mappedTo; - - // Keep original location for nested selection/loop ops. - Location loc = it->second.loc; - // The iterator should be erased before adding a new entry into - // blockMergeInfo to avoid iterator invalidation. - blockMergeInfo.erase(it); - blockMergeInfo.try_emplace(newHeader, loc, newMerge, newContinue); - } - - // The structured selection/loop's entry block does not have arguments. - // If the function's header block is also part of the structured control - // flow, we cannot just simply erase it because it may contain arguments - // matching the function signature and used by the cloned blocks. - if (isFnEntryBlock(block)) { - LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block - << " to only contain a spv.Branch op\n"); - // Still keep the function entry block for the potential block arguments, - // but replace all ops inside with a branch to the merge block. - block->clear(); - builder.setInsertionPointToEnd(block); - builder.create(location, mergeBlock); - } else { - LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n"); - block->erase(); - } - } - - LLVM_DEBUG( - llvm::dbgs() << "[cf] after structurizing construct with header block " - << headerBlock << ":\n" - << *op << '\n'); - - return success(); -} - -LogicalResult Deserializer::wireUpBlockArgument() { - LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); - - OpBuilder::InsertionGuard guard(opBuilder); - - for (const auto &info : blockPhiInfo) { - Block *block = info.first; - const BlockPhiInfo &phiInfo = info.second; - LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); - LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << '\n'); - - // Set insertion point to before this block's terminator early because we - // may materialize ops via getValue() call. - auto *op = block->getTerminator(); - opBuilder.setInsertionPoint(op); - - SmallVector blockArgs; - blockArgs.reserve(phiInfo.size()); - for (uint32_t valueId : phiInfo) { - if (Value value = getValue(valueId)) { - blockArgs.push_back(value); - LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value - << " id = " << valueId << '\n'); - } else { - return emitError(unknownLoc, "OpPhi references undefined value!"); - } - } - - if (auto branchOp = dyn_cast(op)) { - // Replace the previous branch op with a new one with block arguments. - opBuilder.create(branchOp.getLoc(), branchOp.getTarget(), - blockArgs); - branchOp.erase(); - } else { - return emitError(unknownLoc, "unimplemented terminator for Phi creation"); - } - - LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n"); - LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << '\n'); - } - blockPhiInfo.clear(); - - LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n"); - return success(); -} - -LogicalResult Deserializer::structurizeControlFlow() { - LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); - - while (!blockMergeInfo.empty()) { - Block *headerBlock = blockMergeInfo.begin()->first; - BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; - - LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n"); - LLVM_DEBUG(headerBlock->print(llvm::dbgs())); - - auto *mergeBlock = mergeInfo.mergeBlock; - assert(mergeBlock && "merge block cannot be nullptr"); - if (!mergeBlock->args_empty()) - return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); - LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n"); - LLVM_DEBUG(mergeBlock->print(llvm::dbgs())); - - auto *continueBlock = mergeInfo.continueBlock; - if (continueBlock) { - LLVM_DEBUG(llvm::dbgs() - << "[cf] continue block " << continueBlock << ":\n"); - LLVM_DEBUG(continueBlock->print(llvm::dbgs())); - } - // Erase this case before calling into structurizer, who will update - // blockMergeInfo. - blockMergeInfo.erase(blockMergeInfo.begin()); - if (failed(ControlFlowStructurizer::structurize(mergeInfo.loc, - blockMergeInfo, headerBlock, - mergeBlock, continueBlock))) - return failure(); - } - - LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n"); - return success(); -} - -//===----------------------------------------------------------------------===// -// Debug -//===----------------------------------------------------------------------===// - -Location Deserializer::createFileLineColLoc(OpBuilder opBuilder) { - if (!debugLine) - return unknownLoc; - - auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); - if (fileName.empty()) - fileName = ""; - return opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName), - debugLine->line, debugLine->col); -} - -LogicalResult Deserializer::processDebugLine(ArrayRef operands) { - // According to SPIR-V spec: - // "This location information applies to the instructions physically - // following this instruction, up to the first occurrence of any of the - // following: the next end of block, the next OpLine instruction, or the next - // OpNoLine instruction." - if (operands.size() != 3) - return emitError(unknownLoc, "OpLine must have 3 operands"); - debugLine = DebugLine(operands[0], operands[1], operands[2]); - return success(); -} - -LogicalResult Deserializer::clearDebugLine() { - debugLine = llvm::None; - return success(); -} - -LogicalResult Deserializer::processDebugString(ArrayRef operands) { - if (operands.size() < 2) - return emitError(unknownLoc, "OpString needs at least 2 operands"); - - if (!debugInfoMap.lookup(operands[0]).empty()) - return emitError(unknownLoc, - "duplicate debug string found for result ") - << operands[0]; - - unsigned wordIndex = 1; - StringRef debugString = decodeStringLiteral(operands, wordIndex); - if (wordIndex != operands.size()) - return emitError(unknownLoc, - "unexpected trailing words in OpString instruction"); - - debugInfoMap[operands[0]] = debugString; - return success(); -} - -//===----------------------------------------------------------------------===// -// Instruction -//===----------------------------------------------------------------------===// - -Value Deserializer::getValue(uint32_t id) { - if (auto constInfo = getConstant(id)) { - // Materialize a `spv.constant` op at every use site. - return opBuilder.create(unknownLoc, constInfo->second, - constInfo->first); - } - if (auto varOp = getGlobalVariable(id)) { - auto addressOfOp = opBuilder.create( - unknownLoc, varOp.type(), - opBuilder.getSymbolRefAttr(varOp.getOperation())); - return addressOfOp.pointer(); - } - if (auto constOp = getSpecConstant(id)) { - auto referenceOfOp = opBuilder.create( - unknownLoc, constOp.default_value().getType(), - opBuilder.getSymbolRefAttr(constOp.getOperation())); - return referenceOfOp.reference(); - } - if (auto undef = getUndefType(id)) { - return opBuilder.create(unknownLoc, undef); - } - return valueMap.lookup(id); -} - -LogicalResult -Deserializer::sliceInstruction(spirv::Opcode &opcode, - ArrayRef &operands, - Optional expectedOpcode) { - auto binarySize = binary.size(); - if (curOffset >= binarySize) { - return emitError(unknownLoc, "expected ") - << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) - : "more") - << " instruction"; - } - - // For each instruction, get its word count from the first word to slice it - // from the stream properly, and then dispatch to the instruction handler. - - uint32_t wordCount = binary[curOffset] >> 16; - - if (wordCount == 0) - return emitError(unknownLoc, "word count cannot be zero"); - - uint32_t nextOffset = curOffset + wordCount; - if (nextOffset > binarySize) - return emitError(unknownLoc, "insufficient words for the last instruction"); - - opcode = extractOpcode(binary[curOffset]); - operands = binary.slice(curOffset + 1, wordCount - 1); - curOffset = nextOffset; - return success(); -} - -LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, - ArrayRef operands, - bool deferInstructions) { - LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction " - << spirv::stringifyOpcode(opcode) << "\n"); - - // First dispatch all the instructions whose opcode does not correspond to - // those that have a direct mirror in the SPIR-V dialect - switch (opcode) { - case spirv::Opcode::OpCapability: - return processCapability(operands); - case spirv::Opcode::OpExtension: - return processExtension(operands); - case spirv::Opcode::OpExtInst: - return processExtInst(operands); - case spirv::Opcode::OpExtInstImport: - return processExtInstImport(operands); - case spirv::Opcode::OpMemberName: - return processMemberName(operands); - case spirv::Opcode::OpMemoryModel: - return processMemoryModel(operands); - case spirv::Opcode::OpEntryPoint: - case spirv::Opcode::OpExecutionMode: - if (deferInstructions) { - deferredInstructions.emplace_back(opcode, operands); - return success(); - } - break; - case spirv::Opcode::OpVariable: - if (isa(opBuilder.getBlock()->getParentOp())) { - return processGlobalVariable(operands); - } - break; - case spirv::Opcode::OpLine: - return processDebugLine(operands); - case spirv::Opcode::OpNoLine: - return clearDebugLine(); - case spirv::Opcode::OpName: - return processName(operands); - case spirv::Opcode::OpString: - return processDebugString(operands); - case spirv::Opcode::OpModuleProcessed: - case spirv::Opcode::OpSource: - case spirv::Opcode::OpSourceContinued: - case spirv::Opcode::OpSourceExtension: - // TODO: This is debug information embedded in the binary which should be - // translated into the spv.module. - return success(); - case spirv::Opcode::OpTypeVoid: - case spirv::Opcode::OpTypeBool: - case spirv::Opcode::OpTypeInt: - case spirv::Opcode::OpTypeFloat: - case spirv::Opcode::OpTypeVector: - case spirv::Opcode::OpTypeMatrix: - case spirv::Opcode::OpTypeArray: - case spirv::Opcode::OpTypeFunction: - case spirv::Opcode::OpTypeRuntimeArray: - case spirv::Opcode::OpTypeStruct: - case spirv::Opcode::OpTypePointer: - case spirv::Opcode::OpTypeCooperativeMatrixNV: - return processType(opcode, operands); - case spirv::Opcode::OpConstant: - return processConstant(operands, /*isSpec=*/false); - case spirv::Opcode::OpSpecConstant: - return processConstant(operands, /*isSpec=*/true); - case spirv::Opcode::OpConstantComposite: - return processConstantComposite(operands); - case spirv::Opcode::OpConstantTrue: - return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); - case spirv::Opcode::OpSpecConstantTrue: - return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); - case spirv::Opcode::OpConstantFalse: - return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); - case spirv::Opcode::OpSpecConstantFalse: - return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); - case spirv::Opcode::OpConstantNull: - return processConstantNull(operands); - case spirv::Opcode::OpDecorate: - return processDecoration(operands); - case spirv::Opcode::OpMemberDecorate: - return processMemberDecoration(operands); - case spirv::Opcode::OpFunction: - return processFunction(operands); - case spirv::Opcode::OpLabel: - return processLabel(operands); - case spirv::Opcode::OpBranch: - return processBranch(operands); - case spirv::Opcode::OpBranchConditional: - return processBranchConditional(operands); - case spirv::Opcode::OpSelectionMerge: - return processSelectionMerge(operands); - case spirv::Opcode::OpLoopMerge: - return processLoopMerge(operands); - case spirv::Opcode::OpPhi: - return processPhi(operands); - case spirv::Opcode::OpUndef: - return processUndef(operands); - default: - break; - } - return dispatchToAutogenDeserialization(opcode, operands); -} - -LogicalResult Deserializer::processUndef(ArrayRef operands) { - if (operands.size() != 2) { - return emitError(unknownLoc, "OpUndef instruction must have two operands"); - } - auto type = getType(operands[0]); - if (!type) { - return emitError(unknownLoc, "unknown type with OpUndef instruction"); - } - undefMap[operands[1]] = type; - return success(); -} - -LogicalResult Deserializer::processExtInst(ArrayRef operands) { - if (operands.size() < 4) { - return emitError(unknownLoc, - "OpExtInst must have at least 4 operands, result type " - ", result , set and instruction opcode"); - } - if (!extendedInstSets.count(operands[2])) { - return emitError(unknownLoc, "undefined set in OpExtInst"); - } - SmallVector slicedOperands; - slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); - slicedOperands.append(std::next(operands.begin(), 4), operands.end()); - return dispatchToExtensionSetAutogenDeserialization( - extendedInstSets[operands[2]], operands[3], slicedOperands); -} - -namespace { - -template <> -LogicalResult -Deserializer::processOp(ArrayRef words) { - unsigned wordIndex = 0; - if (wordIndex >= words.size()) { - return emitError(unknownLoc, - "missing Execution Model specification in OpEntryPoint"); - } - auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]); - if (wordIndex >= words.size()) { - return emitError(unknownLoc, "missing in OpEntryPoint"); - } - // Get the function - auto fnID = words[wordIndex++]; - // Get the function name - auto fnName = decodeStringLiteral(words, wordIndex); - // Verify that the function matches the fnName - auto parsedFunc = getFunction(fnID); - if (!parsedFunc) { - return emitError(unknownLoc, "no function matching ") << fnID; - } - if (parsedFunc.getName() != fnName) { - return emitError(unknownLoc, "function name mismatch between OpEntryPoint " - "and OpFunction with ") - << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); - } - SmallVector interface; - while (wordIndex < words.size()) { - auto arg = getGlobalVariable(words[wordIndex]); - if (!arg) { - return emitError(unknownLoc, "undefined result ") - << words[wordIndex] << " while decoding OpEntryPoint"; - } - interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); - wordIndex++; - } - opBuilder.create(unknownLoc, exec_model, - opBuilder.getSymbolRefAttr(fnName), - opBuilder.getArrayAttr(interface)); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef words) { - unsigned wordIndex = 0; - if (wordIndex >= words.size()) { - return emitError(unknownLoc, - "missing function result in OpExecutionMode"); - } - // Get the function to get the name of the function - auto fnID = words[wordIndex++]; - auto fn = getFunction(fnID); - if (!fn) { - return emitError(unknownLoc, "no function matching ") << fnID; - } - // Get the Execution mode - if (wordIndex >= words.size()) { - return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); - } - auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); - - // Get the values - SmallVector attrListElems; - while (wordIndex < words.size()) { - attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); - } - auto values = opBuilder.getArrayAttr(attrListElems); - opBuilder.create( - unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef operands) { - if (operands.size() != 3) { - return emitError( - unknownLoc, - "OpControlBarrier must have execution scope , memory scope " - "and memory semantics "); - } - - SmallVector argAttrs; - for (auto operand : operands) { - auto argAttr = getConstantInt(operand); - if (!argAttr) { - return emitError(unknownLoc, - "expected 32-bit integer constant from ") - << operand << " for OpControlBarrier"; - } - argAttrs.push_back(argAttr); - } - - opBuilder.create(unknownLoc, argAttrs[0], - argAttrs[1], argAttrs[2]); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef operands) { - if (operands.size() < 3) { - return emitError(unknownLoc, - "OpFunctionCall must have at least 3 operands"); - } - - Type resultType = getType(operands[0]); - if (!resultType) { - return emitError(unknownLoc, "undefined result type from ") - << operands[0]; - } - - // Use null type to mean no result type. - if (isVoidType(resultType)) - resultType = nullptr; - - auto resultID = operands[1]; - auto functionID = operands[2]; - - auto functionName = getFunctionSymbol(functionID); - - SmallVector arguments; - for (auto operand : llvm::drop_begin(operands, 3)) { - auto value = getValue(operand); - if (!value) { - return emitError(unknownLoc, "unknown ") - << operand << " used by OpFunctionCall"; - } - arguments.push_back(value); - } - - auto opFunctionCall = opBuilder.create( - unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName), - arguments); - - if (resultType) - valueMap[resultID] = opFunctionCall.getResult(0); - return success(); -} - -template <> -LogicalResult -Deserializer::processOp(ArrayRef operands) { - if (operands.size() != 2) { - return emitError(unknownLoc, "OpMemoryBarrier must have memory scope " - "and memory semantics "); - } - - SmallVector argAttrs; - for (auto operand : operands) { - auto argAttr = getConstantInt(operand); - if (!argAttr) { - return emitError(unknownLoc, - "expected 32-bit integer constant from ") - << operand << " for OpMemoryBarrier"; - } - argAttrs.push_back(argAttr); - } - - opBuilder.create(unknownLoc, argAttrs[0], - argAttrs[1]); - return success(); -} - -// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and -// various Deserializer::processOp<...>() specializations. -#define GET_DESERIALIZATION_FNS -#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" -} // namespace - -Optional spirv::deserialize(ArrayRef binary, - MLIRContext *context) { - Deserializer deserializer(binary, context); - - if (failed(deserializer.deserialize())) - return llvm::None; - - return deserializer.collect(); -} + for (unsigned i \ No newline at end of file diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Builders.h" @@ -49,13 +50,13 @@ auto binary = llvm::makeArrayRef(reinterpret_cast(start), size / sizeof(uint32_t)); - auto spirvModule = spirv::deserialize(binary, context); + spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context); if (!spirvModule) return {}; OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); - module->getBody()->push_front(spirvModule->getOperation()); + module->getBody()->push_front(spirvModule.release()); return module; } @@ -136,14 +137,14 @@ return failure(); // Then deserialize to get back a SPIR-V module. - auto spirvModule = spirv::deserialize(binary, context); + spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context); if (!spirvModule) return failure(); // Wrap around in a new MLIR module. OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get( /*filename=*/"", /*line=*/0, /*column=*/0, context))); - dstModule->getBody()->push_front(spirvModule->getOperation()); + dstModule->getBody()->push_front(spirvModule.release()); dstModule->print(output); return mlir::success(); diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Diagnostics.h" @@ -46,7 +47,7 @@ } /// Performs deserialization and returns the constructed spv.module op. - Optional deserialize() { + spirv::OwningSPIRVModuleRef deserialize() { return spirv::deserialize(binary, &context); } @@ -130,27 +131,27 @@ //===----------------------------------------------------------------------===// TEST_F(DeserializationTest, EmptyModuleFailure) { - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("SPIR-V binary module must have a 5-word header"); } TEST_F(DeserializationTest, WrongMagicNumberFailure) { addHeader(); binary.front() = 0xdeadbeef; // Change to a wrong magic number - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("incorrect magic number"); } TEST_F(DeserializationTest, OnlyHeaderSuccess) { addHeader(); - EXPECT_NE(llvm::None, deserialize()); + EXPECT_TRUE(deserialize()); } TEST_F(DeserializationTest, ZeroWordCountFailure) { addHeader(); binary.push_back(0); // OpNop with zero word count - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("word count cannot be zero"); } @@ -160,7 +161,7 @@ static_cast(spirv::Opcode::OpTypeVoid)); // Missing word for type - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("insufficient words for the last instruction"); } @@ -172,7 +173,7 @@ addHeader(); addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); } @@ -198,7 +199,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands2); binary.append(typeDecl.begin(), typeDecl.end()); - EXPECT_NE(llvm::None, deserialize()); + EXPECT_TRUE(deserialize()); } TEST_F(DeserializationTest, OpMemberNameMissingOperands) { @@ -215,7 +216,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands1); binary.append(typeDecl.begin(), typeDecl.end()); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpMemberName must have at least 3 operands"); } @@ -234,7 +235,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands); binary.append(typeDecl.begin(), typeDecl.end()); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("unexpected trailing words in OpMemberName instruction"); } @@ -249,7 +250,7 @@ addFunction(voidType, fnType); // Missing OpFunctionEnd - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("expected OpFunctionEnd instruction"); } @@ -261,7 +262,7 @@ addFunction(voidType, fnType); // Missing OpFunctionParameter - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("expected OpFunctionParameter instruction"); } @@ -274,7 +275,7 @@ addReturn(); addFunctionEnd(); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("a basic block must start with OpLabel"); } @@ -287,6 +288,6 @@ addReturn(); addFunctionEnd(); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpLabel should only have result "); }