diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -5,15 +5,19 @@ add_mlir_dialect_library(MLIRSPIRVDialect AtomicOps.cpp CastOps.cpp + ControlFlowOps.cpp CooperativeMatrixOps.cpp GroupOps.cpp IntegerDotProductOps.cpp JointMatrixOps.cpp + MemoryOps.cpp SPIRVAttributes.cpp SPIRVCanonicalization.cpp SPIRVGLCanonicalization.cpp SPIRVDialect.cpp SPIRVEnums.cpp + SPIRVOpAvailability.cpp + SPIRVOpDefinition.cpp SPIRVOps.cpp SPIRVParsingUtils.cpp SPIRVTypes.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -0,0 +1,562 @@ +//===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the control flow operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Interfaces/CallInterfaces.h" + +#include "SPIRVOpUtils.h" +#include "SPIRVParsingUtils.h" + +using namespace mlir::spirv::AttrNames; + +namespace mlir::spirv { + +/// Parses Function, Selection and Loop control attributes. If no control is +/// specified, "None" is used as a default. +template +static ParseResult +parseControlAttribute(OpAsmParser &parser, OperationState &state, + StringRef attrName = spirv::attributeName()) { + if (succeeded(parser.parseOptionalKeyword(kControl))) { + EnumClass control; + if (parser.parseLParen() || + spirv::parseEnumKeywordAttr(control, parser, state) || + parser.parseRParen()) + return failure(); + return success(); + } + // Set control to "None" otherwise. + Builder builder = parser.getBuilder(); + state.addAttribute(attrName, + builder.getAttr(static_cast(0))); + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.BranchOp +//===----------------------------------------------------------------------===// + +SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return SuccessorOperands(0, getTargetOperandsMutable()); +} + +//===----------------------------------------------------------------------===// +// spirv.BranchConditionalOp +//===----------------------------------------------------------------------===// + +SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) { + assert(index < 2 && "invalid successor index"); + return SuccessorOperands(index == kTrueIndex + ? getTrueTargetOperandsMutable() + : getFalseTargetOperandsMutable()); +} + +ParseResult BranchConditionalOp::parse(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + OpAsmParser::UnresolvedOperand condInfo; + Block *dest; + + // Parse the condition. + Type boolTy = builder.getI1Type(); + if (parser.parseOperand(condInfo) || + parser.resolveOperand(condInfo, boolTy, result.operands)) + return failure(); + + // Parse the optional branch weights. + if (succeeded(parser.parseOptionalLSquare())) { + IntegerAttr trueWeight, falseWeight; + NamedAttrList weights; + + auto i32Type = builder.getIntegerType(32); + if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || + parser.parseComma() || + parser.parseAttribute(falseWeight, i32Type, "weight", weights) || + parser.parseRSquare()) + return failure(); + + result.addAttribute(kBranchWeightAttrName, + builder.getArrayAttr({trueWeight, falseWeight})); + } + + // Parse the true branch. + SmallVector trueOperands; + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, trueOperands)) + return failure(); + result.addSuccessors(dest); + result.addOperands(trueOperands); + + // Parse the false branch. + SmallVector falseOperands; + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, falseOperands)) + return failure(); + result.addSuccessors(dest); + result.addOperands(falseOperands); + result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, static_cast(trueOperands.size()), + static_cast(falseOperands.size())})); + + return success(); +} + +void BranchConditionalOp::print(OpAsmPrinter &printer) { + printer << ' ' << getCondition(); + + if (auto weights = getBranchWeights()) { + printer << " ["; + llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { + printer << llvm::cast(a).getInt(); + }); + printer << "]"; + } + + printer << ", "; + printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); + printer << ", "; + printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); +} + +LogicalResult BranchConditionalOp::verify() { + if (auto weights = getBranchWeights()) { + if (weights->getValue().size() != 2) { + return emitOpError("must have exactly two branch weights"); + } + if (llvm::all_of(*weights, [](Attribute attr) { + return llvm::cast(attr).getValue().isZero(); + })) + return emitOpError("branch weights cannot both be zero"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.FunctionCall +//===----------------------------------------------------------------------===// + +LogicalResult FunctionCallOp::verify() { + auto fnName = getCalleeAttr(); + + auto funcOp = dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); + if (!funcOp) { + return emitOpError("callee function '") + << fnName.getValue() << "' not found in nearest symbol table"; + } + + auto functionType = funcOp.getFunctionType(); + + if (getNumResults() > 1) { + return emitOpError( + "expected callee function to have 0 or 1 result, but provided ") + << getNumResults(); + } + + if (functionType.getNumInputs() != getNumOperands()) { + return emitOpError("has incorrect number of operands for callee: expected ") + << functionType.getNumInputs() << ", but provided " + << getNumOperands(); + } + + for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { + if (getOperand(i).getType() != functionType.getInput(i)) { + return emitOpError("operand type mismatch: expected operand type ") + << functionType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + } + } + + if (functionType.getNumResults() != getNumResults()) { + return emitOpError( + "has incorrect number of results has for callee: expected ") + << functionType.getNumResults() << ", but provided " + << getNumResults(); + } + + if (getNumResults() && + (getResult(0).getType() != functionType.getResult(0))) { + return emitOpError("result type mismatch: expected ") + << functionType.getResult(0) << ", but provided " + << getResult(0).getType(); + } + + return success(); +} + +CallInterfaceCallable FunctionCallOp::getCallableForCallee() { + return (*this)->getAttrOfType(kCallee); +} + +void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr(kCallee, callee.get()); +} + +Operation::operand_range FunctionCallOp::getArgOperands() { + return getArguments(); +} + +//===----------------------------------------------------------------------===// +// spirv.mlir.loop +//===----------------------------------------------------------------------===// + +void LoopOp::build(OpBuilder &builder, OperationState &state) { + state.addAttribute("loop_control", builder.getAttr( + spirv::LoopControl::None)); + state.addRegion(); +} + +ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { + if (parseControlAttribute(parser, + result)) + return failure(); + return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); +} + +void LoopOp::print(OpAsmPrinter &printer) { + auto control = getLoopControl(); + if (control != spirv::LoopControl::None) + printer << " control(" << spirv::stringifyLoopControl(control) << ")"; + printer << ' '; + printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} + +/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the +/// given `dstBlock`. +static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { + // Check that there is only one op in the `srcBlock`. + if (!llvm::hasSingleElement(srcBlock)) + return false; + + auto branchOp = dyn_cast(srcBlock.back()); + return branchOp && branchOp.getSuccessor() == &dstBlock; +} + +/// Returns true if the given `block` only contains one `spirv.mlir.merge` op. +static bool isMergeBlock(Block &block) { + return !block.empty() && std::next(block.begin()) == block.end() && + isa(block.front()); +} + +LogicalResult LoopOp::verifyRegions() { + auto *op = getOperation(); + + // We need to verify that the blocks follow the following layout: + // + // +-------------+ + // | entry block | + // +-------------+ + // | + // v + // +-------------+ + // | loop header | <-----+ + // +-------------+ | + // | + // ... | + // \ | / | + // v | + // +---------------+ | + // | loop continue | -----+ + // +---------------+ + // + // ... + // \ | / + // v + // +-------------+ + // | merge block | + // +-------------+ + + auto ®ion = op->getRegion(0); + // Allow empty region as a degenerated case, which can come from + // optimizations. + if (region.empty()) + return success(); + + // The last block is the merge block. + Block &merge = region.back(); + if (!isMergeBlock(merge)) + return emitOpError("last block must be the merge block with only one " + "'spirv.mlir.merge' op"); + + if (std::next(region.begin()) == region.end()) + return emitOpError( + "must have an entry block branching to the loop header block"); + // The first block is the entry block. + Block &entry = region.front(); + + if (std::next(region.begin(), 2) == region.end()) + return emitOpError( + "must have a loop header block branched from the entry block"); + // The second block is the loop header block. + Block &header = *std::next(region.begin(), 1); + + if (!hasOneBranchOpTo(entry, header)) + return emitOpError( + "entry block must only have one 'spirv.Branch' op to the second block"); + + if (std::next(region.begin(), 3) == region.end()) + return emitOpError( + "requires a loop continue block branching to the loop header block"); + // The second to last block is the loop continue block. + Block &cont = *std::prev(region.end(), 2); + + // Make sure that we have a branch from the loop continue block to the loop + // header block. + if (llvm::none_of( + llvm::seq(0, cont.getNumSuccessors()), + [&](unsigned index) { return cont.getSuccessor(index) == &header; })) + return emitOpError("second to last block must be the loop continue " + "block that branches to the loop header block"); + + // Make sure that no other blocks (except the entry and loop continue block) + // branches to the loop header block. + for (auto &block : llvm::make_range(std::next(region.begin(), 2), + std::prev(region.end(), 2))) { + for (auto i : llvm::seq(0, block.getNumSuccessors())) { + if (block.getSuccessor(i) == &header) { + return emitOpError("can only have the entry and loop continue " + "block branching to the loop header block"); + } + } + } + + return success(); +} + +Block *LoopOp::getEntryBlock() { + assert(!getBody().empty() && "op region should not be empty!"); + return &getBody().front(); +} + +Block *LoopOp::getHeaderBlock() { + assert(!getBody().empty() && "op region should not be empty!"); + // The second block is the loop header block. + return &*std::next(getBody().begin()); +} + +Block *LoopOp::getContinueBlock() { + assert(!getBody().empty() && "op region should not be empty!"); + // The second to last block is the loop continue block. + return &*std::prev(getBody().end(), 2); +} + +Block *LoopOp::getMergeBlock() { + assert(!getBody().empty() && "op region should not be empty!"); + // The last block is the loop merge block. + return &getBody().back(); +} + +void LoopOp::addEntryAndMergeBlock() { + assert(getBody().empty() && "entry and merge block already exist"); + getBody().push_back(new Block()); + auto *mergeBlock = new Block(); + getBody().push_back(mergeBlock); + OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); + + // Add a spirv.mlir.merge op into the merge block. + builder.create(getLoc()); +} + +//===----------------------------------------------------------------------===// +// spirv.mlir.merge +//===----------------------------------------------------------------------===// + +LogicalResult MergeOp::verify() { + auto *parentOp = (*this)->getParentOp(); + if (!parentOp || !isa(parentOp)) + return emitOpError( + "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'"); + + // TODO: This check should be done in `verifyRegions` of parent op. + Block &parentLastBlock = (*this)->getParentRegion()->back(); + if (getOperation() != parentLastBlock.getTerminator()) + return emitOpError("can only be used in the last block of " + "'spirv.mlir.selection' or 'spirv.mlir.loop'"); + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.Return +//===----------------------------------------------------------------------===// + +LogicalResult ReturnOp::verify() { + // Verification is performed in spirv.func op. + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.ReturnValue +//===----------------------------------------------------------------------===// + +LogicalResult ReturnValueOp::verify() { + // Verification is performed in spirv.func op. + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.Select +//===----------------------------------------------------------------------===// + +LogicalResult SelectOp::verify() { + if (auto conditionTy = llvm::dyn_cast(getCondition().getType())) { + auto resultVectorTy = llvm::dyn_cast(getResult().getType()); + if (!resultVectorTy) { + return emitOpError("result expected to be of vector type when " + "condition is of vector type"); + } + if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { + return emitOpError("result should have the same number of elements as " + "the condition when condition is of vector type"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.mlir.selection +//===----------------------------------------------------------------------===// + +ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) { + if (parseControlAttribute(parser, result)) + return failure(); + return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); +} + +void SelectionOp::print(OpAsmPrinter &printer) { + auto control = getSelectionControl(); + if (control != spirv::SelectionControl::None) + printer << " control(" << spirv::stringifySelectionControl(control) << ")"; + printer << ' '; + printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} + +LogicalResult SelectionOp::verifyRegions() { + auto *op = getOperation(); + + // We need to verify that the blocks follow the following layout: + // + // +--------------+ + // | header block | + // +--------------+ + // / | \ + // ... + // + // + // +---------+ +---------+ +---------+ + // | case #0 | | case #1 | | case #2 | ... + // +---------+ +---------+ +---------+ + // + // + // ... + // \ | / + // v + // +-------------+ + // | merge block | + // +-------------+ + + auto ®ion = op->getRegion(0); + // Allow empty region as a degenerated case, which can come from + // optimizations. + if (region.empty()) + return success(); + + // The last block is the merge block. + if (!isMergeBlock(region.back())) + return emitOpError("last block must be the merge block with only one " + "'spirv.mlir.merge' op"); + + if (std::next(region.begin()) == region.end()) + return emitOpError("must have a selection header block"); + + return success(); +} + +Block *SelectionOp::getHeaderBlock() { + assert(!getBody().empty() && "op region should not be empty!"); + // The first block is the loop header block. + return &getBody().front(); +} + +Block *SelectionOp::getMergeBlock() { + assert(!getBody().empty() && "op region should not be empty!"); + // The last block is the loop merge block. + return &getBody().back(); +} + +void SelectionOp::addMergeBlock() { + assert(getBody().empty() && "entry and merge block already exist"); + auto *mergeBlock = new Block(); + getBody().push_back(mergeBlock); + OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); + + // Add a spirv.mlir.merge op into the merge block. + builder.create(getLoc()); +} + +SelectionOp +SelectionOp::createIfThen(Location loc, Value condition, + function_ref thenBody, + OpBuilder &builder) { + auto selectionOp = + builder.create(loc, spirv::SelectionControl::None); + + selectionOp.addMergeBlock(); + Block *mergeBlock = selectionOp.getMergeBlock(); + Block *thenBlock = nullptr; + + // Build the "then" block. + { + OpBuilder::InsertionGuard guard(builder); + thenBlock = builder.createBlock(mergeBlock); + thenBody(builder); + builder.create(loc, mergeBlock); + } + + // Build the header block. + { + OpBuilder::InsertionGuard guard(builder); + builder.createBlock(thenBlock); + builder.create( + loc, condition, thenBlock, + /*trueArguments=*/ArrayRef(), mergeBlock, + /*falseArguments=*/ArrayRef()); + } + + return selectionOp; +} + +//===----------------------------------------------------------------------===// +// spirv.Unreachable +//===----------------------------------------------------------------------===// + +LogicalResult spirv::UnreachableOp::verify() { + auto *block = (*this)->getBlock(); + // Fast track: if this is in entry block, its invalid. Otherwise, if no + // predecessors, it's valid. + if (block->isEntryBlock()) + return emitOpError("cannot be used in reachable block"); + if (block->hasNoPredecessors()) + return success(); + + // TODO: further verification needs to analyze reachability from + // the entry block. + + return success(); +} + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp @@ -0,0 +1,751 @@ +//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the memory operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +#include "SPIRVOpUtils.h" +#include "SPIRVParsingUtils.h" + +#include "llvm/ADT/StringExtras.h" + +using namespace mlir::spirv::AttrNames; + +namespace mlir::spirv { + +// TODO Make sure to merge this and the previous function into one template +// parameterized by memory access attribute name and alignment. Doing so now +// results in VS2017 in producing an internal error (at the call site) that's +// not detailed enough to understand what is happening. +static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, + OperationState &state) { + // Parse an optional list of attributes staring with '[' + if (parser.parseOptionalLSquare()) { + // Nothing to do + return success(); + } + + spirv::MemoryAccess memoryAccessAttr; + if (spirv::parseEnumStrAttr( + memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName)) + return failure(); + + if (spirv::bitEnumContainsAll(memoryAccessAttr, + spirv::MemoryAccess::Aligned)) { + // Parse integer attribute for alignment. + Attribute alignmentAttr; + Type i32Type = parser.getBuilder().getIntegerType(32); + if (parser.parseComma() || + parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName, + state.attributes)) { + return failure(); + } + } + return parser.parseRSquare(); +} + +// TODO Make sure to merge this and the previous function into one template +// parameterized by memory access attribute name and alignment. Doing so now +// results in VS2017 in producing an internal error (at the call site) that's +// not detailed enough to understand what is happening. +template +static void printSourceMemoryAccessAttribute( + MemoryOpTy memoryOp, OpAsmPrinter &printer, + SmallVectorImpl &elidedAttrs, + std::optional memoryAccessAtrrValue = std::nullopt, + std::optional alignmentAttrValue = std::nullopt) { + + printer << ", "; + + // Print optional memory access attribute. + if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue + : memoryOp.getMemoryAccess())) { + elidedAttrs.push_back(kSourceMemoryAccessAttrName); + + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; + + if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { + // Print integer alignment attribute. + if (auto alignment = (alignmentAttrValue ? alignmentAttrValue + : memoryOp.getAlignment())) { + elidedAttrs.push_back(kSourceAlignmentAttrName); + printer << ", " << *alignment; + } + } + printer << "]"; + } + elidedAttrs.push_back(spirv::attributeName()); +} + +template +static void printMemoryAccessAttribute( + MemoryOpTy memoryOp, OpAsmPrinter &printer, + SmallVectorImpl &elidedAttrs, + std::optional memoryAccessAtrrValue = std::nullopt, + std::optional alignmentAttrValue = std::nullopt) { + // Print optional memory access attribute. + if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue + : memoryOp.getMemoryAccess())) { + elidedAttrs.push_back(kMemoryAccessAttrName); + + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; + + if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { + // Print integer alignment attribute. + if (auto alignment = (alignmentAttrValue ? alignmentAttrValue + : memoryOp.getAlignment())) { + elidedAttrs.push_back(kAlignmentAttrName); + printer << ", " << *alignment; + } + } + printer << "]"; + } + elidedAttrs.push_back(spirv::attributeName()); +} + +template +static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, + Value val) { + // ODS already checks ptr is spirv::PointerType. Just check that the pointee + // type of the pointer and the type of the value are the same + // + // TODO: Check that the value type satisfies restrictions of + // SPIR-V OpLoad/OpStore operations + if (val.getType() != + llvm::cast(ptr.getType()).getPointeeType()) { + return op.emitOpError("mismatch in result type and pointer type"); + } + return success(); +} + +template +static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { + // ODS checks for attributes values. Just need to verify that if the + // memory-access attribute is Aligned, then the alignment attribute must be + // present. + auto *op = memoryOp.getOperation(); + auto memAccessAttr = op->getAttr(kMemoryAccessAttrName); + if (!memAccessAttr) { + // Alignment attribute shouldn't be present if memory access attribute is + // not present. + if (op->getAttr(kAlignmentAttrName)) { + return memoryOp.emitOpError( + "invalid alignment specification without aligned memory access " + "specification"); + } + return success(); + } + + auto memAccess = llvm::cast(memAccessAttr); + + if (!memAccess) { + return memoryOp.emitOpError("invalid memory access specifier: ") + << memAccessAttr; + } + + if (spirv::bitEnumContainsAll(memAccess.getValue(), + spirv::MemoryAccess::Aligned)) { + if (!op->getAttr(kAlignmentAttrName)) { + return memoryOp.emitOpError("missing alignment value"); + } + } else { + if (op->getAttr(kAlignmentAttrName)) { + return memoryOp.emitOpError( + "invalid alignment specification with non-aligned memory access " + "specification"); + } + } + return success(); +} + +// TODO Make sure to merge this and the previous function into one template +// parameterized by memory access attribute name and alignment. Doing so now +// results in VS2017 in producing an internal error (at the call site) that's +// not detailed enough to understand what is happening. +template +static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { + // ODS checks for attributes values. Just need to verify that if the + // memory-access attribute is Aligned, then the alignment attribute must be + // present. + auto *op = memoryOp.getOperation(); + auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName); + if (!memAccessAttr) { + // Alignment attribute shouldn't be present if memory access attribute is + // not present. + if (op->getAttr(kSourceAlignmentAttrName)) { + return memoryOp.emitOpError( + "invalid alignment specification without aligned memory access " + "specification"); + } + return success(); + } + + auto memAccess = llvm::cast(memAccessAttr); + + if (!memAccess) { + return memoryOp.emitOpError("invalid memory access specifier: ") + << memAccess; + } + + if (spirv::bitEnumContainsAll(memAccess.getValue(), + spirv::MemoryAccess::Aligned)) { + if (!op->getAttr(kSourceAlignmentAttrName)) { + return memoryOp.emitOpError("missing alignment value"); + } + } else { + if (op->getAttr(kSourceAlignmentAttrName)) { + return memoryOp.emitOpError( + "invalid alignment specification with non-aligned memory access " + "specification"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.AccessChainOp +//===----------------------------------------------------------------------===// + +static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { + auto ptrType = llvm::dyn_cast(type); + if (!ptrType) { + emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " + "to composite type, but provided ") + << type; + return nullptr; + } + + auto resultType = ptrType.getPointeeType(); + auto resultStorageClass = ptrType.getStorageClass(); + int32_t index = 0; + + for (auto indexSSA : indices) { + auto cType = llvm::dyn_cast(resultType); + if (!cType) { + emitError( + baseLoc, + "'spirv.AccessChain' op cannot extract from non-composite type ") + << resultType << " with index " << index; + return nullptr; + } + index = 0; + if (llvm::isa(resultType)) { + Operation *op = indexSSA.getDefiningOp(); + if (!op) { + emitError(baseLoc, "'spirv.AccessChain' op index must be an " + "integer spirv.Constant to access " + "element of spirv.struct"); + return nullptr; + } + + // TODO: this should be relaxed to allow + // integer literals of other bitwidths. + if (failed(spirv::extractValueFromConstOp(op, index))) { + emitError( + baseLoc, + "'spirv.AccessChain' index must be an integer spirv.Constant to " + "access element of spirv.struct, but provided ") + << op->getName(); + return nullptr; + } + if (index < 0 || static_cast(index) >= cType.getNumElements()) { + emitError(baseLoc, "'spirv.AccessChain' op index ") + << index << " out of bounds for " << resultType; + return nullptr; + } + } + resultType = cType.getElementType(index); + } + return spirv::PointerType::get(resultType, resultStorageClass); +} + +void AccessChainOp::build(OpBuilder &builder, OperationState &state, + Value basePtr, ValueRange indices) { + auto type = getElementPtrType(basePtr.getType(), indices, state.location); + assert(type && "Unable to deduce return type based on basePtr and indices"); + build(builder, state, type, basePtr, indices); +} + +ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand ptrInfo; + SmallVector indicesInfo; + Type type; + auto loc = parser.getCurrentLocation(); + SmallVector indicesTypes; + + if (parser.parseOperand(ptrInfo) || + parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || + parser.parseColonType(type) || + parser.resolveOperand(ptrInfo, type, result.operands)) { + return failure(); + } + + // Check that the provided indices list is not empty before parsing their + // type list. + if (indicesInfo.empty()) { + return mlir::emitError(result.location, + "'spirv.AccessChain' op expected at " + "least one index "); + } + + if (parser.parseComma() || parser.parseTypeList(indicesTypes)) + return failure(); + + // Check that the indices types list is not empty and that it has a one-to-one + // mapping to the provided indices. + if (indicesTypes.size() != indicesInfo.size()) { + return mlir::emitError( + result.location, "'spirv.AccessChain' op indices types' count must be " + "equal to indices info count"); + } + + if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands)) + return failure(); + + auto resultType = getElementPtrType( + type, llvm::ArrayRef(result.operands).drop_front(), result.location); + if (!resultType) { + return failure(); + } + + result.addTypes(resultType); + return success(); +} + +template +static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { + printer << ' ' << op.getBasePtr() << '[' << indices + << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes(); +} + +void spirv::AccessChainOp::print(OpAsmPrinter &printer) { + printAccessChain(*this, getIndices(), printer); +} + +template +static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { + auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(), + indices, accessChainOp.getLoc()); + if (!resultType) + return failure(); + + auto providedResultType = + llvm::dyn_cast(accessChainOp.getType()); + if (!providedResultType) + return accessChainOp.emitOpError( + "result type must be a pointer, but provided") + << providedResultType; + + if (resultType != providedResultType) + return accessChainOp.emitOpError("invalid result type: expected ") + << resultType << ", but provided " << providedResultType; + + return success(); +} + +LogicalResult AccessChainOp::verify() { + return verifyAccessChain(*this, getIndices()); +} + +//===----------------------------------------------------------------------===// +// spirv.LoadOp +//===----------------------------------------------------------------------===// + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr, + MemoryAccessAttr memoryAccess, IntegerAttr alignment) { + auto ptrType = llvm::cast(basePtr.getType()); + build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, + alignment); +} + +ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { + // Parse the storage class specification + spirv::StorageClass storageClass; + OpAsmParser::UnresolvedOperand ptrInfo; + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || + parseMemoryAccessAttributes(parser, result) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get(elementType, storageClass); + if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { + return failure(); + } + + result.addTypes(elementType); + return success(); +} + +void LoadOp::print(OpAsmPrinter &printer) { + SmallVector elidedAttrs; + StringRef sc = stringifyStorageClass( + llvm::cast(getPtr().getType()).getStorageClass()); + printer << " \"" << sc << "\" " << getPtr(); + + printMemoryAccessAttribute(*this, printer, elidedAttrs); + + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + printer << " : " << getType(); +} + +LogicalResult LoadOp::verify() { + // SPIR-V spec : "Result Type is the type of the loaded object. It must be a + // type with fixed size; i.e., it cannot be, nor include, any + // OpTypeRuntimeArray types." + if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) { + return failure(); + } + return verifyMemoryAccessAttribute(*this); +} + +//===----------------------------------------------------------------------===// +// spirv.StoreOp +//===----------------------------------------------------------------------===// + +ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { + // Parse the storage class specification + spirv::StorageClass storageClass; + SmallVector operandInfo; + auto loc = parser.getCurrentLocation(); + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || + parser.parseOperandList(operandInfo, 2) || + parseMemoryAccessAttributes(parser, result) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get(elementType, storageClass); + if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, + result.operands)) { + return failure(); + } + return success(); +} + +void StoreOp::print(OpAsmPrinter &printer) { + SmallVector elidedAttrs; + StringRef sc = stringifyStorageClass( + llvm::cast(getPtr().getType()).getStorageClass()); + printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); + + printMemoryAccessAttribute(*this, printer, elidedAttrs); + + printer << " : " << getValue().getType(); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); +} + +LogicalResult StoreOp::verify() { + // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an + // OpTypePointer whose Type operand is the same as the type of Object." + if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) + return failure(); + return verifyMemoryAccessAttribute(*this); +} + +//===----------------------------------------------------------------------===// +// spirv.CopyMemory +//===----------------------------------------------------------------------===// + +void CopyMemoryOp::print(OpAsmPrinter &printer) { + printer << ' '; + + StringRef targetStorageClass = stringifyStorageClass( + llvm::cast(getTarget().getType()).getStorageClass()); + printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; + + StringRef sourceStorageClass = stringifyStorageClass( + llvm::cast(getSource().getType()).getStorageClass()); + printer << " \"" << sourceStorageClass << "\" " << getSource(); + + SmallVector elidedAttrs; + printMemoryAccessAttribute(*this, printer, elidedAttrs); + printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, + getSourceMemoryAccess(), + getSourceAlignment()); + + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + + Type pointeeType = + llvm::cast(getTarget().getType()).getPointeeType(); + printer << " : " << pointeeType; +} + +ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) { + spirv::StorageClass targetStorageClass; + OpAsmParser::UnresolvedOperand targetPtrInfo; + + spirv::StorageClass sourceStorageClass; + OpAsmParser::UnresolvedOperand sourcePtrInfo; + + Type elementType; + + if (parseEnumStrAttr(targetStorageClass, parser) || + parser.parseOperand(targetPtrInfo) || parser.parseComma() || + parseEnumStrAttr(sourceStorageClass, parser) || + parser.parseOperand(sourcePtrInfo) || + parseMemoryAccessAttributes(parser, result)) { + return failure(); + } + + if (!parser.parseOptionalComma()) { + // Parse 2nd memory access attributes. + if (parseSourceMemoryAccessAttributes(parser, result)) { + return failure(); + } + } + + if (parser.parseColon() || parser.parseType(elementType)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); + auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); + + if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) || + parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) { + return failure(); + } + + return success(); +} + +LogicalResult CopyMemoryOp::verify() { + Type targetType = + llvm::cast(getTarget().getType()).getPointeeType(); + + Type sourceType = + llvm::cast(getSource().getType()).getPointeeType(); + + if (targetType != sourceType) + return emitOpError("both operands must be pointers to the same type"); + + if (failed(verifyMemoryAccessAttribute(*this))) + return failure(); + + // TODO - According to the spec: + // + // If two masks are present, the first applies to Target and cannot include + // MakePointerVisible, and the second applies to Source and cannot include + // MakePointerAvailable. + // + // Add such verification here. + + return verifySourceMemoryAccessAttribute(*this); +} + +static ParseResult parsePtrAccessChainOpImpl(StringRef opName, + OpAsmParser &parser, + OperationState &state) { + OpAsmParser::UnresolvedOperand ptrInfo; + SmallVector indicesInfo; + Type type; + auto loc = parser.getCurrentLocation(); + SmallVector indicesTypes; + + if (parser.parseOperand(ptrInfo) || + parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || + parser.parseColonType(type) || + parser.resolveOperand(ptrInfo, type, state.operands)) + return failure(); + + // Check that the provided indices list is not empty before parsing their + // type list. + if (indicesInfo.empty()) + return emitError(state.location) << opName << " expected element"; + + if (parser.parseComma() || parser.parseTypeList(indicesTypes)) + return failure(); + + // Check that the indices types list is not empty and that it has a one-to-one + // mapping to the provided indices. + if (indicesTypes.size() != indicesInfo.size()) + return emitError(state.location) + << opName + << " indices types' count must be equal to indices info count"; + + if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) + return failure(); + + auto resultType = getElementPtrType( + type, llvm::ArrayRef(state.operands).drop_front(2), state.location); + if (!resultType) + return failure(); + + state.addTypes(resultType); + return success(); +} + +template +static auto concatElemAndIndices(Op op) { + SmallVector ret(op.getIndices().size() + 1); + ret[0] = op.getElement(); + llvm::copy(op.getIndices(), ret.begin() + 1); + return ret; +} + +//===----------------------------------------------------------------------===// +// spirv.InBoundsPtrAccessChainOp +//===----------------------------------------------------------------------===// + +void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state, + Value basePtr, Value element, + ValueRange indices) { + auto type = getElementPtrType(basePtr.getType(), indices, state.location); + assert(type && "Unable to deduce return type based on basePtr and indices"); + build(builder, state, type, basePtr, element, indices); +} + +ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser, + OperationState &result) { + return parsePtrAccessChainOpImpl( + spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result); +} + +void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) { + printAccessChain(*this, concatElemAndIndices(*this), printer); +} + +LogicalResult InBoundsPtrAccessChainOp::verify() { + return verifyAccessChain(*this, getIndices()); +} + +//===----------------------------------------------------------------------===// +// spirv.PtrAccessChainOp +//===----------------------------------------------------------------------===// + +void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, + Value basePtr, Value element, ValueRange indices) { + auto type = getElementPtrType(basePtr.getType(), indices, state.location); + assert(type && "Unable to deduce return type based on basePtr and indices"); + build(builder, state, type, basePtr, element, indices); +} + +ParseResult PtrAccessChainOp::parse(OpAsmParser &parser, + OperationState &result) { + return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(), + parser, result); +} + +void PtrAccessChainOp::print(OpAsmPrinter &printer) { + printAccessChain(*this, concatElemAndIndices(*this), printer); +} + +LogicalResult PtrAccessChainOp::verify() { + return verifyAccessChain(*this, getIndices()); +} + +//===----------------------------------------------------------------------===// +// spirv.Variable +//===----------------------------------------------------------------------===// + +ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) { + // Parse optional initializer + std::optional initInfo; + if (succeeded(parser.parseOptionalKeyword("init"))) { + initInfo = OpAsmParser::UnresolvedOperand(); + if (parser.parseLParen() || parser.parseOperand(*initInfo) || + parser.parseRParen()) + return failure(); + } + + if (parseVariableDecorations(parser, result)) { + return failure(); + } + + // Parse result pointer type + Type type; + if (parser.parseColon()) + return failure(); + auto loc = parser.getCurrentLocation(); + if (parser.parseType(type)) + return failure(); + + auto ptrType = llvm::dyn_cast(type); + if (!ptrType) + return parser.emitError(loc, "expected spirv.ptr type"); + result.addTypes(ptrType); + + // Resolve the initializer operand + if (initInfo) { + if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), + result.operands)) + return failure(); + } + + auto attr = parser.getBuilder().getAttr( + ptrType.getStorageClass()); + result.addAttribute(spirv::attributeName(), attr); + + return success(); +} + +void VariableOp::print(OpAsmPrinter &printer) { + SmallVector elidedAttrs{ + spirv::attributeName()}; + // Print optional initializer + if (getNumOperands() != 0) + printer << " init(" << getInitializer() << ")"; + + printVariableDecorations(*this, printer, elidedAttrs); + printer << " : " << getType(); +} + +LogicalResult VariableOp::verify() { + // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the + // object. It cannot be Generic. It must be the same as the Storage Class + // operand of the Result Type." + if (getStorageClass() != spirv::StorageClass::Function) { + return emitOpError( + "can only be used to model function-level variables. Use " + "spirv.GlobalVariable for module-level variables."); + } + + auto pointerType = llvm::cast(getPointer().getType()); + if (getStorageClass() != pointerType.getStorageClass()) + return emitOpError( + "storage class must match result pointer's storage class"); + + if (getNumOperands() != 0) { + // SPIR-V spec: "Initializer must be an from a constant instruction or + // a global (module scope) OpVariable instruction". + auto *initOp = getOperand(0).getDefiningOp(); + if (!initOp || !isa(initOp)) + return emitOpError("initializer must be the result of a " + "constant or spirv.GlobalVariable op"); + } + + // TODO: generate these strings using ODS. + auto *op = getOperation(); + auto descriptorSetName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::Binding)); + auto builtInName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::BuiltIn)); + + for (const auto &attr : {descriptorSetName, bindingName, builtInName}) { + if (op->getAttr(attr)) + return emitOpError("cannot have '") + << attr << "' attribute (only allowed in spirv.GlobalVariable)"; + } + + return success(); +} + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp @@ -0,0 +1,22 @@ +//===- SPIRVOpAvailability.cpp - MLIR SPIR-V Availability Implementation --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the SPIR-V operation availability in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +// TableGen'erated operation interfaces for querying versions, extensions, and +// capabilities. +#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" + +namespace mlir::spirv { +// TableGen'erated operation availability interface implementations. +#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc" +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -0,0 +1,76 @@ +//===- SPIRVOpDefinition.cpp - MLIR SPIR-V Op Definition Implementation ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the TableGen'erated SPIR-V op implementation in the SPIR-V dialect. +// These are placed in a separate file to reduce the total amount of code in +// SPIRVOps.cpp and make that file faster to recompile. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +#include "SPIRVParsingUtils.h" + +#include "mlir/IR/TypeUtilities.h" + +namespace mlir::spirv { +/// Returns true if the given op is a function-like op or nested in a +/// function-like op without a module-like op in the middle. +static bool isNestedInFunctionOpInterface(Operation *op) { + if (!op) + return false; + if (op->hasTrait()) + return false; + if (isa(op)) + return true; + return isNestedInFunctionOpInterface(op->getParentOp()); +} + +/// Returns true if the given op is an module-like op that maintains a symbol +/// table. +static bool isDirectInModuleLikeOp(Operation *op) { + return op && op->hasTrait(); +} + +/// Result of a logical op must be a scalar or vector of boolean type. +static Type getUnaryOpResultType(Type operandType) { + Builder builder(operandType.getContext()); + Type resultType = builder.getIntegerType(1); + if (auto vecType = llvm::dyn_cast(operandType)) + return VectorType::get(vecType.getNumElements(), resultType); + return resultType; +} + +static ParseResult parseImageOperands(OpAsmParser &parser, + spirv::ImageOperandsAttr &attr) { + // Expect image operands + if (parser.parseOptionalLSquare()) + return success(); + + spirv::ImageOperands imageOperands; + if (parseEnumStrAttr(imageOperands, parser)) + return failure(); + + attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands); + + return parser.parseRSquare(); +} + +static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, + spirv::ImageOperandsAttr attr) { + if (attr) { + auto strImageOperands = stringifyImageOperands(attr.getValue()); + printer << "[\"" << strImageOperands << "\"]"; + } +} + +} // namespace mlir::spirv + +// TablenGen'erated operation definitions. +#define GET_OP_CLASSES +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h @@ -29,6 +29,9 @@ llvm_unreachable("unhandled bit width computation for type"); } +void printVariableDecorations(Operation *op, OpAsmPrinter &printer, + SmallVectorImpl &elidedAttrs); + LogicalResult extractValueFromConstOp(Operation *op, int32_t &value); LogicalResult verifyMemorySemantics(Operation *op, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -28,7 +28,6 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -45,6 +44,77 @@ // Common utility functions //===----------------------------------------------------------------------===// +LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) { + auto constOp = dyn_cast_or_null(op); + if (!constOp) { + return failure(); + } + auto valueAttr = constOp.getValue(); + auto integerValueAttr = llvm::dyn_cast(valueAttr); + if (!integerValueAttr) { + return failure(); + } + + if (integerValueAttr.getType().isSignlessInteger()) + value = integerValueAttr.getInt(); + else + value = integerValueAttr.getSInt(); + + return success(); +} + +LogicalResult +spirv::verifyMemorySemantics(Operation *op, + spirv::MemorySemantics memorySemantics) { + // According to the SPIR-V specification: + // "Despite being a mask and allowing multiple bits to be combined, it is + // invalid for more than one of these four bits to be set: Acquire, Release, + // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and + // Release semantics is done by setting the AcquireRelease bit, not by setting + // two bits." + auto atMostOneInSet = spirv::MemorySemantics::Acquire | + spirv::MemorySemantics::Release | + spirv::MemorySemantics::AcquireRelease | + spirv::MemorySemantics::SequentiallyConsistent; + + auto bitCount = + llvm::popcount(static_cast(memorySemantics & atMostOneInSet)); + if (bitCount > 1) { + return op->emitError( + "expected at most one of these four memory constraints " + "to be set: `Acquire`, `Release`," + "`AcquireRelease` or `SequentiallyConsistent`"); + } + return success(); +} + +void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer, + SmallVectorImpl &elidedAttrs) { + // Print optional descriptor binding + auto descriptorSetName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::Binding)); + auto descriptorSet = op->getAttrOfType(descriptorSetName); + auto binding = op->getAttrOfType(bindingName); + if (descriptorSet && binding) { + elidedAttrs.push_back(descriptorSetName); + elidedAttrs.push_back(bindingName); + printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() + << ")"; + } + + // Print BuiltIn attribute if present + auto builtInName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::BuiltIn)); + if (auto builtin = op->getAttrOfType(builtInName)) { + printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; + elidedAttrs.push_back(builtInName); + } + + printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); +} + static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; @@ -93,177 +163,6 @@ p << " : " << resultType; } -/// Returns true if the given op is a function-like op or nested in a -/// function-like op without a module-like op in the middle. -static bool isNestedInFunctionOpInterface(Operation *op) { - if (!op) - return false; - if (op->hasTrait()) - return false; - if (isa(op)) - return true; - return isNestedInFunctionOpInterface(op->getParentOp()); -} - -/// Returns true if the given op is an module-like op that maintains a symbol -/// table. -static bool isDirectInModuleLikeOp(Operation *op) { - return op && op->hasTrait(); -} - -LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) { - auto constOp = dyn_cast_or_null(op); - if (!constOp) { - return failure(); - } - auto valueAttr = constOp.getValue(); - auto integerValueAttr = llvm::dyn_cast(valueAttr); - if (!integerValueAttr) { - return failure(); - } - - if (integerValueAttr.getType().isSignlessInteger()) - value = integerValueAttr.getInt(); - else - value = integerValueAttr.getSInt(); - - return success(); -} - -/// Parses Function, Selection and Loop control attributes. If no control is -/// specified, "None" is used as a default. -template -static ParseResult -parseControlAttribute(OpAsmParser &parser, OperationState &state, - StringRef attrName = spirv::attributeName()) { - if (succeeded(parser.parseOptionalKeyword(kControl))) { - EnumClass control; - if (parser.parseLParen() || - spirv::parseEnumKeywordAttr(control, parser, state) || - parser.parseRParen()) - return failure(); - return success(); - } - // Set control to "None" otherwise. - Builder builder = parser.getBuilder(); - state.addAttribute(attrName, - builder.getAttr(static_cast(0))); - return success(); -} - -// TODO Make sure to merge this and the previous function into one template -// parameterized by memory access attribute name and alignment. Doing so now -// results in VS2017 in producing an internal error (at the call site) that's -// not detailed enough to understand what is happening. -static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, - OperationState &state) { - // Parse an optional list of attributes staring with '[' - if (parser.parseOptionalLSquare()) { - // Nothing to do - return success(); - } - - spirv::MemoryAccess memoryAccessAttr; - if (spirv::parseEnumStrAttr( - memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName)) - return failure(); - - if (spirv::bitEnumContainsAll(memoryAccessAttr, - spirv::MemoryAccess::Aligned)) { - // Parse integer attribute for alignment. - Attribute alignmentAttr; - Type i32Type = parser.getBuilder().getIntegerType(32); - if (parser.parseComma() || - parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName, - state.attributes)) { - return failure(); - } - } - return parser.parseRSquare(); -} - -template -static void printMemoryAccessAttribute( - MemoryOpTy memoryOp, OpAsmPrinter &printer, - SmallVectorImpl &elidedAttrs, - std::optional memoryAccessAtrrValue = std::nullopt, - std::optional alignmentAttrValue = std::nullopt) { - // Print optional memory access attribute. - if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue - : memoryOp.getMemoryAccess())) { - elidedAttrs.push_back(kMemoryAccessAttrName); - - printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; - - if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { - // Print integer alignment attribute. - if (auto alignment = (alignmentAttrValue ? alignmentAttrValue - : memoryOp.getAlignment())) { - elidedAttrs.push_back(kAlignmentAttrName); - printer << ", " << *alignment; - } - } - printer << "]"; - } - elidedAttrs.push_back(spirv::attributeName()); -} - -// TODO Make sure to merge this and the previous function into one template -// parameterized by memory access attribute name and alignment. Doing so now -// results in VS2017 in producing an internal error (at the call site) that's -// not detailed enough to understand what is happening. -template -static void printSourceMemoryAccessAttribute( - MemoryOpTy memoryOp, OpAsmPrinter &printer, - SmallVectorImpl &elidedAttrs, - std::optional memoryAccessAtrrValue = std::nullopt, - std::optional alignmentAttrValue = std::nullopt) { - - printer << ", "; - - // Print optional memory access attribute. - if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue - : memoryOp.getMemoryAccess())) { - elidedAttrs.push_back(kSourceMemoryAccessAttrName); - - printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; - - if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { - // Print integer alignment attribute. - if (auto alignment = (alignmentAttrValue ? alignmentAttrValue - : memoryOp.getAlignment())) { - elidedAttrs.push_back(kSourceAlignmentAttrName); - printer << ", " << *alignment; - } - } - printer << "]"; - } - elidedAttrs.push_back(spirv::attributeName()); -} - -static ParseResult parseImageOperands(OpAsmParser &parser, - spirv::ImageOperandsAttr &attr) { - // Expect image operands - if (parser.parseOptionalLSquare()) - return success(); - - spirv::ImageOperands imageOperands; - if (parseEnumStrAttr(imageOperands, parser)) - return failure(); - - attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands); - - return parser.parseRSquare(); -} - -static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, - spirv::ImageOperandsAttr attr) { - if (attr) { - auto strImageOperands = stringifyImageOperands(attr.getValue()); - printer << "[\"" << strImageOperands << "\"]"; - } -} - template static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, @@ -292,130 +191,6 @@ return success(); } -template -static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { - // ODS checks for attributes values. Just need to verify that if the - // memory-access attribute is Aligned, then the alignment attribute must be - // present. - auto *op = memoryOp.getOperation(); - auto memAccessAttr = op->getAttr(kMemoryAccessAttrName); - if (!memAccessAttr) { - // Alignment attribute shouldn't be present if memory access attribute is - // not present. - if (op->getAttr(kAlignmentAttrName)) { - return memoryOp.emitOpError( - "invalid alignment specification without aligned memory access " - "specification"); - } - return success(); - } - - auto memAccess = llvm::cast(memAccessAttr); - - if (!memAccess) { - return memoryOp.emitOpError("invalid memory access specifier: ") - << memAccessAttr; - } - - if (spirv::bitEnumContainsAll(memAccess.getValue(), - spirv::MemoryAccess::Aligned)) { - if (!op->getAttr(kAlignmentAttrName)) { - return memoryOp.emitOpError("missing alignment value"); - } - } else { - if (op->getAttr(kAlignmentAttrName)) { - return memoryOp.emitOpError( - "invalid alignment specification with non-aligned memory access " - "specification"); - } - } - return success(); -} - -// TODO Make sure to merge this and the previous function into one template -// parameterized by memory access attribute name and alignment. Doing so now -// results in VS2017 in producing an internal error (at the call site) that's -// not detailed enough to understand what is happening. -template -static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { - // ODS checks for attributes values. Just need to verify that if the - // memory-access attribute is Aligned, then the alignment attribute must be - // present. - auto *op = memoryOp.getOperation(); - auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName); - if (!memAccessAttr) { - // Alignment attribute shouldn't be present if memory access attribute is - // not present. - if (op->getAttr(kSourceAlignmentAttrName)) { - return memoryOp.emitOpError( - "invalid alignment specification without aligned memory access " - "specification"); - } - return success(); - } - - auto memAccess = llvm::cast(memAccessAttr); - - if (!memAccess) { - return memoryOp.emitOpError("invalid memory access specifier: ") - << memAccess; - } - - if (spirv::bitEnumContainsAll(memAccess.getValue(), - spirv::MemoryAccess::Aligned)) { - if (!op->getAttr(kSourceAlignmentAttrName)) { - return memoryOp.emitOpError("missing alignment value"); - } - } else { - if (op->getAttr(kSourceAlignmentAttrName)) { - return memoryOp.emitOpError( - "invalid alignment specification with non-aligned memory access " - "specification"); - } - } - return success(); -} - -LogicalResult -spirv::verifyMemorySemantics(Operation *op, - spirv::MemorySemantics memorySemantics) { - // According to the SPIR-V specification: - // "Despite being a mask and allowing multiple bits to be combined, it is - // invalid for more than one of these four bits to be set: Acquire, Release, - // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and - // Release semantics is done by setting the AcquireRelease bit, not by setting - // two bits." - auto atMostOneInSet = spirv::MemorySemantics::Acquire | - spirv::MemorySemantics::Release | - spirv::MemorySemantics::AcquireRelease | - spirv::MemorySemantics::SequentiallyConsistent; - - auto bitCount = - llvm::popcount(static_cast(memorySemantics & atMostOneInSet)); - if (bitCount > 1) { - return op->emitError( - "expected at most one of these four memory constraints " - "to be set: `Acquire`, `Release`," - "`AcquireRelease` or `SequentiallyConsistent`"); - } - return success(); -} - -template -static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, - Value val) { - // ODS already checks ptr is spirv::PointerType. Just check that the pointee - // type of the pointer and the type of the value are the same - // - // TODO: Check that the value type satisfies restrictions of - // SPIR-V OpLoad/OpStore operations - if (val.getType() != - llvm::cast(ptr.getType()).getPointeeType()) { - return op.emitOpError("mismatch in result type and pointer type"); - } - return success(); -} - template static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val) { @@ -430,70 +205,6 @@ return success(); } -static ParseResult parseVariableDecorations(OpAsmParser &parser, - OperationState &state) { - auto builtInName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::BuiltIn)); - if (succeeded(parser.parseOptionalKeyword("bind"))) { - Attribute set, binding; - // Parse optional descriptor binding - auto descriptorSetName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::Binding)); - Type i32Type = parser.getBuilder().getIntegerType(32); - if (parser.parseLParen() || - parser.parseAttribute(set, i32Type, descriptorSetName, - state.attributes) || - parser.parseComma() || - parser.parseAttribute(binding, i32Type, bindingName, - state.attributes) || - parser.parseRParen()) { - return failure(); - } - } else if (succeeded(parser.parseOptionalKeyword(builtInName))) { - StringAttr builtIn; - if (parser.parseLParen() || - parser.parseAttribute(builtIn, builtInName, state.attributes) || - parser.parseRParen()) { - return failure(); - } - } - - // Parse other attributes - if (parser.parseOptionalAttrDict(state.attributes)) - return failure(); - - return success(); -} - -static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, - SmallVectorImpl &elidedAttrs) { - // Print optional descriptor binding - auto descriptorSetName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::Binding)); - auto descriptorSet = op->getAttrOfType(descriptorSetName); - auto binding = op->getAttrOfType(bindingName); - if (descriptorSet && binding) { - elidedAttrs.push_back(descriptorSetName); - elidedAttrs.push_back(bindingName); - printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() - << ")"; - } - - // Print BuiltIn attribute if present - auto builtInName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::BuiltIn)); - if (auto builtin = op->getAttrOfType(builtInName)) { - printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; - elidedAttrs.push_back(builtInName); - } - - printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); -} - /// Walks the given type hierarchy with the given indices, potentially down /// to component granularity, to select an element type. Returns null type and /// emits errors with the given loc on failure. @@ -564,12 +275,6 @@ return getElementType(type, indices, errorFn); } -/// Returns true if the given `block` only contains one `spirv.mlir.merge` op. -static inline bool isMergeBlock(Block &block) { - return !block.empty() && std::next(block.begin()) == block.end() && - isa(block.front()); -} - template static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { auto resultType = llvm::cast(op.getType()); @@ -617,15 +322,6 @@ printer << " : " << op->getResultTypes().front(); } -/// Result of a logical op must be a scalar or vector of boolean type. -static Type getUnaryOpResultType(Type operandType) { - Builder builder(operandType.getContext()); - Type resultType = builder.getIntegerType(1); - if (auto vecType = llvm::dyn_cast(operandType)) - return VectorType::get(vecType.getNumElements(), resultType); - return resultType; -} - static LogicalResult verifyShiftOp(Operation *op) { if (op->getOperand(0).getType() != op->getResult(0).getType()) { return op->emitError("expected the same type for the first operand and " @@ -636,152 +332,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// spirv.AccessChainOp -//===----------------------------------------------------------------------===// - -static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { - auto ptrType = llvm::dyn_cast(type); - if (!ptrType) { - emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " - "to composite type, but provided ") - << type; - return nullptr; - } - - auto resultType = ptrType.getPointeeType(); - auto resultStorageClass = ptrType.getStorageClass(); - int32_t index = 0; - - for (auto indexSSA : indices) { - auto cType = llvm::dyn_cast(resultType); - if (!cType) { - emitError( - baseLoc, - "'spirv.AccessChain' op cannot extract from non-composite type ") - << resultType << " with index " << index; - return nullptr; - } - index = 0; - if (llvm::isa(resultType)) { - Operation *op = indexSSA.getDefiningOp(); - if (!op) { - emitError(baseLoc, "'spirv.AccessChain' op index must be an " - "integer spirv.Constant to access " - "element of spirv.struct"); - return nullptr; - } - - // TODO: this should be relaxed to allow - // integer literals of other bitwidths. - if (failed(spirv::extractValueFromConstOp(op, index))) { - emitError( - baseLoc, - "'spirv.AccessChain' index must be an integer spirv.Constant to " - "access element of spirv.struct, but provided ") - << op->getName(); - return nullptr; - } - if (index < 0 || static_cast(index) >= cType.getNumElements()) { - emitError(baseLoc, "'spirv.AccessChain' op index ") - << index << " out of bounds for " << resultType; - return nullptr; - } - } - resultType = cType.getElementType(index); - } - return spirv::PointerType::get(resultType, resultStorageClass); -} - -void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state, - Value basePtr, ValueRange indices) { - auto type = getElementPtrType(basePtr.getType(), indices, state.location); - assert(type && "Unable to deduce return type based on basePtr and indices"); - build(builder, state, type, basePtr, indices); -} - -ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand ptrInfo; - SmallVector indicesInfo; - Type type; - auto loc = parser.getCurrentLocation(); - SmallVector indicesTypes; - - if (parser.parseOperand(ptrInfo) || - parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || - parser.parseColonType(type) || - parser.resolveOperand(ptrInfo, type, result.operands)) { - return failure(); - } - - // Check that the provided indices list is not empty before parsing their - // type list. - if (indicesInfo.empty()) { - return mlir::emitError(result.location, - "'spirv.AccessChain' op expected at " - "least one index "); - } - - if (parser.parseComma() || parser.parseTypeList(indicesTypes)) - return failure(); - - // Check that the indices types list is not empty and that it has a one-to-one - // mapping to the provided indices. - if (indicesTypes.size() != indicesInfo.size()) { - return mlir::emitError( - result.location, "'spirv.AccessChain' op indices types' count must be " - "equal to indices info count"); - } - - if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands)) - return failure(); - - auto resultType = getElementPtrType( - type, llvm::ArrayRef(result.operands).drop_front(), result.location); - if (!resultType) { - return failure(); - } - - result.addTypes(resultType); - return success(); -} - -template -static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { - printer << ' ' << op.getBasePtr() << '[' << indices - << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes(); -} - -void spirv::AccessChainOp::print(OpAsmPrinter &printer) { - printAccessChain(*this, getIndices(), printer); -} - -template -static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { - auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(), - indices, accessChainOp.getLoc()); - if (!resultType) - return failure(); - - auto providedResultType = - llvm::dyn_cast(accessChainOp.getType()); - if (!providedResultType) - return accessChainOp.emitOpError( - "result type must be a pointer, but provided") - << providedResultType; - - if (resultType != providedResultType) - return accessChainOp.emitOpError("invalid result type: expected ") - << resultType << ", but provided " << providedResultType; - - return success(); -} - -LogicalResult spirv::AccessChainOp::verify() { - return verifyAccessChain(*this, getIndices()); -} - //===----------------------------------------------------------------------===// // spirv.mlir.addressof //===----------------------------------------------------------------------===// @@ -805,109 +355,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// spirv.BranchOp -//===----------------------------------------------------------------------===// - -SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) { - assert(index == 0 && "invalid successor index"); - return SuccessorOperands(0, getTargetOperandsMutable()); -} - -//===----------------------------------------------------------------------===// -// spirv.BranchConditionalOp -//===----------------------------------------------------------------------===// - -SuccessorOperands -spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { - assert(index < 2 && "invalid successor index"); - return SuccessorOperands(index == kTrueIndex - ? getTrueTargetOperandsMutable() - : getFalseTargetOperandsMutable()); -} - -ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser, - OperationState &result) { - auto &builder = parser.getBuilder(); - OpAsmParser::UnresolvedOperand condInfo; - Block *dest; - - // Parse the condition. - Type boolTy = builder.getI1Type(); - if (parser.parseOperand(condInfo) || - parser.resolveOperand(condInfo, boolTy, result.operands)) - return failure(); - - // Parse the optional branch weights. - if (succeeded(parser.parseOptionalLSquare())) { - IntegerAttr trueWeight, falseWeight; - NamedAttrList weights; - - auto i32Type = builder.getIntegerType(32); - if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || - parser.parseComma() || - parser.parseAttribute(falseWeight, i32Type, "weight", weights) || - parser.parseRSquare()) - return failure(); - - result.addAttribute(kBranchWeightAttrName, - builder.getArrayAttr({trueWeight, falseWeight})); - } - - // Parse the true branch. - SmallVector trueOperands; - if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, trueOperands)) - return failure(); - result.addSuccessors(dest); - result.addOperands(trueOperands); - - // Parse the false branch. - SmallVector falseOperands; - if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, falseOperands)) - return failure(); - result.addSuccessors(dest); - result.addOperands(falseOperands); - result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {1, static_cast(trueOperands.size()), - static_cast(falseOperands.size())})); - - return success(); -} - -void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) { - printer << ' ' << getCondition(); - - if (auto weights = getBranchWeights()) { - printer << " ["; - llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { - printer << llvm::cast(a).getInt(); - }); - printer << "]"; - } - - printer << ", "; - printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); - printer << ", "; - printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); -} - -LogicalResult spirv::BranchConditionalOp::verify() { - if (auto weights = getBranchWeights()) { - if (weights->getValue().size() != 2) { - return emitOpError("must have exactly two branch weights"); - } - if (llvm::all_of(*weights, [](Attribute attr) { - return llvm::cast(attr).getValue().isZero(); - })) - return emitOpError("branch weights cannot both be zero"); - } - - return success(); -} - //===----------------------------------------------------------------------===// // spirv.CompositeConstruct //===----------------------------------------------------------------------===// @@ -1584,72 +1031,6 @@ return getResAttrs().value_or(nullptr); } -//===----------------------------------------------------------------------===// -// spirv.FunctionCall -//===----------------------------------------------------------------------===// - -LogicalResult spirv::FunctionCallOp::verify() { - auto fnName = getCalleeAttr(); - - auto funcOp = dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); - if (!funcOp) { - return emitOpError("callee function '") - << fnName.getValue() << "' not found in nearest symbol table"; - } - - auto functionType = funcOp.getFunctionType(); - - if (getNumResults() > 1) { - return emitOpError( - "expected callee function to have 0 or 1 result, but provided ") - << getNumResults(); - } - - if (functionType.getNumInputs() != getNumOperands()) { - return emitOpError("has incorrect number of operands for callee: expected ") - << functionType.getNumInputs() << ", but provided " - << getNumOperands(); - } - - for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { - if (getOperand(i).getType() != functionType.getInput(i)) { - return emitOpError("operand type mismatch: expected operand type ") - << functionType.getInput(i) << ", but provided " - << getOperand(i).getType() << " for operand number " << i; - } - } - - if (functionType.getNumResults() != getNumResults()) { - return emitOpError( - "has incorrect number of results has for callee: expected ") - << functionType.getNumResults() << ", but provided " - << getNumResults(); - } - - if (getNumResults() && - (getResult(0).getType() != functionType.getResult(0))) { - return emitOpError("result type mismatch: expected ") - << functionType.getResult(0) << ", but provided " - << getResult(0).getType(); - } - - return success(); -} - -CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() { - return (*this)->getAttrOfType(kCallee); -} - -void spirv::FunctionCallOp::setCalleeFromCallable( - CallInterfaceCallable callee) { - (*this)->setAttr(kCallee, callee.get()); -} - -Operation::operand_range spirv::FunctionCallOp::getArgOperands() { - return getArguments(); -} - //===----------------------------------------------------------------------===// // spirv.GLFClampOp //===----------------------------------------------------------------------===// @@ -1768,7 +1149,7 @@ } elidedAttrs.push_back(kTypeAttrName); - printVariableDecorations(*this, printer, elidedAttrs); + spirv::printVariableDecorations(*this, printer, elidedAttrs); printer << " : " << getType(); } @@ -1933,232 +1314,21 @@ ::printArithmeticExtendedBinaryOp(*this, printer); } -//===----------------------------------------------------------------------===// -// spirv.UMulExtended -//===----------------------------------------------------------------------===// - -LogicalResult spirv::UMulExtendedOp::verify() { - return ::verifyArithmeticExtendedBinaryOp(*this); -} - -ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseArithmeticExtendedBinaryOp(parser, result); -} - -void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) { - ::printArithmeticExtendedBinaryOp(*this, printer); -} - -//===----------------------------------------------------------------------===// -// spirv.LoadOp -//===----------------------------------------------------------------------===// - -void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, - Value basePtr, MemoryAccessAttr memoryAccess, - IntegerAttr alignment) { - auto ptrType = llvm::cast(basePtr.getType()); - build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, - alignment); -} - -ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) { - // Parse the storage class specification - spirv::StorageClass storageClass; - OpAsmParser::UnresolvedOperand ptrInfo; - Type elementType; - if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || - parseMemoryAccessAttributes(parser, result) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.parseType(elementType)) { - return failure(); - } - - auto ptrType = spirv::PointerType::get(elementType, storageClass); - if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { - return failure(); - } - - result.addTypes(elementType); - return success(); -} - -void spirv::LoadOp::print(OpAsmPrinter &printer) { - SmallVector elidedAttrs; - StringRef sc = stringifyStorageClass( - llvm::cast(getPtr().getType()).getStorageClass()); - printer << " \"" << sc << "\" " << getPtr(); - - printMemoryAccessAttribute(*this, printer, elidedAttrs); - - printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); - printer << " : " << getType(); -} - -LogicalResult spirv::LoadOp::verify() { - // SPIR-V spec : "Result Type is the type of the loaded object. It must be a - // type with fixed size; i.e., it cannot be, nor include, any - // OpTypeRuntimeArray types." - if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) { - return failure(); - } - return verifyMemoryAccessAttribute(*this); -} - -//===----------------------------------------------------------------------===// -// spirv.mlir.loop -//===----------------------------------------------------------------------===// - -void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) { - state.addAttribute("loop_control", builder.getAttr( - spirv::LoopControl::None)); - state.addRegion(); -} - -ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) { - if (parseControlAttribute(parser, - result)) - return failure(); - return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); -} - -void spirv::LoopOp::print(OpAsmPrinter &printer) { - auto control = getLoopControl(); - if (control != spirv::LoopControl::None) - printer << " control(" << spirv::stringifyLoopControl(control) << ")"; - printer << ' '; - printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); -} - -/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the -/// given `dstBlock`. -static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { - // Check that there is only one op in the `srcBlock`. - if (!llvm::hasSingleElement(srcBlock)) - return false; - - auto branchOp = dyn_cast(srcBlock.back()); - return branchOp && branchOp.getSuccessor() == &dstBlock; -} - -LogicalResult spirv::LoopOp::verifyRegions() { - auto *op = getOperation(); - - // We need to verify that the blocks follow the following layout: - // - // +-------------+ - // | entry block | - // +-------------+ - // | - // v - // +-------------+ - // | loop header | <-----+ - // +-------------+ | - // | - // ... | - // \ | / | - // v | - // +---------------+ | - // | loop continue | -----+ - // +---------------+ - // - // ... - // \ | / - // v - // +-------------+ - // | merge block | - // +-------------+ - - auto ®ion = op->getRegion(0); - // Allow empty region as a degenerated case, which can come from - // optimizations. - if (region.empty()) - return success(); - - // The last block is the merge block. - Block &merge = region.back(); - if (!isMergeBlock(merge)) - return emitOpError("last block must be the merge block with only one " - "'spirv.mlir.merge' op"); - - if (std::next(region.begin()) == region.end()) - return emitOpError( - "must have an entry block branching to the loop header block"); - // The first block is the entry block. - Block &entry = region.front(); - - if (std::next(region.begin(), 2) == region.end()) - return emitOpError( - "must have a loop header block branched from the entry block"); - // The second block is the loop header block. - Block &header = *std::next(region.begin(), 1); - - if (!hasOneBranchOpTo(entry, header)) - return emitOpError( - "entry block must only have one 'spirv.Branch' op to the second block"); - - if (std::next(region.begin(), 3) == region.end()) - return emitOpError( - "requires a loop continue block branching to the loop header block"); - // The second to last block is the loop continue block. - Block &cont = *std::prev(region.end(), 2); - - // Make sure that we have a branch from the loop continue block to the loop - // header block. - if (llvm::none_of( - llvm::seq(0, cont.getNumSuccessors()), - [&](unsigned index) { return cont.getSuccessor(index) == &header; })) - return emitOpError("second to last block must be the loop continue " - "block that branches to the loop header block"); - - // Make sure that no other blocks (except the entry and loop continue block) - // branches to the loop header block. - for (auto &block : llvm::make_range(std::next(region.begin(), 2), - std::prev(region.end(), 2))) { - for (auto i : llvm::seq(0, block.getNumSuccessors())) { - if (block.getSuccessor(i) == &header) { - return emitOpError("can only have the entry and loop continue " - "block branching to the loop header block"); - } - } - } - - return success(); -} - -Block *spirv::LoopOp::getEntryBlock() { - assert(!getBody().empty() && "op region should not be empty!"); - return &getBody().front(); -} - -Block *spirv::LoopOp::getHeaderBlock() { - assert(!getBody().empty() && "op region should not be empty!"); - // The second block is the loop header block. - return &*std::next(getBody().begin()); -} +//===----------------------------------------------------------------------===// +// spirv.UMulExtended +//===----------------------------------------------------------------------===// -Block *spirv::LoopOp::getContinueBlock() { - assert(!getBody().empty() && "op region should not be empty!"); - // The second to last block is the loop continue block. - return &*std::prev(getBody().end(), 2); +LogicalResult spirv::UMulExtendedOp::verify() { + return ::verifyArithmeticExtendedBinaryOp(*this); } -Block *spirv::LoopOp::getMergeBlock() { - assert(!getBody().empty() && "op region should not be empty!"); - // The last block is the loop merge block. - return &getBody().back(); +ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseArithmeticExtendedBinaryOp(parser, result); } -void spirv::LoopOp::addEntryAndMergeBlock() { - assert(getBody().empty() && "entry and merge block already exist"); - getBody().push_back(new Block()); - auto *mergeBlock = new Block(); - getBody().push_back(mergeBlock); - OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); - - // Add a spirv.mlir.merge op into the merge block. - builder.create(getLoc()); +void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) { + ::printArithmeticExtendedBinaryOp(*this, printer); } //===----------------------------------------------------------------------===// @@ -2169,24 +1339,6 @@ return verifyMemorySemantics(getOperation(), getMemorySemantics()); } -//===----------------------------------------------------------------------===// -// spirv.mlir.merge -//===----------------------------------------------------------------------===// - -LogicalResult spirv::MergeOp::verify() { - auto *parentOp = (*this)->getParentOp(); - if (!parentOp || !isa(parentOp)) - return emitOpError( - "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'"); - - // TODO: This check should be done in `verifyRegions` of parent op. - Block &parentLastBlock = (*this)->getParentRegion()->back(); - if (getOperation() != parentLastBlock.getTerminator()) - return emitOpError("can only be used in the last block of " - "'spirv.mlir.selection' or 'spirv.mlir.loop'"); - return success(); -} - //===----------------------------------------------------------------------===// // spirv.module //===----------------------------------------------------------------------===// @@ -2382,158 +1534,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// spirv.Return -//===----------------------------------------------------------------------===// - -LogicalResult spirv::ReturnOp::verify() { - // Verification is performed in spirv.func op. - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.ReturnValue -//===----------------------------------------------------------------------===// - -LogicalResult spirv::ReturnValueOp::verify() { - // Verification is performed in spirv.func op. - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.Select -//===----------------------------------------------------------------------===// - -LogicalResult spirv::SelectOp::verify() { - if (auto conditionTy = llvm::dyn_cast(getCondition().getType())) { - auto resultVectorTy = llvm::dyn_cast(getResult().getType()); - if (!resultVectorTy) { - return emitOpError("result expected to be of vector type when " - "condition is of vector type"); - } - if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { - return emitOpError("result should have the same number of elements as " - "the condition when condition is of vector type"); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.mlir.selection -//===----------------------------------------------------------------------===// - -ParseResult spirv::SelectionOp::parse(OpAsmParser &parser, - OperationState &result) { - if (parseControlAttribute(parser, result)) - return failure(); - return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); -} - -void spirv::SelectionOp::print(OpAsmPrinter &printer) { - auto control = getSelectionControl(); - if (control != spirv::SelectionControl::None) - printer << " control(" << spirv::stringifySelectionControl(control) << ")"; - printer << ' '; - printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); -} - -LogicalResult spirv::SelectionOp::verifyRegions() { - auto *op = getOperation(); - - // We need to verify that the blocks follow the following layout: - // - // +--------------+ - // | header block | - // +--------------+ - // / | \ - // ... - // - // - // +---------+ +---------+ +---------+ - // | case #0 | | case #1 | | case #2 | ... - // +---------+ +---------+ +---------+ - // - // - // ... - // \ | / - // v - // +-------------+ - // | merge block | - // +-------------+ - - auto ®ion = op->getRegion(0); - // Allow empty region as a degenerated case, which can come from - // optimizations. - if (region.empty()) - return success(); - - // The last block is the merge block. - if (!isMergeBlock(region.back())) - return emitOpError("last block must be the merge block with only one " - "'spirv.mlir.merge' op"); - - if (std::next(region.begin()) == region.end()) - return emitOpError("must have a selection header block"); - - return success(); -} - -Block *spirv::SelectionOp::getHeaderBlock() { - assert(!getBody().empty() && "op region should not be empty!"); - // The first block is the loop header block. - return &getBody().front(); -} - -Block *spirv::SelectionOp::getMergeBlock() { - assert(!getBody().empty() && "op region should not be empty!"); - // The last block is the loop merge block. - return &getBody().back(); -} - -void spirv::SelectionOp::addMergeBlock() { - assert(getBody().empty() && "entry and merge block already exist"); - auto *mergeBlock = new Block(); - getBody().push_back(mergeBlock); - OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); - - // Add a spirv.mlir.merge op into the merge block. - builder.create(getLoc()); -} - -spirv::SelectionOp spirv::SelectionOp::createIfThen( - Location loc, Value condition, - function_ref thenBody, OpBuilder &builder) { - auto selectionOp = - builder.create(loc, spirv::SelectionControl::None); - - selectionOp.addMergeBlock(); - Block *mergeBlock = selectionOp.getMergeBlock(); - Block *thenBlock = nullptr; - - // Build the "then" block. - { - OpBuilder::InsertionGuard guard(builder); - thenBlock = builder.createBlock(mergeBlock); - thenBody(builder); - builder.create(loc, mergeBlock); - } - - // Build the header block. - { - OpBuilder::InsertionGuard guard(builder); - builder.createBlock(thenBlock); - builder.create( - loc, condition, thenBlock, - /*trueArguments=*/ArrayRef(), mergeBlock, - /*falseArguments=*/ArrayRef()); - } - - return selectionOp; -} - //===----------------------------------------------------------------------===// // spirv.SpecConstant //===----------------------------------------------------------------------===// @@ -2588,171 +1588,6 @@ "default value can only be a bool, integer, or float scalar"); } -//===----------------------------------------------------------------------===// -// spirv.StoreOp -//===----------------------------------------------------------------------===// - -ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) { - // Parse the storage class specification - spirv::StorageClass storageClass; - SmallVector operandInfo; - auto loc = parser.getCurrentLocation(); - Type elementType; - if (parseEnumStrAttr(storageClass, parser) || - parser.parseOperandList(operandInfo, 2) || - parseMemoryAccessAttributes(parser, result) || parser.parseColon() || - parser.parseType(elementType)) { - return failure(); - } - - auto ptrType = spirv::PointerType::get(elementType, storageClass); - if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, - result.operands)) { - return failure(); - } - return success(); -} - -void spirv::StoreOp::print(OpAsmPrinter &printer) { - SmallVector elidedAttrs; - StringRef sc = stringifyStorageClass( - llvm::cast(getPtr().getType()).getStorageClass()); - printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); - - printMemoryAccessAttribute(*this, printer, elidedAttrs); - - printer << " : " << getValue().getType(); - printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); -} - -LogicalResult spirv::StoreOp::verify() { - // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an - // OpTypePointer whose Type operand is the same as the type of Object." - if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) - return failure(); - return verifyMemoryAccessAttribute(*this); -} - -//===----------------------------------------------------------------------===// -// spirv.Unreachable -//===----------------------------------------------------------------------===// - -LogicalResult spirv::UnreachableOp::verify() { - auto *block = (*this)->getBlock(); - // Fast track: if this is in entry block, its invalid. Otherwise, if no - // predecessors, it's valid. - if (block->isEntryBlock()) - return emitOpError("cannot be used in reachable block"); - if (block->hasNoPredecessors()) - return success(); - - // TODO: further verification needs to analyze reachability from - // the entry block. - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.Variable -//===----------------------------------------------------------------------===// - -ParseResult spirv::VariableOp::parse(OpAsmParser &parser, - OperationState &result) { - // Parse optional initializer - std::optional initInfo; - if (succeeded(parser.parseOptionalKeyword("init"))) { - initInfo = OpAsmParser::UnresolvedOperand(); - if (parser.parseLParen() || parser.parseOperand(*initInfo) || - parser.parseRParen()) - return failure(); - } - - if (parseVariableDecorations(parser, result)) { - return failure(); - } - - // Parse result pointer type - Type type; - if (parser.parseColon()) - return failure(); - auto loc = parser.getCurrentLocation(); - if (parser.parseType(type)) - return failure(); - - auto ptrType = llvm::dyn_cast(type); - if (!ptrType) - return parser.emitError(loc, "expected spirv.ptr type"); - result.addTypes(ptrType); - - // Resolve the initializer operand - if (initInfo) { - if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), - result.operands)) - return failure(); - } - - auto attr = parser.getBuilder().getAttr( - ptrType.getStorageClass()); - result.addAttribute(spirv::attributeName(), attr); - - return success(); -} - -void spirv::VariableOp::print(OpAsmPrinter &printer) { - SmallVector elidedAttrs{ - spirv::attributeName()}; - // Print optional initializer - if (getNumOperands() != 0) - printer << " init(" << getInitializer() << ")"; - - printVariableDecorations(*this, printer, elidedAttrs); - printer << " : " << getType(); -} - -LogicalResult spirv::VariableOp::verify() { - // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the - // object. It cannot be Generic. It must be the same as the Storage Class - // operand of the Result Type." - if (getStorageClass() != spirv::StorageClass::Function) { - return emitOpError( - "can only be used to model function-level variables. Use " - "spirv.GlobalVariable for module-level variables."); - } - - auto pointerType = llvm::cast(getPointer().getType()); - if (getStorageClass() != pointerType.getStorageClass()) - return emitOpError( - "storage class must match result pointer's storage class"); - - if (getNumOperands() != 0) { - // SPIR-V spec: "Initializer must be an from a constant instruction or - // a global (module scope) OpVariable instruction". - auto *initOp = getOperand(0).getDefiningOp(); - if (!initOp || !isa(initOp)) - return emitOpError("initializer must be the result of a " - "constant or spirv.GlobalVariable op"); - } - - // TODO: generate these strings using ODS. - auto *op = getOperation(); - auto descriptorSetName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::Binding)); - auto builtInName = llvm::convertToSnakeFromCamelCase( - stringifyDecoration(spirv::Decoration::BuiltIn)); - - for (const auto &attr : {descriptorSetName, bindingName, builtInName}) { - if (op->getAttr(attr)) - return emitOpError("cannot have '") - << attr << "' attribute (only allowed in spirv.GlobalVariable)"; - } - - return success(); -} - //===----------------------------------------------------------------------===// // spirv.VectorShuffle //===----------------------------------------------------------------------===// @@ -2804,100 +1639,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// spirv.CopyMemory -//===----------------------------------------------------------------------===// - -void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) { - printer << ' '; - - StringRef targetStorageClass = stringifyStorageClass( - llvm::cast(getTarget().getType()).getStorageClass()); - printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; - - StringRef sourceStorageClass = stringifyStorageClass( - llvm::cast(getSource().getType()).getStorageClass()); - printer << " \"" << sourceStorageClass << "\" " << getSource(); - - SmallVector elidedAttrs; - printMemoryAccessAttribute(*this, printer, elidedAttrs); - printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, - getSourceMemoryAccess(), - getSourceAlignment()); - - printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); - - Type pointeeType = - llvm::cast(getTarget().getType()).getPointeeType(); - printer << " : " << pointeeType; -} - -ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser, - OperationState &result) { - spirv::StorageClass targetStorageClass; - OpAsmParser::UnresolvedOperand targetPtrInfo; - - spirv::StorageClass sourceStorageClass; - OpAsmParser::UnresolvedOperand sourcePtrInfo; - - Type elementType; - - if (parseEnumStrAttr(targetStorageClass, parser) || - parser.parseOperand(targetPtrInfo) || parser.parseComma() || - parseEnumStrAttr(sourceStorageClass, parser) || - parser.parseOperand(sourcePtrInfo) || - parseMemoryAccessAttributes(parser, result)) { - return failure(); - } - - if (!parser.parseOptionalComma()) { - // Parse 2nd memory access attributes. - if (parseSourceMemoryAccessAttributes(parser, result)) { - return failure(); - } - } - - if (parser.parseColon() || parser.parseType(elementType)) - return failure(); - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); - auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); - - if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) || - parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) { - return failure(); - } - - return success(); -} - -LogicalResult spirv::CopyMemoryOp::verify() { - Type targetType = - llvm::cast(getTarget().getType()).getPointeeType(); - - Type sourceType = - llvm::cast(getSource().getType()).getPointeeType(); - - if (targetType != sourceType) - return emitOpError("both operands must be pointers to the same type"); - - if (failed(verifyMemoryAccessAttribute(*this))) - return failure(); - - // TODO - According to the spec: - // - // If two masks are present, the first applies to Target and cannot include - // MakePointerVisible, and the second applies to Source and cannot include - // MakePointerAvailable. - // - // Add such verification here. - - return verifySourceMemoryAccessAttribute(*this); -} - //===----------------------------------------------------------------------===// // spirv.Transpose //===----------------------------------------------------------------------===// @@ -3305,109 +2046,6 @@ return success(); } -static ParseResult parsePtrAccessChainOpImpl(StringRef opName, - OpAsmParser &parser, - OperationState &state) { - OpAsmParser::UnresolvedOperand ptrInfo; - SmallVector indicesInfo; - Type type; - auto loc = parser.getCurrentLocation(); - SmallVector indicesTypes; - - if (parser.parseOperand(ptrInfo) || - parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || - parser.parseColonType(type) || - parser.resolveOperand(ptrInfo, type, state.operands)) - return failure(); - - // Check that the provided indices list is not empty before parsing their - // type list. - if (indicesInfo.empty()) - return emitError(state.location) << opName << " expected element"; - - if (parser.parseComma() || parser.parseTypeList(indicesTypes)) - return failure(); - - // Check that the indices types list is not empty and that it has a one-to-one - // mapping to the provided indices. - if (indicesTypes.size() != indicesInfo.size()) - return emitError(state.location) - << opName - << " indices types' count must be equal to indices info count"; - - if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) - return failure(); - - auto resultType = getElementPtrType( - type, llvm::ArrayRef(state.operands).drop_front(2), state.location); - if (!resultType) - return failure(); - - state.addTypes(resultType); - return success(); -} - -template -static auto concatElemAndIndices(Op op) { - SmallVector ret(op.getIndices().size() + 1); - ret[0] = op.getElement(); - llvm::copy(op.getIndices(), ret.begin() + 1); - return ret; -} - -//===----------------------------------------------------------------------===// -// spirv.InBoundsPtrAccessChainOp -//===----------------------------------------------------------------------===// - -void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder, - OperationState &state, - Value basePtr, Value element, - ValueRange indices) { - auto type = getElementPtrType(basePtr.getType(), indices, state.location); - assert(type && "Unable to deduce return type based on basePtr and indices"); - build(builder, state, type, basePtr, element, indices); -} - -ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser, - OperationState &result) { - return parsePtrAccessChainOpImpl( - spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result); -} - -void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) { - printAccessChain(*this, concatElemAndIndices(*this), printer); -} - -LogicalResult spirv::InBoundsPtrAccessChainOp::verify() { - return verifyAccessChain(*this, getIndices()); -} - -//===----------------------------------------------------------------------===// -// spirv.PtrAccessChainOp -//===----------------------------------------------------------------------===// - -void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, - Value basePtr, Value element, - ValueRange indices) { - auto type = getElementPtrType(basePtr.getType(), indices, state.location); - assert(type && "Unable to deduce return type based on basePtr and indices"); - build(builder, state, type, basePtr, element, indices); -} - -ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser, - OperationState &result) { - return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(), - parser, result); -} - -void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) { - printAccessChain(*this, concatElemAndIndices(*this), printer); -} - -LogicalResult spirv::PtrAccessChainOp::verify() { - return verifyAccessChain(*this, getIndices()); -} - //===----------------------------------------------------------------------===// // spirv.VectorTimesScalarOp //===----------------------------------------------------------------------===// @@ -3420,18 +2058,3 @@ return emitOpError("scalar operand and result element type match"); return success(); } - -// TableGen'erated operation interfaces for querying versions, extensions, and -// capabilities. -#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" - -// TablenGen'erated operation definitions. -#define GET_OP_CLASSES -#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" - -namespace mlir { -namespace spirv { -// TableGen'erated operation availability interface implementations. -#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc" -} // namespace spirv -} // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h @@ -153,4 +153,7 @@ OpAsmParser &parser, OperationState &state, StringRef attrName = AttrNames::kMemoryAccessAttrName); +ParseResult parseVariableDecorations(OpAsmParser &parser, + OperationState &state); + } // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp @@ -12,6 +12,8 @@ #include "SPIRVParsingUtils.h" +#include "llvm/ADT/StringExtras.h" + using namespace mlir::spirv::AttrNames; namespace mlir::spirv { @@ -45,4 +47,41 @@ return parser.parseRSquare(); } +ParseResult parseVariableDecorations(OpAsmParser &parser, + OperationState &state) { + auto builtInName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::BuiltIn)); + if (succeeded(parser.parseOptionalKeyword("bind"))) { + Attribute set, binding; + // Parse optional descriptor binding + auto descriptorSetName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = llvm::convertToSnakeFromCamelCase( + stringifyDecoration(spirv::Decoration::Binding)); + Type i32Type = parser.getBuilder().getIntegerType(32); + if (parser.parseLParen() || + parser.parseAttribute(set, i32Type, descriptorSetName, + state.attributes) || + parser.parseComma() || + parser.parseAttribute(binding, i32Type, bindingName, + state.attributes) || + parser.parseRParen()) { + return failure(); + } + } else if (succeeded(parser.parseOptionalKeyword(builtInName))) { + StringAttr builtIn; + if (parser.parseLParen() || + parser.parseAttribute(builtIn, builtInName, state.attributes) || + parser.parseRParen()) { + return failure(); + } + } + + // Parse other attributes + if (parser.parseOptionalAttrDict(state.attributes)) + return failure(); + + return success(); +} + } // namespace mlir::spirv