diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -51,7 +51,7 @@ If Result Type has a different number of components than Operand, the total number of bits in Result Type must equal the total number of bits - in Operand. Let L be the type, either Result Type or Operand’s type, + in Operand. Let L be the type, either Result Type or Operand's type, that has the larger number of components. Let S be the other type, with the smaller number of components. The number of components in L must be an integer multiple of the number of components in S. The first @@ -335,17 +335,17 @@ def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> { let summary = [{ - Bit pattern-preserving conversion of a pointer to + Bit pattern-preserving conversion of a pointer to an unsigned scalar integer of possibly different bit width. }]; let description = [{ Result Type must be a scalar of integer type, whose Signedness operand is 0. - Pointer must be a physical pointer type. If the bit width of Pointer is - smaller than that of Result Type, the conversion zero extends Pointer. - If the bit width of Pointer is larger than that of Result Type, - the conversion truncates Pointer. + Pointer must be a physical pointer type. If the bit width of Pointer is + smaller than that of Result Type, the conversion zero extends Pointer. + If the bit width of Pointer is larger than that of Result Type, + the conversion truncates Pointer. For same bit width Pointer and Result Type, this is the same as OpBitcast. @@ -359,7 +359,7 @@ #### Example: ```mlir - %1 = spirv.ConvertPtrToU %0 : !spirv.ptr to i32 + %1 = spirv.ConvertPtrToU %0 : !spirv.ptr to i32 ``` }]; @@ -390,18 +390,18 @@ def SPIRV_ConvertUToPtrOp : SPIRV_Op<"ConvertUToPtr", [UnsignedOp]> { let summary = [{ - Bit pattern-preserving conversion of an unsigned scalar integer + Bit pattern-preserving conversion of an unsigned scalar integer to a pointer. }]; let description = [{ Result Type must be a physical pointer type. - Integer Value must be a scalar of integer type, whose Signedness - operand is 0. If the bit width of Integer Value is smaller + Integer Value must be a scalar of integer type, whose Signedness + operand is 0. If the bit width of Integer Value is smaller than that of Result Type, the conversion zero extends Integer Value. - If the bit width of Integer Value is larger than that of Result Type, - the conversion truncates Integer Value. + If the bit width of Integer Value is larger than that of Result Type, + the conversion truncates Integer Value. For same-width Integer Value and Result Type, this is the same as OpBitcast. @@ -415,7 +415,7 @@ #### Example: ```mlir - %1 = spirv.ConvertUToPtr %0 : i32 to !spirv.ptr + %1 = spirv.ConvertUToPtr %0 : i32 to !spirv.ptr ``` }]; diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp @@ -0,0 +1,441 @@ +//===- AtomicOps.cpp - MLIR SPIR-V Atomic Ops ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the atomic operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +#include "SPIRVOpUtils.h" +#include "SPIRVParsingUtils.h" + +using namespace mlir::spirv::AttrNames; + +namespace mlir::spirv { + +// Parses an atomic update op. If the update op does not take a value (like +// AtomicIIncrement) `hasValue` must be false. +static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, + OperationState &state, bool hasValue) { + spirv::Scope scope; + spirv::MemorySemantics memoryScope; + SmallVector operandInfo; + OpAsmParser::UnresolvedOperand ptrInfo, valueInfo; + Type type; + SMLoc loc; + if (parseEnumStrAttr(scope, parser, state, + kMemoryScopeAttrName) || + parseEnumStrAttr(memoryScope, parser, state, + kSemanticsAttrName) || + parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) || + parser.getCurrentLocation(&loc) || parser.parseColonType(type)) + return failure(); + + auto ptrType = llvm::dyn_cast(type); + if (!ptrType) + return parser.emitError(loc, "expected pointer type"); + + SmallVector operandTypes; + operandTypes.push_back(ptrType); + if (hasValue) + operandTypes.push_back(ptrType.getPointeeType()); + if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(), + state.operands)) + return failure(); + return parser.addTypeToList(ptrType.getPointeeType(), state.types); +} + +// Prints an atomic update op. +static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { + printer << " \""; + auto scopeAttr = op->getAttrOfType(kMemoryScopeAttrName); + printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \""; + auto memorySemanticsAttr = + op->getAttrOfType(kSemanticsAttrName); + printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue()) + << "\" " << op->getOperands() << " : " << op->getOperand(0).getType(); +} + +template +static StringRef stringifyTypeName(); + +template <> +StringRef stringifyTypeName() { + return "integer"; +} + +template <> +StringRef stringifyTypeName() { + return "float"; +} + +// Verifies an atomic update op. +template +static LogicalResult verifyAtomicUpdateOp(Operation *op) { + auto ptrType = llvm::cast(op->getOperand(0).getType()); + auto elementType = ptrType.getPointeeType(); + if (!llvm::isa(elementType)) + return op->emitOpError() << "pointer operand must point to an " + << stringifyTypeName() + << " value, found " << elementType; + + if (op->getNumOperands() > 1) { + auto valueType = op->getOperand(1).getType(); + if (valueType != elementType) + return op->emitOpError("expected value to have the same type as the " + "pointer operand's pointee type ") + << elementType << ", but found " << valueType; + } + auto memorySemantics = + op->getAttrOfType(kSemanticsAttrName) + .getValue(); + if (failed(verifyMemorySemantics(op, memorySemantics))) { + return failure(); + } + return success(); +} + +template +static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) { + printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \"" + << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \"" + << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" " + << atomOp.getOperands() << " : " << atomOp.getPointer().getType(); +} + +static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, + OperationState &state) { + spirv::Scope memoryScope; + spirv::MemorySemantics equalSemantics, unequalSemantics; + SmallVector operandInfo; + Type type; + if (parseEnumStrAttr(memoryScope, parser, state, + kMemoryScopeAttrName) || + parseEnumStrAttr( + equalSemantics, parser, state, kEqualSemanticsAttrName) || + parseEnumStrAttr( + unequalSemantics, parser, state, kUnequalSemanticsAttrName) || + parser.parseOperandList(operandInfo, 3)) + return failure(); + + auto loc = parser.getCurrentLocation(); + if (parser.parseColonType(type)) + return failure(); + + auto ptrType = llvm::dyn_cast(type); + if (!ptrType) + return parser.emitError(loc, "expected pointer type"); + + if (parser.resolveOperands( + operandInfo, + {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()}, + parser.getNameLoc(), state.operands)) + return failure(); + + return parser.addTypeToList(ptrType.getPointeeType(), state.types); +} + +template +static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) { + // According to the spec: + // "The type of Value must be the same as Result Type. The type of the value + // pointed to by Pointer must be the same as Result Type. This type must also + // match the type of Comparator." + if (atomOp.getType() != atomOp.getValue().getType()) + return atomOp.emitOpError("value operand must have the same type as the op " + "result, but found ") + << atomOp.getValue().getType() << " vs " << atomOp.getType(); + + if (atomOp.getType() != atomOp.getComparator().getType()) + return atomOp.emitOpError( + "comparator operand must have the same type as the op " + "result, but found ") + << atomOp.getComparator().getType() << " vs " << atomOp.getType(); + + Type pointeeType = + llvm::cast(atomOp.getPointer().getType()) + .getPointeeType(); + if (atomOp.getType() != pointeeType) + return atomOp.emitOpError( + "pointer operand's pointee type must have the same " + "as the op result type, but found ") + << pointeeType << " vs " << atomOp.getType(); + + // TODO: Unequal cannot be set to Release or Acquire and Release. + // In addition, Unequal cannot be set to a stronger memory-order then Equal. + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.AtomicAndOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicAndOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicAndOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicAndOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spirv.AtomicCompareExchangeOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicCompareExchangeOp::verify() { + return verifyAtomicCompareExchangeImpl(*this); +} + +ParseResult AtomicCompareExchangeOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseAtomicCompareExchangeImpl(parser, result); +} + +void AtomicCompareExchangeOp::print(OpAsmPrinter &p) { + printAtomicCompareExchangeImpl(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.AtomicCompareExchangeWeakOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicCompareExchangeWeakOp::verify() { + return verifyAtomicCompareExchangeImpl(*this); +} + +ParseResult AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseAtomicCompareExchangeImpl(parser, result); +} + +void AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) { + printAtomicCompareExchangeImpl(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.AtomicExchange +//===----------------------------------------------------------------------===// + +void AtomicExchangeOp::print(OpAsmPrinter &printer) { + printer << " \"" << stringifyScope(getMemoryScope()) << "\" \"" + << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands() + << " : " << getPointer().getType(); +} + +ParseResult AtomicExchangeOp::parse(OpAsmParser &parser, + OperationState &result) { + spirv::Scope memoryScope; + spirv::MemorySemantics semantics; + SmallVector operandInfo; + Type type; + if (parseEnumStrAttr(memoryScope, parser, result, + kMemoryScopeAttrName) || + parseEnumStrAttr(semantics, parser, result, + kSemanticsAttrName) || + parser.parseOperandList(operandInfo, 2)) + return failure(); + + auto loc = parser.getCurrentLocation(); + if (parser.parseColonType(type)) + return failure(); + + auto ptrType = llvm::dyn_cast(type); + if (!ptrType) + return parser.emitError(loc, "expected pointer type"); + + if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()}, + parser.getNameLoc(), result.operands)) + return failure(); + + return parser.addTypeToList(ptrType.getPointeeType(), result.types); +} + +LogicalResult AtomicExchangeOp::verify() { + if (getType() != getValue().getType()) + return emitOpError("value operand must have the same type as the op " + "result, but found ") + << getValue().getType() << " vs " << getType(); + + Type pointeeType = + llvm::cast(getPointer().getType()).getPointeeType(); + if (getType() != pointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << pointeeType << " vs " << getType(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.AtomicIAddOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicIAddOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicIAddOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicIAddOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spirv.EXT.AtomicFAddOp +//===----------------------------------------------------------------------===// + +LogicalResult EXTAtomicFAddOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult EXTAtomicFAddOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) { + printAtomicUpdateOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.AtomicIDecrementOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicIDecrementOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicIDecrementOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseAtomicUpdateOp(parser, result, false); +} + +void AtomicIDecrementOp::print(OpAsmPrinter &p) { + printAtomicUpdateOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.AtomicIIncrementOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicIIncrementOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicIIncrementOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseAtomicUpdateOp(parser, result, false); +} + +void AtomicIIncrementOp::print(OpAsmPrinter &p) { + printAtomicUpdateOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.AtomicISubOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicISubOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicISubOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicISubOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spirv.AtomicOrOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicOrOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicOrOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicOrOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spirv.AtomicSMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicSMaxOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicSMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicSMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spirv.AtomicSMinOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicSMinOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicSMinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicSMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spirv.AtomicUMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicUMaxOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicUMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicUMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spirv.AtomicUMinOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicUMinOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicUMinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicUMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spirv.AtomicXorOp +//===----------------------------------------------------------------------===// + +LogicalResult AtomicXorOp::verify() { + return verifyAtomicUpdateOp(getOperation()); +} + +ParseResult AtomicXorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAtomicUpdateOp(parser, result, true); +} + +void AtomicXorOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); } + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -3,7 +3,10 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen) add_mlir_dialect_library(MLIRSPIRVDialect + AtomicOps.cpp + CastOps.cpp CooperativeMatrixOps.cpp + GroupOps.cpp IntegerDotProductOps.cpp JointMatrixOps.cpp SPIRVAttributes.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp @@ -0,0 +1,339 @@ +//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the cast and conversion operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +#include "SPIRVOpUtils.h" +#include "SPIRVParsingUtils.h" + +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir::spirv::AttrNames; + +namespace mlir::spirv { + +static LogicalResult verifyCastOp(Operation *op, + bool requireSameBitWidth = true, + bool skipBitWidthCheck = false) { + // Some CastOps have no limit on bit widths for result and operand type. + if (skipBitWidthCheck) + return success(); + + Type operandType = op->getOperand(0).getType(); + Type resultType = op->getResult(0).getType(); + + // ODS checks that result type and operand type have the same shape. Check + // that composite types match and extract the element types, if any. + using TypePair = std::pair; + auto [operandElemTy, resultElemTy] = + TypeSwitch(operandType) + .Case( + [resultType](auto concreteOperandTy) -> TypePair { + if (auto concreteResultTy = + dyn_cast(resultType)) { + return {concreteOperandTy.getElementType(), + concreteResultTy.getElementType()}; + } + return {}; + }) + .Default([resultType](Type operandType) -> TypePair { + return {operandType, resultType}; + }); + + if (!operandElemTy || !resultElemTy) + return op->emitOpError("incompatible operand and result types"); + + unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth(); + unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth(); + bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; + + if (requireSameBitWidth) { + if (!isSameBitWidth) { + return op->emitOpError( + "expected the same bit widths for operand type and result " + "type, but provided ") + << operandElemTy << " and " << resultElemTy; + } + return success(); + } + + if (isSameBitWidth) { + return op->emitOpError( + "expected the different bit widths for operand type and result " + "type, but provided ") + << operandElemTy << " and " << resultElemTy; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.BitcastOp +//===----------------------------------------------------------------------===// + +LogicalResult BitcastOp::verify() { + // TODO: The SPIR-V spec validation rules are different for different + // versions. + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); + if (operandType == resultType) { + return emitError("result type must be different from operand type"); + } + if (llvm::isa(operandType) && + !llvm::isa(resultType)) { + return emitError( + "unhandled bit cast conversion from pointer type to non-pointer type"); + } + if (!llvm::isa(operandType) && + llvm::isa(resultType)) { + return emitError( + "unhandled bit cast conversion from non-pointer type to pointer type"); + } + auto operandBitWidth = getBitWidth(operandType); + auto resultBitWidth = getBitWidth(resultType); + if (operandBitWidth != resultBitWidth) { + return emitOpError("mismatch in result type bitwidth ") + << resultBitWidth << " and operand type bitwidth " + << operandBitWidth; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.ConvertPtrToUOp +//===----------------------------------------------------------------------===// + +LogicalResult ConvertPtrToUOp::verify() { + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); + if (!resultType || !resultType.isSignlessInteger()) + return emitError("result must be a scalar type of unsigned integer"); + auto spirvModule = (*this)->getParentOfType(); + if (!spirvModule) + return success(); + auto addressingModel = spirvModule.getAddressingModel(); + if ((addressingModel == spirv::AddressingModel::Logical) || + (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && + operandType.getStorageClass() != + spirv::StorageClass::PhysicalStorageBuffer)) + return emitError("operand must be a physical pointer"); + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.ConvertUToPtrOp +//===----------------------------------------------------------------------===// + +LogicalResult ConvertUToPtrOp::verify() { + auto operandType = llvm::cast(getOperand().getType()); + auto resultType = llvm::cast(getResult().getType()); + if (!operandType || !operandType.isSignlessInteger()) + return emitError("result must be a scalar type of unsigned integer"); + auto spirvModule = (*this)->getParentOfType(); + if (!spirvModule) + return success(); + auto addressingModel = spirvModule.getAddressingModel(); + if ((addressingModel == spirv::AddressingModel::Logical) || + (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && + resultType.getStorageClass() != + spirv::StorageClass::PhysicalStorageBuffer)) + return emitError("result must be a physical pointer"); + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.PtrCastToGenericOp +//===----------------------------------------------------------------------===// + +LogicalResult PtrCastToGenericOp::verify() { + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Workgroup && + operandStorage != spirv::StorageClass::CrossWorkgroup && + operandStorage != spirv::StorageClass::Function) + return emitError("pointer must point to the Workgroup, CrossWorkgroup" + ", or Function Storage Class"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Generic) + return emitError("result type must be of storage class Generic"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GenericCastToPtrOp +//===----------------------------------------------------------------------===// + +LogicalResult GenericCastToPtrOp::verify() { + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Generic) + return emitError("pointer type must be of storage class Generic"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Workgroup && + resultStorage != spirv::StorageClass::CrossWorkgroup && + resultStorage != spirv::StorageClass::Function) + return emitError("result must point to the Workgroup, CrossWorkgroup, " + "or Function Storage Class"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GenericCastToPtrExplicitOp +//===----------------------------------------------------------------------===// + +LogicalResult GenericCastToPtrExplicitOp::verify() { + auto operandType = llvm::cast(getPointer().getType()); + auto resultType = llvm::cast(getResult().getType()); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Generic) + return emitError("pointer type must be of storage class Generic"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Workgroup && + resultStorage != spirv::StorageClass::CrossWorkgroup && + resultStorage != spirv::StorageClass::Function) + return emitError("result must point to the Workgroup, CrossWorkgroup, " + "or Function Storage Class"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.ConvertFToSOp +//===----------------------------------------------------------------------===// + +LogicalResult ConvertFToSOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false, + /*skipBitWidthCheck=*/true); +} + +//===----------------------------------------------------------------------===// +// spirv.ConvertFToUOp +//===----------------------------------------------------------------------===// + +LogicalResult ConvertFToUOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false, + /*skipBitWidthCheck=*/true); +} + +//===----------------------------------------------------------------------===// +// spirv.ConvertSToFOp +//===----------------------------------------------------------------------===// + +LogicalResult ConvertSToFOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false, + /*skipBitWidthCheck=*/true); +} + +//===----------------------------------------------------------------------===// +// spirv.ConvertUToFOp +//===----------------------------------------------------------------------===// + +LogicalResult ConvertUToFOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false, + /*skipBitWidthCheck=*/true); +} + +//===----------------------------------------------------------------------===// +// spirv.INTELConvertBF16ToFOp +//===----------------------------------------------------------------------===// + +LogicalResult INTELConvertBF16ToFOp::verify() { + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); + // ODS checks that vector result type and vector operand type have the same + // shape. + if (auto vectorType = llvm::dyn_cast(operandType)) { + unsigned operandNumElements = vectorType.getNumElements(); + unsigned resultNumElements = + llvm::cast(resultType).getNumElements(); + if (operandNumElements != resultNumElements) { + return emitOpError( + "operand and result must have same number of elements"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.INTELConvertFToBF16Op +//===----------------------------------------------------------------------===// + +LogicalResult INTELConvertFToBF16Op::verify() { + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); + // ODS checks that vector result type and vector operand type have the same + // shape. + if (auto vectorType = llvm::dyn_cast(operandType)) { + unsigned operandNumElements = vectorType.getNumElements(); + unsigned resultNumElements = + llvm::cast(resultType).getNumElements(); + if (operandNumElements != resultNumElements) { + return emitOpError( + "operand and result must have same number of elements"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.FConvertOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::FConvertOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false); +} + +//===----------------------------------------------------------------------===// +// spirv.SConvertOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::SConvertOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false); +} + +//===----------------------------------------------------------------------===// +// spirv.UConvertOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::UConvertOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false); +} + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp @@ -0,0 +1,407 @@ +//===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the group operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" + +#include "SPIRVOpUtils.h" +#include "SPIRVParsingUtils.h" + +using namespace mlir::spirv::AttrNames; + +namespace mlir::spirv { + +static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, + OperationState &state) { + spirv::Scope executionScope; + GroupOperation groupOperation; + OpAsmParser::UnresolvedOperand valueInfo; + if (spirv::parseEnumStrAttr(executionScope, parser, state, + kExecutionScopeAttrName) || + spirv::parseEnumStrAttr(groupOperation, parser, state, + kGroupOperationAttrName) || + parser.parseOperand(valueInfo)) + return failure(); + + std::optional clusterSizeInfo; + if (succeeded(parser.parseOptionalKeyword(kClusterSize))) { + clusterSizeInfo = OpAsmParser::UnresolvedOperand(); + if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) || + parser.parseRParen()) + return failure(); + } + + Type resultType; + if (parser.parseColonType(resultType)) + return failure(); + + if (parser.resolveOperand(valueInfo, resultType, state.operands)) + return failure(); + + if (clusterSizeInfo) { + Type i32Type = parser.getBuilder().getIntegerType(32); + if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands)) + return failure(); + } + + return parser.addTypeToList(resultType, state.types); +} + +static void printGroupNonUniformArithmeticOp(Operation *groupOp, + OpAsmPrinter &printer) { + printer + << " \"" + << stringifyScope( + groupOp->getAttrOfType(kExecutionScopeAttrName) + .getValue()) + << "\" \"" + << stringifyGroupOperation( + groupOp->getAttrOfType(kGroupOperationAttrName) + .getValue()) + << "\" " << groupOp->getOperand(0); + + if (groupOp->getNumOperands() > 1) + printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')'; + printer << " : " << groupOp->getResult(0).getType(); +} + +static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { + spirv::Scope scope = + groupOp->getAttrOfType(kExecutionScopeAttrName) + .getValue(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return groupOp->emitOpError( + "execution scope must be 'Workgroup' or 'Subgroup'"); + + GroupOperation operation = + groupOp->getAttrOfType(kGroupOperationAttrName) + .getValue(); + if (operation == GroupOperation::ClusteredReduce && + groupOp->getNumOperands() == 1) + return groupOp->emitOpError("cluster size operand must be provided for " + "'ClusteredReduce' group operation"); + if (groupOp->getNumOperands() > 1) { + Operation *sizeOp = groupOp->getOperand(1).getDefiningOp(); + int32_t clusterSize = 0; + + // TODO: support specialization constant here. + if (failed(extractValueFromConstOp(sizeOp, clusterSize))) + return groupOp->emitOpError( + "cluster size operand must come from a constant op"); + + if (!llvm::isPowerOf2_32(clusterSize)) + return groupOp->emitOpError( + "cluster size operand must be a power of two"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupBroadcast +//===----------------------------------------------------------------------===// + +LogicalResult GroupBroadcastOp::verify() { + spirv::Scope scope = getExecutionScope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); + + if (auto localIdTy = llvm::dyn_cast(getLocalid().getType())) + if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) + return emitOpError("localid is a vector and can be with only " + " 2 or 3 components, actual number is ") + << localIdTy.getNumElements(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformBallotOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformBallotOp::verify() { + spirv::Scope scope = getExecutionScope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformBroadcast +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformBroadcastOp::verify() { + spirv::Scope scope = getExecutionScope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); + + // SPIR-V spec: "Before version 1.5, Id must come from a + // constant instruction. + auto targetEnv = spirv::getDefaultTargetEnv(getContext()); + if (auto spirvModule = (*this)->getParentOfType()) + targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); + + if (targetEnv.getVersion() < spirv::Version::V_1_5) { + auto *idOp = getId().getDefiningOp(); + if (!idOp || !isa(idOp)) // for spec constant + return emitOpError("id must be the result of a constant op"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformShuffle* +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) { + spirv::Scope scope = op.getExecutionScope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); + + if (op.getOperands().back().getType().isSignedInteger()) + return op.emitOpError("second operand must be a singless/unsigned integer"); + + return success(); +} + +LogicalResult GroupNonUniformShuffleOp::verify() { + return verifyGroupNonUniformShuffleOp(*this); +} +LogicalResult GroupNonUniformShuffleDownOp::verify() { + return verifyGroupNonUniformShuffleOp(*this); +} +LogicalResult GroupNonUniformShuffleUpOp::verify() { + return verifyGroupNonUniformShuffleOp(*this); +} +LogicalResult GroupNonUniformShuffleXorOp::verify() { + return verifyGroupNonUniformShuffleOp(*this); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformElectOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformElectOp::verify() { + spirv::Scope scope = getExecutionScope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformFAddOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformFAddOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformFAddOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformFMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformFMaxOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformFMinOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformFMinOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformFMinOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformFMulOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformFMulOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformFMulOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformIAddOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformIAddOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformIAddOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformIMulOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformIMulOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformIMulOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformSMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformSMaxOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformSMinOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformSMinOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformSMinOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformUMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformUMaxOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// spirv.GroupNonUniformUMinOp +//===----------------------------------------------------------------------===// + +LogicalResult GroupNonUniformUMinOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} + +void GroupNonUniformUMinOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + +//===----------------------------------------------------------------------===// +// Group op verification +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyGroupOp(Op op) { + spirv::Scope scope = op.getExecutionScope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); + + return success(); +} + +LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); } + +LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); } + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h @@ -29,4 +29,9 @@ llvm_unreachable("unhandled bit width computation for type"); } +LogicalResult extractValueFromConstOp(Operation *op, int32_t &value); + +LogicalResult verifyMemorySemantics(Operation *op, + spirv::MemorySemantics memorySemantics); + } // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -34,7 +34,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/TypeSwitch.h" #include #include #include @@ -112,7 +111,7 @@ return op && op->hasTrait(); } -static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) { +LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) { auto constOp = dyn_cast_or_null(op); if (!constOp) { return failure(); @@ -293,61 +292,6 @@ return success(); } -static LogicalResult verifyCastOp(Operation *op, - bool requireSameBitWidth = true, - bool skipBitWidthCheck = false) { - // Some CastOps have no limit on bit widths for result and operand type. - if (skipBitWidthCheck) - return success(); - - Type operandType = op->getOperand(0).getType(); - Type resultType = op->getResult(0).getType(); - - // ODS checks that result type and operand type have the same shape. Check - // that composite types match and extract the element types, if any. - using TypePair = std::pair; - auto [operandElemTy, resultElemTy] = - TypeSwitch(operandType) - .Case( - [resultType](auto concreteOperandTy) -> TypePair { - if (auto concreteResultTy = - dyn_cast(resultType)) { - return {concreteOperandTy.getElementType(), - concreteResultTy.getElementType()}; - } - return {}; - }) - .Default([resultType](Type operandType) -> TypePair { - return {operandType, resultType}; - }); - - if (!operandElemTy || !resultElemTy) - return op->emitOpError("incompatible operand and result types"); - - unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth(); - unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth(); - bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; - - if (requireSameBitWidth) { - if (!isSameBitWidth) { - return op->emitOpError( - "expected the same bit widths for operand type and result " - "type, but provided ") - << operandElemTy << " and " << resultElemTy; - } - return success(); - } - - if (isSameBitWidth) { - return op->emitOpError( - "expected the different bit widths for operand type and result " - "type, but provided ") - << operandElemTy << " and " << resultElemTy; - } - return success(); -} - template static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { // ODS checks for attributes values. Just need to verify that if the @@ -432,8 +376,9 @@ return success(); } -static LogicalResult -verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) { +LogicalResult +spirv::verifyMemorySemantics(Operation *op, + spirv::MemorySemantics memorySemantics) { // According to the SPIR-V specification: // "Despite being a mask and allowing multiple bits to be combined, it is // invalid for more than one of these four bits to be set: Acquire, Release, @@ -672,178 +617,6 @@ printer << " : " << op->getResultTypes().front(); } -//===----------------------------------------------------------------------===// -// Common parsers and printers -//===----------------------------------------------------------------------===// - -// Parses an atomic update op. If the update op does not take a value (like -// AtomicIIncrement) `hasValue` must be false. -static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, - OperationState &state, bool hasValue) { - spirv::Scope scope; - spirv::MemorySemantics memoryScope; - SmallVector operandInfo; - OpAsmParser::UnresolvedOperand ptrInfo, valueInfo; - Type type; - SMLoc loc; - if (spirv::parseEnumStrAttr(scope, parser, state, - kMemoryScopeAttrName) || - spirv::parseEnumStrAttr( - memoryScope, parser, state, kSemanticsAttrName) || - parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) || - parser.getCurrentLocation(&loc) || parser.parseColonType(type)) - return failure(); - - auto ptrType = llvm::dyn_cast(type); - if (!ptrType) - return parser.emitError(loc, "expected pointer type"); - - SmallVector operandTypes; - operandTypes.push_back(ptrType); - if (hasValue) - operandTypes.push_back(ptrType.getPointeeType()); - if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(), - state.operands)) - return failure(); - return parser.addTypeToList(ptrType.getPointeeType(), state.types); -} - -// Prints an atomic update op. -static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { - printer << " \""; - auto scopeAttr = op->getAttrOfType(kMemoryScopeAttrName); - printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \""; - auto memorySemanticsAttr = - op->getAttrOfType(kSemanticsAttrName); - printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue()) - << "\" " << op->getOperands() << " : " << op->getOperand(0).getType(); -} - -template -static StringRef stringifyTypeName(); - -template <> -StringRef stringifyTypeName() { - return "integer"; -} - -template <> -StringRef stringifyTypeName() { - return "float"; -} - -// Verifies an atomic update op. -template -static LogicalResult verifyAtomicUpdateOp(Operation *op) { - auto ptrType = llvm::cast(op->getOperand(0).getType()); - auto elementType = ptrType.getPointeeType(); - if (!llvm::isa(elementType)) - return op->emitOpError() << "pointer operand must point to an " - << stringifyTypeName() - << " value, found " << elementType; - - if (op->getNumOperands() > 1) { - auto valueType = op->getOperand(1).getType(); - if (valueType != elementType) - return op->emitOpError("expected value to have the same type as the " - "pointer operand's pointee type ") - << elementType << ", but found " << valueType; - } - auto memorySemantics = - op->getAttrOfType(kSemanticsAttrName) - .getValue(); - if (failed(verifyMemorySemantics(op, memorySemantics))) { - return failure(); - } - return success(); -} - -static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, - OperationState &state) { - spirv::Scope executionScope; - spirv::GroupOperation groupOperation; - OpAsmParser::UnresolvedOperand valueInfo; - if (spirv::parseEnumStrAttr(executionScope, parser, state, - kExecutionScopeAttrName) || - spirv::parseEnumStrAttr( - groupOperation, parser, state, kGroupOperationAttrName) || - parser.parseOperand(valueInfo)) - return failure(); - - std::optional clusterSizeInfo; - if (succeeded(parser.parseOptionalKeyword(kClusterSize))) { - clusterSizeInfo = OpAsmParser::UnresolvedOperand(); - if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) || - parser.parseRParen()) - return failure(); - } - - Type resultType; - if (parser.parseColonType(resultType)) - return failure(); - - if (parser.resolveOperand(valueInfo, resultType, state.operands)) - return failure(); - - if (clusterSizeInfo) { - Type i32Type = parser.getBuilder().getIntegerType(32); - if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands)) - return failure(); - } - - return parser.addTypeToList(resultType, state.types); -} - -static void printGroupNonUniformArithmeticOp(Operation *groupOp, - OpAsmPrinter &printer) { - printer - << " \"" - << stringifyScope( - groupOp->getAttrOfType(kExecutionScopeAttrName) - .getValue()) - << "\" \"" - << stringifyGroupOperation(groupOp - ->getAttrOfType( - kGroupOperationAttrName) - .getValue()) - << "\" " << groupOp->getOperand(0); - - if (groupOp->getNumOperands() > 1) - printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')'; - printer << " : " << groupOp->getResult(0).getType(); -} - -static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { - spirv::Scope scope = - groupOp->getAttrOfType(kExecutionScopeAttrName) - .getValue(); - if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return groupOp->emitOpError( - "execution scope must be 'Workgroup' or 'Subgroup'"); - - spirv::GroupOperation operation = - groupOp->getAttrOfType(kGroupOperationAttrName) - .getValue(); - if (operation == spirv::GroupOperation::ClusteredReduce && - groupOp->getNumOperands() == 1) - return groupOp->emitOpError("cluster size operand must be provided for " - "'ClusteredReduce' group operation"); - if (groupOp->getNumOperands() > 1) { - Operation *sizeOp = groupOp->getOperand(1).getDefiningOp(); - int32_t clusterSize = 0; - - // TODO: support specialization constant here. - if (failed(extractValueFromConstOp(sizeOp, clusterSize))) - return groupOp->emitOpError( - "cluster size operand must come from a constant op"); - - if (!llvm::isPowerOf2_32(clusterSize)) - return groupOp->emitOpError( - "cluster size operand must be a power of two"); - } - return success(); -} - /// Result of a logical op must be a scalar or vector of boolean type. static Type getUnaryOpResultType(Type operandType) { Builder builder(operandType.getContext()); @@ -901,7 +674,7 @@ // TODO: this should be relaxed to allow // integer literals of other bitwidths. - if (failed(extractValueFromConstOp(op, index))) { + if (failed(spirv::extractValueFromConstOp(op, index))) { emitError( baseLoc, "'spirv.AccessChain' index must be an integer spirv.Constant to " @@ -1032,514 +805,6 @@ return success(); } -template -static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) { - printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \"" - << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \"" - << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" " - << atomOp.getOperands() << " : " << atomOp.getPointer().getType(); -} - -static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, - OperationState &state) { - spirv::Scope memoryScope; - spirv::MemorySemantics equalSemantics, unequalSemantics; - SmallVector operandInfo; - Type type; - if (spirv::parseEnumStrAttr(memoryScope, parser, state, - kMemoryScopeAttrName) || - spirv::parseEnumStrAttr( - equalSemantics, parser, state, kEqualSemanticsAttrName) || - spirv::parseEnumStrAttr( - unequalSemantics, parser, state, kUnequalSemanticsAttrName) || - parser.parseOperandList(operandInfo, 3)) - return failure(); - - auto loc = parser.getCurrentLocation(); - if (parser.parseColonType(type)) - return failure(); - - auto ptrType = llvm::dyn_cast(type); - if (!ptrType) - return parser.emitError(loc, "expected pointer type"); - - if (parser.resolveOperands( - operandInfo, - {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()}, - parser.getNameLoc(), state.operands)) - return failure(); - - return parser.addTypeToList(ptrType.getPointeeType(), state.types); -} - -template -static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) { - // According to the spec: - // "The type of Value must be the same as Result Type. The type of the value - // pointed to by Pointer must be the same as Result Type. This type must also - // match the type of Comparator." - if (atomOp.getType() != atomOp.getValue().getType()) - return atomOp.emitOpError("value operand must have the same type as the op " - "result, but found ") - << atomOp.getValue().getType() << " vs " << atomOp.getType(); - - if (atomOp.getType() != atomOp.getComparator().getType()) - return atomOp.emitOpError( - "comparator operand must have the same type as the op " - "result, but found ") - << atomOp.getComparator().getType() << " vs " << atomOp.getType(); - - Type pointeeType = - llvm::cast(atomOp.getPointer().getType()) - .getPointeeType(); - if (atomOp.getType() != pointeeType) - return atomOp.emitOpError( - "pointer operand's pointee type must have the same " - "as the op result type, but found ") - << pointeeType << " vs " << atomOp.getType(); - - // TODO: Unequal cannot be set to Release or Acquire and Release. - // In addition, Unequal cannot be set to a stronger memory-order then Equal. - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicAndOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicAndOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicAndOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicCompareExchangeOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicCompareExchangeOp::verify() { - return ::verifyAtomicCompareExchangeImpl(*this); -} - -ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicCompareExchangeImpl(parser, result); -} -void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) { - ::printAtomicCompareExchangeImpl(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicCompareExchangeWeakOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() { - return ::verifyAtomicCompareExchangeImpl(*this); -} - -ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicCompareExchangeImpl(parser, result); -} -void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) { - ::printAtomicCompareExchangeImpl(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicExchange -//===----------------------------------------------------------------------===// - -void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) { - printer << " \"" << stringifyScope(getMemoryScope()) << "\" \"" - << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands() - << " : " << getPointer().getType(); -} - -ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser, - OperationState &result) { - spirv::Scope memoryScope; - spirv::MemorySemantics semantics; - SmallVector operandInfo; - Type type; - if (parseEnumStrAttr(memoryScope, parser, result, - kMemoryScopeAttrName) || - parseEnumStrAttr(semantics, parser, result, - kSemanticsAttrName) || - parser.parseOperandList(operandInfo, 2)) - return failure(); - - auto loc = parser.getCurrentLocation(); - if (parser.parseColonType(type)) - return failure(); - - auto ptrType = llvm::dyn_cast(type); - if (!ptrType) - return parser.emitError(loc, "expected pointer type"); - - if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()}, - parser.getNameLoc(), result.operands)) - return failure(); - - return parser.addTypeToList(ptrType.getPointeeType(), result.types); -} - -LogicalResult spirv::AtomicExchangeOp::verify() { - if (getType() != getValue().getType()) - return emitOpError("value operand must have the same type as the op " - "result, but found ") - << getValue().getType() << " vs " << getType(); - - Type pointeeType = - llvm::cast(getPointer().getType()).getPointeeType(); - if (getType() != pointeeType) - return emitOpError("pointer operand's pointee type must have the same " - "as the op result type, but found ") - << pointeeType << " vs " << getType(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicIAddOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicIAddOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicIAddOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.EXT.AtomicFAddOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::EXTAtomicFAddOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicIDecrementOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicIDecrementOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, false); -} -void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicIIncrementOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicIIncrementOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, false); -} -void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicISubOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicISubOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicISubOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicOrOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicOrOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicOrOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicSMaxOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicSMaxOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicSMinOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicSMinOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicSMinOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicUMaxOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicUMaxOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicUMinOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicUMinOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicUMinOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.AtomicXorOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::AtomicXorOp::verify() { - return ::verifyAtomicUpdateOp(getOperation()); -} - -ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser, - OperationState &result) { - return ::parseAtomicUpdateOp(parser, result, true); -} -void spirv::AtomicXorOp::print(OpAsmPrinter &p) { - ::printAtomicUpdateOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.BitcastOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::BitcastOp::verify() { - // TODO: The SPIR-V spec validation rules are different for different - // versions. - auto operandType = getOperand().getType(); - auto resultType = getResult().getType(); - if (operandType == resultType) { - return emitError("result type must be different from operand type"); - } - if (llvm::isa(operandType) && - !llvm::isa(resultType)) { - return emitError( - "unhandled bit cast conversion from pointer type to non-pointer type"); - } - if (!llvm::isa(operandType) && - llvm::isa(resultType)) { - return emitError( - "unhandled bit cast conversion from non-pointer type to pointer type"); - } - auto operandBitWidth = getBitWidth(operandType); - auto resultBitWidth = getBitWidth(resultType); - if (operandBitWidth != resultBitWidth) { - return emitOpError("mismatch in result type bitwidth ") - << resultBitWidth << " and operand type bitwidth " - << operandBitWidth; - } - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.ConvertPtrToUOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::ConvertPtrToUOp::verify() { - auto operandType = llvm::cast(getPointer().getType()); - auto resultType = llvm::cast(getResult().getType()); - if (!resultType || !resultType.isSignlessInteger()) - return emitError("result must be a scalar type of unsigned integer"); - auto spirvModule = (*this)->getParentOfType(); - if (!spirvModule) - return success(); - auto addressingModel = spirvModule.getAddressingModel(); - if ((addressingModel == spirv::AddressingModel::Logical) || - (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && - operandType.getStorageClass() != - spirv::StorageClass::PhysicalStorageBuffer)) - return emitError("operand must be a physical pointer"); - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.ConvertUToPtrOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::ConvertUToPtrOp::verify() { - auto operandType = llvm::cast(getOperand().getType()); - auto resultType = llvm::cast(getResult().getType()); - if (!operandType || !operandType.isSignlessInteger()) - return emitError("result must be a scalar type of unsigned integer"); - auto spirvModule = (*this)->getParentOfType(); - if (!spirvModule) - return success(); - auto addressingModel = spirvModule.getAddressingModel(); - if ((addressingModel == spirv::AddressingModel::Logical) || - (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && - resultType.getStorageClass() != - spirv::StorageClass::PhysicalStorageBuffer)) - return emitError("result must be a physical pointer"); - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.PtrCastToGenericOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::PtrCastToGenericOp::verify() { - auto operandType = llvm::cast(getPointer().getType()); - auto resultType = llvm::cast(getResult().getType()); - - spirv::StorageClass operandStorage = operandType.getStorageClass(); - if (operandStorage != spirv::StorageClass::Workgroup && - operandStorage != spirv::StorageClass::CrossWorkgroup && - operandStorage != spirv::StorageClass::Function) - return emitError("pointer must point to the Workgroup, CrossWorkgroup" - ", or Function Storage Class"); - - spirv::StorageClass resultStorage = resultType.getStorageClass(); - if (resultStorage != spirv::StorageClass::Generic) - return emitError("result type must be of storage class Generic"); - - Type operandPointeeType = operandType.getPointeeType(); - Type resultPointeeType = resultType.getPointeeType(); - if (operandPointeeType != resultPointeeType) - return emitOpError("pointer operand's pointee type must have the same " - "as the op result type, but found ") - << operandPointeeType << " vs " << resultPointeeType; - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.GenericCastToPtrOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GenericCastToPtrOp::verify() { - auto operandType = llvm::cast(getPointer().getType()); - auto resultType = llvm::cast(getResult().getType()); - - spirv::StorageClass operandStorage = operandType.getStorageClass(); - if (operandStorage != spirv::StorageClass::Generic) - return emitError("pointer type must be of storage class Generic"); - - spirv::StorageClass resultStorage = resultType.getStorageClass(); - if (resultStorage != spirv::StorageClass::Workgroup && - resultStorage != spirv::StorageClass::CrossWorkgroup && - resultStorage != spirv::StorageClass::Function) - return emitError("result must point to the Workgroup, CrossWorkgroup, " - "or Function Storage Class"); - - Type operandPointeeType = operandType.getPointeeType(); - Type resultPointeeType = resultType.getPointeeType(); - if (operandPointeeType != resultPointeeType) - return emitOpError("pointer operand's pointee type must have the same " - "as the op result type, but found ") - << operandPointeeType << " vs " << resultPointeeType; - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.GenericCastToPtrExplicitOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { - auto operandType = llvm::cast(getPointer().getType()); - auto resultType = llvm::cast(getResult().getType()); - - spirv::StorageClass operandStorage = operandType.getStorageClass(); - if (operandStorage != spirv::StorageClass::Generic) - return emitError("pointer type must be of storage class Generic"); - - spirv::StorageClass resultStorage = resultType.getStorageClass(); - if (resultStorage != spirv::StorageClass::Workgroup && - resultStorage != spirv::StorageClass::CrossWorkgroup && - resultStorage != spirv::StorageClass::Function) - return emitError("result must point to the Workgroup, CrossWorkgroup, " - "or Function Storage Class"); - - Type operandPointeeType = operandType.getPointeeType(); - Type resultPointeeType = resultType.getPointeeType(); - if (operandPointeeType != resultPointeeType) - return emitOpError("pointer operand's pointee type must have the same " - "as the op result type, but found ") - << operandPointeeType << " vs " << resultPointeeType; - return success(); -} - //===----------------------------------------------------------------------===// // spirv.BranchOp //===----------------------------------------------------------------------===// @@ -2065,84 +1330,6 @@ return verifyMemorySemantics(getOperation(), getMemorySemantics()); } -//===----------------------------------------------------------------------===// -// spirv.ConvertFToSOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::ConvertFToSOp::verify() { - return verifyCastOp(*this, /*requireSameBitWidth=*/false, - /*skipBitWidthCheck=*/true); -} - -//===----------------------------------------------------------------------===// -// spirv.ConvertFToUOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::ConvertFToUOp::verify() { - return verifyCastOp(*this, /*requireSameBitWidth=*/false, - /*skipBitWidthCheck=*/true); -} - -//===----------------------------------------------------------------------===// -// spirv.ConvertSToFOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::ConvertSToFOp::verify() { - return verifyCastOp(*this, /*requireSameBitWidth=*/false, - /*skipBitWidthCheck=*/true); -} - -//===----------------------------------------------------------------------===// -// spirv.ConvertUToFOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::ConvertUToFOp::verify() { - return verifyCastOp(*this, /*requireSameBitWidth=*/false, - /*skipBitWidthCheck=*/true); -} - -//===----------------------------------------------------------------------===// -// spirv.INTELConvertBF16ToFOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::INTELConvertBF16ToFOp::verify() { - auto operandType = getOperand().getType(); - auto resultType = getResult().getType(); - // ODS checks that vector result type and vector operand type have the same - // shape. - if (auto vectorType = llvm::dyn_cast(operandType)) { - unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = - llvm::cast(resultType).getNumElements(); - if (operandNumElements != resultNumElements) { - return emitOpError( - "operand and result must have same number of elements"); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.INTELConvertFToBF16Op -//===----------------------------------------------------------------------===// - -LogicalResult spirv::INTELConvertFToBF16Op::verify() { - auto operandType = getOperand().getType(); - auto resultType = getResult().getType(); - // ODS checks that vector result type and vector operand type have the same - // shape. - if (auto vectorType = llvm::dyn_cast(operandType)) { - unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = - llvm::cast(resultType).getNumElements(); - if (operandNumElements != resultNumElements) { - return emitOpError( - "operand and result must have same number of elements"); - } - } - return success(); -} - //===----------------------------------------------------------------------===// // spirv.EntryPoint //===----------------------------------------------------------------------===// @@ -2253,30 +1440,6 @@ }); } -//===----------------------------------------------------------------------===// -// spirv.FConvertOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::FConvertOp::verify() { - return verifyCastOp(*this, /*requireSameBitWidth=*/false); -} - -//===----------------------------------------------------------------------===// -// spirv.SConvertOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::SConvertOp::verify() { - return verifyCastOp(*this, /*requireSameBitWidth=*/false); -} - -//===----------------------------------------------------------------------===// -// spirv.UConvertOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::UConvertOp::verify() { - return verifyCastOp(*this, /*requireSameBitWidth=*/false); -} - //===----------------------------------------------------------------------===// // spirv.func //===----------------------------------------------------------------------===// @@ -2641,90 +1804,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// spirv.GroupBroadcast -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupBroadcastOp::verify() { - spirv::Scope scope = getExecutionScope(); - if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - - if (auto localIdTy = llvm::dyn_cast(getLocalid().getType())) - if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) - return emitOpError("localid is a vector and can be with only " - " 2 or 3 components, actual number is ") - << localIdTy.getNumElements(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformBallotOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformBallotOp::verify() { - spirv::Scope scope = getExecutionScope(); - if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformBroadcast -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformBroadcastOp::verify() { - spirv::Scope scope = getExecutionScope(); - if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - - // SPIR-V spec: "Before version 1.5, Id must come from a - // constant instruction. - auto targetEnv = spirv::getDefaultTargetEnv(getContext()); - if (auto spirvModule = (*this)->getParentOfType()) - targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); - - if (targetEnv.getVersion() < spirv::Version::V_1_5) { - auto *idOp = getId().getDefiningOp(); - if (!idOp || !isa(idOp)) // for spec constant - return emitOpError("id must be the result of a constant op"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformShuffle* -//===----------------------------------------------------------------------===// - -template -static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) { - spirv::Scope scope = op.getExecutionScope(); - if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - - if (op.getOperands().back().getType().isSignedInteger()) - return op.emitOpError("second operand must be a singless/unsigned integer"); - - return success(); -} - -LogicalResult spirv::GroupNonUniformShuffleOp::verify() { - return verifyGroupNonUniformShuffleOp(*this); -} -LogicalResult spirv::GroupNonUniformShuffleDownOp::verify() { - return verifyGroupNonUniformShuffleOp(*this); -} -LogicalResult spirv::GroupNonUniformShuffleUpOp::verify() { - return verifyGroupNonUniformShuffleOp(*this); -} -LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() { - return verifyGroupNonUniformShuffleOp(*this); -} - //===----------------------------------------------------------------------===// // spirv.INTEL.SubgroupBlockRead //===----------------------------------------------------------------------===// @@ -2803,178 +1882,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformElectOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformElectOp::verify() { - spirv::Scope scope = getExecutionScope(); - if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformFAddOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformFAddOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformFMaxOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformFMaxOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformFMinOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformFMinOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformFMulOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformFMulOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformIAddOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformIAddOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformIMulOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformIMulOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformSMaxOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformSMaxOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformSMinOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformSMinOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformUMaxOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformUMaxOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - -//===----------------------------------------------------------------------===// -// spirv.GroupNonUniformUMinOp -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GroupNonUniformUMinOp::verify() { - return verifyGroupNonUniformArithmeticOp(*this); -} - -ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseGroupNonUniformArithmeticOp(parser, result); -} -void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) { - printGroupNonUniformArithmeticOp(*this, p); -} - //===----------------------------------------------------------------------===// // spirv.IAddCarryOp //===----------------------------------------------------------------------===// @@ -4514,39 +3421,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// Group ops -//===----------------------------------------------------------------------===// - -template -static LogicalResult verifyGroupOp(Op op) { - spirv::Scope scope = op.getExecutionScope(); - if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - - return success(); -} - -LogicalResult spirv::GroupIAddOp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupFAddOp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupFMinOp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupUMinOp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupSMinOp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupFMaxOp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupUMaxOp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupSMaxOp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); } - -LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); } - // TableGen'erated operation interfaces for querying versions, extensions, and // capabilities. #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"