diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h b/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h +++ /dev/null @@ -1,45 +0,0 @@ -//===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines utilities used for parsing types and ops for SPIR-V -// dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_ -#define MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_ - -#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" - -namespace mlir { - -/// Parses the next keyword in `parser` as an enumerant of the given -/// `EnumClass`. -template -static ParseResult -parseEnumKeywordAttr(EnumClass &value, ParserType &parser, - StringRef attrName = spirv::attributeName()) { - StringRef keyword; - SmallVector attr; - auto loc = parser.getCurrentLocation(); - if (parser.parseKeyword(&keyword)) - return failure(); - if (std::optional attr = - spirv::symbolizeEnum(keyword)) { - value = *attr; - return success(); - } - return parser.emitError(loc, "invalid ") - << attrName << " attribute specification: " << keyword; -} - -} // namespace mlir - -#endif // MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_ diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -3,12 +3,16 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen) add_mlir_dialect_library(MLIRSPIRVDialect + CooperativeMatrixOps.cpp + IntegerDotProductOps.cpp + JointMatrixOps.cpp SPIRVAttributes.cpp SPIRVCanonicalization.cpp SPIRVGLCanonicalization.cpp SPIRVDialect.cpp SPIRVEnums.cpp SPIRVOps.cpp + SPIRVParsingUtils.cpp SPIRVTypes.cpp TargetAndABI.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -0,0 +1,306 @@ +//===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the Cooperative Matrix operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVParsingUtils.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +using namespace mlir::spirv::AttrNames; + +namespace mlir::spirv { +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixLength +//===----------------------------------------------------------------------===// + +LogicalResult KHRCooperativeMatrixLengthOp::verify() { + if (!isa(getCooperativeMatrixType())) { + return emitOpError( + "type attribute must be a '!spirv.coopmatrix' type, found ") + << getCooperativeMatrixType() << " instead"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixLoad +//===----------------------------------------------------------------------===// + +ParseResult KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser, + OperationState &result) { + std::array operandInfo = {}; + if (parser.parseOperand(operandInfo[0]) || parser.parseComma()) + return failure(); + if (parser.parseOperand(operandInfo[1]) || parser.parseComma()) + return failure(); + + CooperativeMatrixLayoutKHR layout; + if (parseEnumKeywordAttr( + layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) { + return failure(); + } + + if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName)) + return failure(); + + Type ptrType; + Type elementType; + if (parser.parseColon() || parser.parseType(ptrType) || + parser.parseKeywordType("as", elementType)) { + return failure(); + } + result.addTypes(elementType); + + Type strideType = parser.getBuilder().getIntegerType(32); + if (parser.resolveOperands(operandInfo, {ptrType, strideType}, + parser.getNameLoc(), result.operands)) { + return failure(); + } + + return success(); +} + +void KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { + printer << " " << getPointer() << ", " << getStride() << ", " + << getMatrixLayout(); + // Print optional memory operand attribute. + if (auto memOperand = getMemoryOperand()) + printer << " [\"" << memOperand << "\"]"; + printer << " : " << getPointer().getType() << " as " << getType(); +} + +static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, + Type coopMatrix) { + auto pointerType = cast(pointer); + Type pointeeType = pointerType.getPointeeType(); + if (!isa(pointeeType)) { + return op->emitError( + "Pointer must point to a scalar or vector type but provided ") + << pointeeType; + } + + // TODO: Verify the memory object behind the pointer: + // > If the Shader capability was declared, Pointer must point into an array + // > and any ArrayStride decoration on Pointer is ignored. + + return success(); +} + +LogicalResult KHRCooperativeMatrixLoadOp::verify() { + return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), + getResult().getType()); +} + +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixStore +//===----------------------------------------------------------------------===// + +ParseResult KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser, + OperationState &result) { + std::array operandInfo = {}; + for (auto &op : operandInfo) { + if (parser.parseOperand(op) || parser.parseComma()) + return failure(); + } + + CooperativeMatrixLayoutKHR layout; + if (parseEnumKeywordAttr( + layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) { + return failure(); + } + + if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName)) + return failure(); + + Type ptrType; + Type objectType; + if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() || + parser.parseType(objectType)) { + return failure(); + } + + Type strideType = parser.getBuilder().getIntegerType(32); + if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType}, + parser.getNameLoc(), result.operands)) { + return failure(); + } + + return success(); +} + +void KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { + printer << " " << getPointer() << ", " << getObject() << ", " << getStride() + << ", " << getMatrixLayout(); + + // Print optional memory operand attribute. + if (auto memOperand = getMemoryOperand()) + printer << " [\"" << *memOperand << "\"]"; + printer << " : " << getPointer().getType() << ", " << getObject().getType(); +} + +LogicalResult KHRCooperativeMatrixStoreOp::verify() { + return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), + getObject().getType()); +} + +//===----------------------------------------------------------------------===// +// spirv.NV.CooperativeMatrixLength +//===----------------------------------------------------------------------===// + +LogicalResult NVCooperativeMatrixLengthOp::verify() { + if (!isa(getCooperativeMatrixType())) { + return emitOpError( + "type attribute must be a '!spirv.NV.coopmatrix' type, found ") + << getCooperativeMatrixType() << " instead"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.NV.CooperativeMatrixLoad +//===----------------------------------------------------------------------===// + +ParseResult NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector operandInfo; + Type strideType = parser.getBuilder().getIntegerType(32); + Type columnMajorType = parser.getBuilder().getIntegerType(1); + Type ptrType; + Type elementType; + if (parser.parseOperandList(operandInfo, 3) || + parseMemoryAccessAttributes(parser, result) || parser.parseColon() || + parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) { + return failure(); + } + if (parser.resolveOperands(operandInfo, + {ptrType, strideType, columnMajorType}, + parser.getNameLoc(), result.operands)) { + return failure(); + } + + result.addTypes(elementType); + return success(); +} + +void NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { + printer << " " << getPointer() << ", " << getStride() << ", " + << getColumnmajor(); + // Print optional memory access attribute. + if (auto memAccess = getMemoryAccess()) + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; + printer << " : " << getPointer().getType() << " as " << getType(); +} + +static LogicalResult +verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) { + Type pointeeType = llvm::cast(pointer).getPointeeType(); + if (!llvm::isa(pointeeType) && + !llvm::isa(pointeeType)) + return op->emitError( + "Pointer must point to a scalar or vector type but provided ") + << pointeeType; + StorageClass storage = llvm::cast(pointer).getStorageClass(); + if (storage != StorageClass::Workgroup && + storage != StorageClass::StorageBuffer && + storage != StorageClass::PhysicalStorageBuffer) + return op->emitError( + "Pointer storage class must be Workgroup, StorageBuffer or " + "PhysicalStorageBufferEXT but provided ") + << stringifyStorageClass(storage); + return success(); +} + +LogicalResult NVCooperativeMatrixLoadOp::verify() { + return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(), + getResult().getType()); +} + +//===----------------------------------------------------------------------===// +// spirv.NV.CooperativeMatrixStore +//===----------------------------------------------------------------------===// + +ParseResult NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector operandInfo; + Type strideType = parser.getBuilder().getIntegerType(32); + Type columnMajorType = parser.getBuilder().getIntegerType(1); + Type ptrType; + Type elementType; + if (parser.parseOperandList(operandInfo, 4) || + parseMemoryAccessAttributes(parser, result) || parser.parseColon() || + parser.parseType(ptrType) || parser.parseComma() || + parser.parseType(elementType)) { + return failure(); + } + if (parser.resolveOperands( + operandInfo, {ptrType, elementType, strideType, columnMajorType}, + parser.getNameLoc(), result.operands)) { + return failure(); + } + + return success(); +} + +void NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { + printer << " " << getPointer() << ", " << getObject() << ", " << getStride() + << ", " << getColumnmajor(); + // Print optional memory access attribute. + if (auto memAccess = getMemoryAccess()) + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; + printer << " : " << getPointer().getType() << ", " << getOperand(1).getType(); +} + +LogicalResult NVCooperativeMatrixStoreOp::verify() { + return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(), + getObject().getType()); +} + +//===----------------------------------------------------------------------===// +// spirv.NV.CooperativeMatrixMulAdd +//===----------------------------------------------------------------------===// + +static LogicalResult verifyCoopMatrixMulAddNV(NVCooperativeMatrixMulAddOp op) { + if (op.getC().getType() != op.getResult().getType()) + return op.emitOpError("result and third operand must have the same type"); + auto typeA = llvm::cast(op.getA().getType()); + auto typeB = llvm::cast(op.getB().getType()); + auto typeC = llvm::cast(op.getC().getType()); + auto typeR = llvm::cast(op.getResult().getType()); + if (typeA.getRows() != typeR.getRows() || + typeA.getColumns() != typeB.getRows() || + typeB.getColumns() != typeR.getColumns()) + return op.emitOpError("matrix size must match"); + if (typeR.getScope() != typeA.getScope() || + typeR.getScope() != typeB.getScope() || + typeR.getScope() != typeC.getScope()) + return op.emitOpError("matrix scope must match"); + auto elementTypeA = typeA.getElementType(); + auto elementTypeB = typeB.getElementType(); + if (isa(elementTypeA) && isa(elementTypeB)) { + if (llvm::cast(elementTypeA).getWidth() != + llvm::cast(elementTypeB).getWidth()) + return op.emitOpError( + "matrix A and B integer element types must be the same bit width"); + } else if (elementTypeA != elementTypeB) { + return op.emitOpError( + "matrix A and B non-integer element types must match"); + } + if (typeR.getElementType() != typeC.getElementType()) + return op.emitOpError("matrix accumulator element type must match"); + return success(); +} + +LogicalResult NVCooperativeMatrixMulAddOp::verify() { + return verifyCoopMatrixMulAddNV(*this); +} + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp @@ -0,0 +1,158 @@ +//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the Integer Dot Product operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +#include "SPIRVOpUtils.h" +#include "SPIRVParsingUtils.h" + +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir::spirv::AttrNames; + +namespace mlir::spirv { + +//===----------------------------------------------------------------------===// +// Integer Dot Product ops +//===----------------------------------------------------------------------===// + +static LogicalResult verifyIntegerDotProduct(Operation *op) { + assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) && + "Not an integer dot product op?"); + assert(op->getNumResults() == 1 && "Expected a single result"); + + Type factorTy = op->getOperand(0).getType(); + if (op->getOperand(1).getType() != factorTy) + return op->emitOpError("requires the same type for both vector operands"); + + unsigned expectedNumAttrs = 0; + if (auto intTy = llvm::dyn_cast(factorTy)) { + ++expectedNumAttrs; + auto packedVectorFormat = + llvm::dyn_cast_or_null( + op->getAttr(kPackedVectorFormatAttrName)); + if (!packedVectorFormat) + return op->emitOpError("requires Packed Vector Format attribute for " + "integer vector operands"); + + assert(packedVectorFormat.getValue() == + spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && + "Unknown Packed Vector Format"); + if (intTy.getWidth() != 32) + return op->emitOpError( + llvm::formatv("with specified Packed Vector Format ({0}) requires " + "integer vector operands to be 32-bits wide", + packedVectorFormat.getValue())); + } else { + if (op->hasAttr(kPackedVectorFormatAttrName)) + return op->emitOpError(llvm::formatv( + "with invalid format attribute for vector operands of type '{0}'", + factorTy)); + } + + if (op->getAttrs().size() > expectedNumAttrs) + return op->emitError( + "op only supports the 'format' #spirv.packed_vector_format attribute"); + + Type resultTy = op->getResultTypes().front(); + bool hasAccumulator = op->getNumOperands() == 3; + if (hasAccumulator && op->getOperand(2).getType() != resultTy) + return op->emitOpError( + "requires the same accumulator operand and result types"); + + unsigned factorBitWidth = getBitWidth(factorTy); + unsigned resultBitWidth = getBitWidth(resultTy); + if (factorBitWidth > resultBitWidth) + return op->emitOpError( + llvm::formatv("result type has insufficient bit-width ({0} bits) " + "for the specified vector operand type ({1} bits)", + resultBitWidth, factorBitWidth)); + + return success(); +} + +static std::optional getIntegerDotProductMinVersion() { + return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. +} + +static std::optional getIntegerDotProductMaxVersion() { + return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. +} + +static SmallVector, 1> +getIntegerDotProductExtensions() { + // Requires the SPV_KHR_integer_dot_product extension, specified either + // explicitly or implied by target env's SPIR-V version >= 1.6. + static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product; + return {extension}; +} + +static SmallVector, 1> +getIntegerDotProductCapabilities(Operation *op) { + // Requires the the DotProduct capability and capabilities that depend on + // exact op types. + static const auto dotProductCap = spirv::Capability::DotProduct; + static const auto dotProductInput4x8BitPackedCap = + spirv::Capability::DotProductInput4x8BitPacked; + static const auto dotProductInput4x8BitCap = + spirv::Capability::DotProductInput4x8Bit; + static const auto dotProductInputAllCap = + spirv::Capability::DotProductInputAll; + + SmallVector, 1> capabilities = {dotProductCap}; + + Type factorTy = op->getOperand(0).getType(); + if (auto intTy = llvm::dyn_cast(factorTy)) { + auto formatAttr = llvm::cast( + op->getAttr(kPackedVectorFormatAttrName)); + if (formatAttr.getValue() == + spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) + capabilities.push_back(dotProductInput4x8BitPackedCap); + + return capabilities; + } + + auto vecTy = llvm::cast(factorTy); + if (vecTy.getElementTypeBitWidth() == 8) { + capabilities.push_back(dotProductInput4x8BitCap); + return capabilities; + } + + capabilities.push_back(dotProductInputAllCap); + return capabilities; +} + +#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \ + LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \ + SmallVector, 1> OpName::getExtensions() { \ + return getIntegerDotProductExtensions(); \ + } \ + SmallVector, 1> OpName::getCapabilities() { \ + return getIntegerDotProductCapabilities(*this); \ + } \ + std::optional OpName::getMinVersion() { \ + return getIntegerDotProductMinVersion(); \ + } \ + std::optional OpName::getMaxVersion() { \ + return getIntegerDotProductMaxVersion(); \ + } + +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp) + +#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp @@ -0,0 +1,84 @@ +//===- JointMatrixOps.cpp - MLIR SPIR-V Intel Joint Matrix Ops -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the Intel Joint Matrix operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +namespace mlir { +//===----------------------------------------------------------------------===// +// spirv.INTEL.JointMatrixLoad +//===----------------------------------------------------------------------===// + +static LogicalResult +verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { + Type pointeeType = llvm::cast(pointer).getPointeeType(); + if (!llvm::isa(pointeeType) && + !llvm::isa(pointeeType)) + return op->emitError( + "Pointer must point to a scalar or vector type but provided ") + << pointeeType; + spirv::StorageClass storage = + llvm::cast(pointer).getStorageClass(); + if (storage != spirv::StorageClass::Workgroup && + storage != spirv::StorageClass::CrossWorkgroup && + storage != spirv::StorageClass::UniformConstant && + storage != spirv::StorageClass::Generic) + return op->emitError("Pointer storage class must be Workgroup or " + "CrossWorkgroup but provided ") + << stringifyStorageClass(storage); + return success(); +} + +LogicalResult spirv::INTELJointMatrixLoadOp::verify() { + return verifyPointerAndJointMatrixType(*this, getPointer().getType(), + getResult().getType()); +} + +//===----------------------------------------------------------------------===// +// spirv.INTEL.JointMatrixStore +//===----------------------------------------------------------------------===// + +LogicalResult spirv::INTELJointMatrixStoreOp::verify() { + return verifyPointerAndJointMatrixType(*this, getPointer().getType(), + getObject().getType()); +} + +//===----------------------------------------------------------------------===// +// spirv.INTEL.JointMatrixMad +//===----------------------------------------------------------------------===// + +static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { + if (op.getC().getType() != op.getResult().getType()) + return op.emitOpError("result and third operand must have the same type"); + auto typeA = llvm::cast(op.getA().getType()); + auto typeB = llvm::cast(op.getB().getType()); + auto typeC = llvm::cast(op.getC().getType()); + auto typeR = + llvm::cast(op.getResult().getType()); + if (typeA.getRows() != typeR.getRows() || + typeA.getColumns() != typeB.getRows() || + typeB.getColumns() != typeR.getColumns()) + return op.emitOpError("matrix size must match"); + if (typeR.getScope() != typeA.getScope() || + typeR.getScope() != typeB.getScope() || + typeR.getScope() != typeC.getScope()) + return op.emitOpError("matrix scope must match"); + if (typeA.getElementType() != typeB.getElementType() || + typeR.getElementType() != typeC.getElementType()) + return op.emitOpError("matrix element type must match"); + return success(); +} + +LogicalResult spirv::INTELJointMatrixMadOp::verify() { + return verifyJointMatrixMad(*this); +} + +} // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -11,7 +11,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/IR/ParserUtils.h" + +#include "SPIRVParsingUtils.h" + #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" @@ -341,11 +343,13 @@ return {}; Scope scope; - if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope ")) + if (parser.parseComma() || + spirv::parseEnumKeywordAttr(scope, parser, "scope ")) return {}; CooperativeMatrixUseKHR use; - if (parser.parseComma() || parseEnumKeywordAttr(use, parser, "use ")) + if (parser.parseComma() || + spirv::parseEnumKeywordAttr(use, parser, "use ")) return {}; if (parser.parseGreater()) @@ -376,7 +380,8 @@ return Type(); Scope scope; - if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope ")) + if (parser.parseComma() || + spirv::parseEnumKeywordAttr(scope, parser, "scope ")) return Type(); if (parser.parseGreater()) @@ -407,10 +412,11 @@ return Type(); MatrixLayout matrixLayout; if (parser.parseComma() || - parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout ")) + spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout ")) return Type(); Scope scope; - if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope ")) + if (parser.parseComma() || + spirv::parseEnumKeywordAttr(scope, parser, "scope ")) return Type(); if (parser.parseGreater()) return Type(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h @@ -0,0 +1,32 @@ +//===- SPIRVOpUtils.h - MLIR SPIR-V Dialect Op Definition Utilities -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +namespace mlir::spirv { + +/// Returns the bit width of the `type`. +inline unsigned getBitWidth(Type type) { + if (isa(type)) { + // Just return 64 bits for pointer types for now. + // TODO: Make sure not caller relies on the actual pointer width value. + return 64; + } + + if (type.isIntOrFloat()) + return type.getIntOrFloatBitWidth(); + + if (auto vectorType = dyn_cast(type)) { + assert(vectorType.getElementType().isIntOrFloat()); + return vectorType.getNumElements() * + vectorType.getElementType().getIntOrFloatBitWidth(); + } + llvm_unreachable("unhandled bit width computation for type"); +} + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -12,7 +12,9 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" -#include "mlir/Dialect/SPIRV/IR/ParserUtils.h" +#include "SPIRVOpUtils.h" +#include "SPIRVParsingUtils.h" + #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" @@ -33,41 +35,12 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/FormatVariadic.h" #include #include #include using namespace mlir; - -// TODO: generate these strings using ODS. -constexpr char kAlignmentAttrName[] = "alignment"; -constexpr char kBranchWeightAttrName[] = "branch_weights"; -constexpr char kCallee[] = "callee"; -constexpr char kClusterSize[] = "cluster_size"; -constexpr char kControl[] = "control"; -constexpr char kDefaultValueAttrName[] = "default_value"; -constexpr char kEqualSemanticsAttrName[] = "equal_semantics"; -constexpr char kExecutionScopeAttrName[] = "execution_scope"; -constexpr char kFnNameAttrName[] = "fn"; -constexpr char kGroupOperationAttrName[] = "group_operation"; -constexpr char kIndicesAttrName[] = "indices"; -constexpr char kInitializerAttrName[] = "initializer"; -constexpr char kInterfaceAttrName[] = "interface"; -constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout"; -constexpr char kMemoryAccessAttrName[] = "memory_access"; -constexpr char kMemoryOperandAttrName[] = "memory_operand"; -constexpr char kMemoryScopeAttrName[] = "memory_scope"; -constexpr char kPackedVectorFormatAttrName[] = "format"; -constexpr char kSemanticsAttrName[] = "semantics"; -constexpr char kSourceAlignmentAttrName[] = "source_alignment"; -constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access"; -constexpr char kSpecIdAttrName[] = "spec_id"; -constexpr char kTypeAttrName[] = "type"; -constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics"; -constexpr char kValueAttrName[] = "value"; -constexpr char kValuesAttrName[] = "values"; -constexpr char kCompositeSpecConstituentsName[] = "constituents"; +using namespace mlir::spirv::AttrNames; //===----------------------------------------------------------------------===// // Common utility functions @@ -158,79 +131,6 @@ return success(); } -template -static ArrayAttr -getStrArrayAttrForEnumList(Builder &builder, ArrayRef enumValues, - function_ref stringifyFn) { - if (enumValues.empty()) { - return nullptr; - } - SmallVector enumValStrs; - enumValStrs.reserve(enumValues.size()); - for (auto val : enumValues) { - enumValStrs.emplace_back(stringifyFn(val)); - } - return builder.getStrArrayAttr(enumValStrs); -} - -/// Parses the next string attribute in `parser` as an enumerant of the given -/// `EnumClass`. -template -static ParseResult -parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, - StringRef attrName = spirv::attributeName()) { - static_assert(std::is_enum_v); - Attribute attrVal; - NamedAttrList attr; - auto loc = parser.getCurrentLocation(); - if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), - attrName, attr)) - return failure(); - if (!llvm::isa(attrVal)) - return parser.emitError(loc, "expected ") - << attrName << " attribute specified as string"; - auto attrOptional = spirv::symbolizeEnum( - llvm::cast(attrVal).getValue()); - if (!attrOptional) - return parser.emitError(loc, "invalid ") - << attrName << " attribute specification: " << attrVal; - value = *attrOptional; - return success(); -} - -/// Parses the next string attribute in `parser` as an enumerant of the given -/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer -/// attribute with the enum class's name as attribute name. -template -static ParseResult -parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, - StringRef attrName = spirv::attributeName()) { - static_assert(std::is_enum_v); - if (parseEnumStrAttr(value, parser)) - return failure(); - state.addAttribute(attrName, - parser.getBuilder().getAttr(value)); - return success(); -} - -/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` -/// and inserts the enumerant into `state` as an 32-bit integer attribute with -/// the enum class's name as attribute name. -template -static ParseResult -parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, - OperationState &state, - StringRef attrName = spirv::attributeName()) { - static_assert(std::is_enum_v); - if (parseEnumKeywordAttr(value, parser)) - return failure(); - state.addAttribute(attrName, - parser.getBuilder().getAttr(value)); - return success(); -} - /// Parses Function, Selection and Loop control attributes. If no control is /// specified, "None" is used as a default. template @@ -240,7 +140,7 @@ if (succeeded(parser.parseOptionalKeyword(kControl))) { EnumClass control; if (parser.parseLParen() || - parseEnumKeywordAttr(control, parser, state) || + spirv::parseEnumKeywordAttr(control, parser, state) || parser.parseRParen()) return failure(); return success(); @@ -252,40 +152,6 @@ return success(); } -/// Parses optional memory access (a.k.a. memory operand) attributes attached to -/// a memory access operand/pointer. Specifically, parses the following syntax: -/// (`[` memory-access `]`)? -/// where: -/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` -/// integer-literal | `"NonTemporal"` -static ParseResult -parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state, - StringRef attrName = kMemoryAccessAttrName) { - // Parse an optional list of attributes staring with '[' - if (parser.parseOptionalLSquare()) { - // Nothing to do - return success(); - } - - spirv::MemoryAccess memoryAccessAttr; - if (parseEnumStrAttr(memoryAccessAttr, parser, state, - attrName)) - return failure(); - - if (spirv::bitEnumContainsAll(memoryAccessAttr, - spirv::MemoryAccess::Aligned)) { - // Parse integer attribute for alignment. - Attribute alignmentAttr; - Type i32Type = parser.getBuilder().getIntegerType(32); - if (parser.parseComma() || - parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName, - state.attributes)) { - return failure(); - } - } - return parser.parseRSquare(); -} - // TODO Make sure to merge this and the previous function into one template // parameterized by memory access attribute name and alignment. Doing so now // results in VS2017 in producing an internal error (at the call site) that's @@ -299,8 +165,8 @@ } spirv::MemoryAccess memoryAccessAttr; - if (parseEnumStrAttr(memoryAccessAttr, parser, state, - kSourceMemoryAccessAttrName)) + if (spirv::parseEnumStrAttr( + memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName)) return failure(); if (spirv::bitEnumContainsAll(memoryAccessAttr, @@ -683,25 +549,6 @@ printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } -// Get bit width of types. -static unsigned getBitWidth(Type type) { - if (llvm::isa(type)) { - // Just return 64 bits for pointer types for now. - // TODO: Make sure not caller relies on the actual pointer width value. - return 64; - } - - if (type.isIntOrFloat()) - return type.getIntOrFloatBitWidth(); - - if (auto vectorType = llvm::dyn_cast(type)) { - assert(vectorType.getElementType().isIntOrFloat()); - return vectorType.getNumElements() * - vectorType.getElementType().getIntOrFloatBitWidth(); - } - llvm_unreachable("unhandled bit width computation for type"); -} - /// Walks the given type hierarchy with the given indices, potentially down /// to component granularity, to select an element type. Returns null type and /// emits errors with the given loc on failure. @@ -839,10 +686,10 @@ OpAsmParser::UnresolvedOperand ptrInfo, valueInfo; Type type; SMLoc loc; - if (parseEnumStrAttr(scope, parser, state, - kMemoryScopeAttrName) || - parseEnumStrAttr(memoryScope, parser, state, - kSemanticsAttrName) || + if (spirv::parseEnumStrAttr(scope, parser, state, + kMemoryScopeAttrName) || + spirv::parseEnumStrAttr( + memoryScope, parser, state, kSemanticsAttrName) || parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) || parser.getCurrentLocation(&loc) || parser.parseColonType(type)) return failure(); @@ -916,10 +763,10 @@ spirv::Scope executionScope; spirv::GroupOperation groupOperation; OpAsmParser::UnresolvedOperand valueInfo; - if (parseEnumStrAttr(executionScope, parser, state, - kExecutionScopeAttrName) || - parseEnumStrAttr(groupOperation, parser, state, - kGroupOperationAttrName) || + if (spirv::parseEnumStrAttr(executionScope, parser, state, + kExecutionScopeAttrName) || + spirv::parseEnumStrAttr( + groupOperation, parser, state, kGroupOperationAttrName) || parser.parseOperand(valueInfo)) return failure(); @@ -1199,11 +1046,11 @@ spirv::MemorySemantics equalSemantics, unequalSemantics; SmallVector operandInfo; Type type; - if (parseEnumStrAttr(memoryScope, parser, state, - kMemoryScopeAttrName) || - parseEnumStrAttr( + if (spirv::parseEnumStrAttr(memoryScope, parser, state, + kMemoryScopeAttrName) || + spirv::parseEnumStrAttr( equalSemantics, parser, state, kEqualSemanticsAttrName) || - parseEnumStrAttr( + spirv::parseEnumStrAttr( unequalSemantics, parser, state, kUnequalSemanticsAttrName) || parser.parseOperandList(operandInfo, 3)) return failure(); @@ -3478,10 +3325,10 @@ // Parse attributes spirv::AddressingModel addrModel; spirv::MemoryModel memoryModel; - if (::parseEnumKeywordAttr(addrModel, parser, - result) || - ::parseEnumKeywordAttr(memoryModel, parser, - result)) + if (spirv::parseEnumKeywordAttr(addrModel, parser, + result) || + spirv::parseEnumKeywordAttr(memoryModel, parser, + result)) return failure(); if (succeeded(parser.parseOptionalKeyword("requires"))) { @@ -4028,364 +3875,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// spirv.KHR.CooperativeMatrixLength -//===----------------------------------------------------------------------===// - -LogicalResult spirv::KHRCooperativeMatrixLengthOp::verify() { - if (!isa(getCooperativeMatrixType())) { - return emitOpError( - "type attribute must be a '!spirv.coopmatrix' type, found ") - << getCooperativeMatrixType() << " instead"; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.KHR.CooperativeMatrixLoad -//===----------------------------------------------------------------------===// - -ParseResult spirv::KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser, - OperationState &result) { - std::array operandInfo = {}; - if (parser.parseOperand(operandInfo[0]) || parser.parseComma()) - return failure(); - if (parser.parseOperand(operandInfo[1]) || parser.parseComma()) - return failure(); - - spirv::CooperativeMatrixLayoutKHR layout; - if (::parseEnumKeywordAttr( - layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) { - return failure(); - } - - if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName)) - return failure(); - - Type ptrType; - Type elementType; - if (parser.parseColon() || parser.parseType(ptrType) || - parser.parseKeywordType("as", elementType)) { - return failure(); - } - result.addTypes(elementType); - - Type strideType = parser.getBuilder().getIntegerType(32); - if (parser.resolveOperands(operandInfo, {ptrType, strideType}, - parser.getNameLoc(), result.operands)) { - return failure(); - } - - return success(); -} - -void spirv::KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { - printer << " " << getPointer() << ", " << getStride() << ", " - << getMatrixLayout(); - // Print optional memory operand attribute. - if (auto memOperand = getMemoryOperand()) - printer << " [\"" << memOperand << "\"]"; - printer << " : " << getPointer().getType() << " as " << getType(); -} - -static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, - Type coopMatrix) { - auto pointerType = cast(pointer); - Type pointeeType = pointerType.getPointeeType(); - if (!isa(pointeeType)) { - return op->emitError( - "Pointer must point to a scalar or vector type but provided ") - << pointeeType; - } - - // TODO: Verify the memory object behind the pointer: - // > If the Shader capability was declared, Pointer must point into an array - // > and any ArrayStride decoration on Pointer is ignored. - - return success(); -} - -LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() { - return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), - getResult().getType()); -} - -//===----------------------------------------------------------------------===// -// spirv.KHR.CooperativeMatrixStore -//===----------------------------------------------------------------------===// - -ParseResult spirv::KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser, - OperationState &result) { - std::array operandInfo = {}; - for (auto &op : operandInfo) { - if (parser.parseOperand(op) || parser.parseComma()) - return failure(); - } - - spirv::CooperativeMatrixLayoutKHR layout; - if (::parseEnumKeywordAttr( - layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) { - return failure(); - } - - if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName)) - return failure(); - - Type ptrType; - Type objectType; - if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() || - parser.parseType(objectType)) { - return failure(); - } - - Type strideType = parser.getBuilder().getIntegerType(32); - if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType}, - parser.getNameLoc(), result.operands)) { - return failure(); - } - - return success(); -} - -void spirv::KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { - printer << " " << getPointer() << ", " << getObject() << ", " << getStride() - << ", " << getMatrixLayout(); - - // Print optional memory operand attribute. - if (auto memOperand = getMemoryOperand()) - printer << " [\"" << *memOperand << "\"]"; - printer << " : " << getPointer().getType() << ", " << getObject().getType(); -} - -LogicalResult spirv::KHRCooperativeMatrixStoreOp::verify() { - return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), - getObject().getType()); -} - -//===----------------------------------------------------------------------===// -// spirv.NV.CooperativeMatrixLength -//===----------------------------------------------------------------------===// - -LogicalResult spirv::NVCooperativeMatrixLengthOp::verify() { - if (!isa(getCooperativeMatrixType())) { - return emitOpError( - "type attribute must be a '!spirv.NV.coopmatrix' type, found ") - << getCooperativeMatrixType() << " instead"; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.NV.CooperativeMatrixLoad -//===----------------------------------------------------------------------===// - -ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector operandInfo; - Type strideType = parser.getBuilder().getIntegerType(32); - Type columnMajorType = parser.getBuilder().getIntegerType(1); - Type ptrType; - Type elementType; - if (parser.parseOperandList(operandInfo, 3) || - parseMemoryAccessAttributes(parser, result) || parser.parseColon() || - parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) { - return failure(); - } - if (parser.resolveOperands(operandInfo, - {ptrType, strideType, columnMajorType}, - parser.getNameLoc(), result.operands)) { - return failure(); - } - - result.addTypes(elementType); - return success(); -} - -void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { - printer << " " << getPointer() << ", " << getStride() << ", " - << getColumnmajor(); - // Print optional memory access attribute. - if (auto memAccess = getMemoryAccess()) - printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << getPointer().getType() << " as " << getType(); -} - -static LogicalResult -verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) { - Type pointeeType = llvm::cast(pointer).getPointeeType(); - if (!llvm::isa(pointeeType) && - !llvm::isa(pointeeType)) - return op->emitError( - "Pointer must point to a scalar or vector type but provided ") - << pointeeType; - spirv::StorageClass storage = - llvm::cast(pointer).getStorageClass(); - if (storage != spirv::StorageClass::Workgroup && - storage != spirv::StorageClass::StorageBuffer && - storage != spirv::StorageClass::PhysicalStorageBuffer) - return op->emitError( - "Pointer storage class must be Workgroup, StorageBuffer or " - "PhysicalStorageBufferEXT but provided ") - << stringifyStorageClass(storage); - return success(); -} - -LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() { - return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(), - getResult().getType()); -} - -//===----------------------------------------------------------------------===// -// spirv.NV.CooperativeMatrixStore -//===----------------------------------------------------------------------===// - -ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector operandInfo; - Type strideType = parser.getBuilder().getIntegerType(32); - Type columnMajorType = parser.getBuilder().getIntegerType(1); - Type ptrType; - Type elementType; - if (parser.parseOperandList(operandInfo, 4) || - parseMemoryAccessAttributes(parser, result) || parser.parseColon() || - parser.parseType(ptrType) || parser.parseComma() || - parser.parseType(elementType)) { - return failure(); - } - if (parser.resolveOperands( - operandInfo, {ptrType, elementType, strideType, columnMajorType}, - parser.getNameLoc(), result.operands)) { - return failure(); - } - - return success(); -} - -void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { - printer << " " << getPointer() << ", " << getObject() << ", " << getStride() - << ", " << getColumnmajor(); - // Print optional memory access attribute. - if (auto memAccess = getMemoryAccess()) - printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << getPointer().getType() << ", " << getOperand(1).getType(); -} - -LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() { - return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(), - getObject().getType()); -} - -//===----------------------------------------------------------------------===// -// spirv.NV.CooperativeMatrixMulAdd -//===----------------------------------------------------------------------===// - -static LogicalResult -verifyCoopMatrixMulAddNV(spirv::NVCooperativeMatrixMulAddOp op) { - if (op.getC().getType() != op.getResult().getType()) - return op.emitOpError("result and third operand must have the same type"); - auto typeA = llvm::cast(op.getA().getType()); - auto typeB = llvm::cast(op.getB().getType()); - auto typeC = llvm::cast(op.getC().getType()); - auto typeR = - llvm::cast(op.getResult().getType()); - if (typeA.getRows() != typeR.getRows() || - typeA.getColumns() != typeB.getRows() || - typeB.getColumns() != typeR.getColumns()) - return op.emitOpError("matrix size must match"); - if (typeR.getScope() != typeA.getScope() || - typeR.getScope() != typeB.getScope() || - typeR.getScope() != typeC.getScope()) - return op.emitOpError("matrix scope must match"); - auto elementTypeA = typeA.getElementType(); - auto elementTypeB = typeB.getElementType(); - if (isa(elementTypeA) && isa(elementTypeB)) { - if (llvm::cast(elementTypeA).getWidth() != - llvm::cast(elementTypeB).getWidth()) - return op.emitOpError( - "matrix A and B integer element types must be the same bit width"); - } else if (elementTypeA != elementTypeB) { - return op.emitOpError( - "matrix A and B non-integer element types must match"); - } - if (typeR.getElementType() != typeC.getElementType()) - return op.emitOpError("matrix accumulator element type must match"); - return success(); -} - -LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() { - return verifyCoopMatrixMulAddNV(*this); -} - -//===----------------------------------------------------------------------===// -// spirv.INTEL.JointMatrixLoad -//===----------------------------------------------------------------------===// - -static LogicalResult -verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { - Type pointeeType = llvm::cast(pointer).getPointeeType(); - if (!llvm::isa(pointeeType) && - !llvm::isa(pointeeType)) - return op->emitError( - "Pointer must point to a scalar or vector type but provided ") - << pointeeType; - spirv::StorageClass storage = - llvm::cast(pointer).getStorageClass(); - if (storage != spirv::StorageClass::Workgroup && - storage != spirv::StorageClass::CrossWorkgroup && - storage != spirv::StorageClass::UniformConstant && - storage != spirv::StorageClass::Generic) - return op->emitError("Pointer storage class must be Workgroup or " - "CrossWorkgroup but provided ") - << stringifyStorageClass(storage); - return success(); -} - -LogicalResult spirv::INTELJointMatrixLoadOp::verify() { - return verifyPointerAndJointMatrixType(*this, getPointer().getType(), - getResult().getType()); -} - -//===----------------------------------------------------------------------===// -// spirv.INTEL.JointMatrixStore -//===----------------------------------------------------------------------===// - -LogicalResult spirv::INTELJointMatrixStoreOp::verify() { - return verifyPointerAndJointMatrixType(*this, getPointer().getType(), - getObject().getType()); -} - -//===----------------------------------------------------------------------===// -// spirv.INTEL.JointMatrixMad -//===----------------------------------------------------------------------===// - -static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { - if (op.getC().getType() != op.getResult().getType()) - return op.emitOpError("result and third operand must have the same type"); - auto typeA = llvm::cast(op.getA().getType()); - auto typeB = llvm::cast(op.getB().getType()); - auto typeC = llvm::cast(op.getC().getType()); - auto typeR = - llvm::cast(op.getResult().getType()); - if (typeA.getRows() != typeR.getRows() || - typeA.getColumns() != typeB.getRows() || - typeB.getColumns() != typeR.getColumns()) - return op.emitOpError("matrix size must match"); - if (typeR.getScope() != typeA.getScope() || - typeR.getScope() != typeB.getScope() || - typeR.getScope() != typeC.getScope()) - return op.emitOpError("matrix scope must match"); - if (typeA.getElementType() != typeB.getElementType() || - typeR.getElementType() != typeC.getElementType()) - return op.emitOpError("matrix element type must match"); - return success(); -} - -LogicalResult spirv::INTELJointMatrixMadOp::verify() { - return verifyJointMatrixMad(*this); -} - //===----------------------------------------------------------------------===// // spirv.MatrixTimesScalar //===----------------------------------------------------------------------===// @@ -5058,140 +4547,6 @@ LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); } -//===----------------------------------------------------------------------===// -// Integer Dot Product ops -//===----------------------------------------------------------------------===// - -static LogicalResult verifyIntegerDotProduct(Operation *op) { - assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) && - "Not an integer dot product op?"); - assert(op->getNumResults() == 1 && "Expected a single result"); - - Type factorTy = op->getOperand(0).getType(); - if (op->getOperand(1).getType() != factorTy) - return op->emitOpError("requires the same type for both vector operands"); - - unsigned expectedNumAttrs = 0; - if (auto intTy = llvm::dyn_cast(factorTy)) { - ++expectedNumAttrs; - auto packedVectorFormat = - llvm::dyn_cast_or_null( - op->getAttr(kPackedVectorFormatAttrName)); - if (!packedVectorFormat) - return op->emitOpError("requires Packed Vector Format attribute for " - "integer vector operands"); - - assert(packedVectorFormat.getValue() == - spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && - "Unknown Packed Vector Format"); - if (intTy.getWidth() != 32) - return op->emitOpError( - llvm::formatv("with specified Packed Vector Format ({0}) requires " - "integer vector operands to be 32-bits wide", - packedVectorFormat.getValue())); - } else { - if (op->hasAttr(kPackedVectorFormatAttrName)) - return op->emitOpError(llvm::formatv( - "with invalid format attribute for vector operands of type '{0}'", - factorTy)); - } - - if (op->getAttrs().size() > expectedNumAttrs) - return op->emitError( - "op only supports the 'format' #spirv.packed_vector_format attribute"); - - Type resultTy = op->getResultTypes().front(); - bool hasAccumulator = op->getNumOperands() == 3; - if (hasAccumulator && op->getOperand(2).getType() != resultTy) - return op->emitOpError( - "requires the same accumulator operand and result types"); - - unsigned factorBitWidth = getBitWidth(factorTy); - unsigned resultBitWidth = getBitWidth(resultTy); - if (factorBitWidth > resultBitWidth) - return op->emitOpError( - llvm::formatv("result type has insufficient bit-width ({0} bits) " - "for the specified vector operand type ({1} bits)", - resultBitWidth, factorBitWidth)); - - return success(); -} - -static std::optional getIntegerDotProductMinVersion() { - return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. -} - -static std::optional getIntegerDotProductMaxVersion() { - return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. -} - -static SmallVector, 1> -getIntegerDotProductExtensions() { - // Requires the SPV_KHR_integer_dot_product extension, specified either - // explicitly or implied by target env's SPIR-V version >= 1.6. - static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product; - return {extension}; -} - -static SmallVector, 1> -getIntegerDotProductCapabilities(Operation *op) { - // Requires the the DotProduct capability and capabilities that depend on - // exact op types. - static const auto dotProductCap = spirv::Capability::DotProduct; - static const auto dotProductInput4x8BitPackedCap = - spirv::Capability::DotProductInput4x8BitPacked; - static const auto dotProductInput4x8BitCap = - spirv::Capability::DotProductInput4x8Bit; - static const auto dotProductInputAllCap = - spirv::Capability::DotProductInputAll; - - SmallVector, 1> capabilities = {dotProductCap}; - - Type factorTy = op->getOperand(0).getType(); - if (auto intTy = llvm::dyn_cast(factorTy)) { - auto formatAttr = llvm::cast( - op->getAttr(kPackedVectorFormatAttrName)); - if (formatAttr.getValue() == - spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) - capabilities.push_back(dotProductInput4x8BitPackedCap); - - return capabilities; - } - - auto vecTy = llvm::cast(factorTy); - if (vecTy.getElementTypeBitWidth() == 8) { - capabilities.push_back(dotProductInput4x8BitCap); - return capabilities; - } - - capabilities.push_back(dotProductInputAllCap); - return capabilities; -} - -#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \ - LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \ - SmallVector, 1> OpName::getExtensions() { \ - return getIntegerDotProductExtensions(); \ - } \ - SmallVector, 1> OpName::getCapabilities() { \ - return getIntegerDotProductCapabilities(*this); \ - } \ - std::optional OpName::getMinVersion() { \ - return getIntegerDotProductMinVersion(); \ - } \ - std::optional OpName::getMaxVersion() { \ - return getIntegerDotProductMaxVersion(); \ - } - -SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp) -SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp) -SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp) -SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp) -SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp) -SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp) - -#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP - // TableGen'erated operation interfaces for querying versions, extensions, and // capabilities. #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h @@ -0,0 +1,156 @@ +//===- SPIRVParsingUtils.h - MLIR SPIR-V Dialect Parsing Utilities --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include + +namespace mlir::spirv { +namespace AttrNames { +// TODO: generate these strings using ODS. +inline constexpr char kAlignmentAttrName[] = "alignment"; +inline constexpr char kBranchWeightAttrName[] = "branch_weights"; +inline constexpr char kCallee[] = "callee"; +inline constexpr char kClusterSize[] = "cluster_size"; +inline constexpr char kControl[] = "control"; +inline constexpr char kDefaultValueAttrName[] = "default_value"; +inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics"; +inline constexpr char kExecutionScopeAttrName[] = "execution_scope"; +inline constexpr char kFnNameAttrName[] = "fn"; +inline constexpr char kGroupOperationAttrName[] = "group_operation"; +inline constexpr char kIndicesAttrName[] = "indices"; +inline constexpr char kInitializerAttrName[] = "initializer"; +inline constexpr char kInterfaceAttrName[] = "interface"; +inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout"; +inline constexpr char kMemoryAccessAttrName[] = "memory_access"; +inline constexpr char kMemoryOperandAttrName[] = "memory_operand"; +inline constexpr char kMemoryScopeAttrName[] = "memory_scope"; +inline constexpr char kPackedVectorFormatAttrName[] = "format"; +inline constexpr char kSemanticsAttrName[] = "semantics"; +inline constexpr char kSourceAlignmentAttrName[] = "source_alignment"; +inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access"; +inline constexpr char kSpecIdAttrName[] = "spec_id"; +inline constexpr char kTypeAttrName[] = "type"; +inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics"; +inline constexpr char kValueAttrName[] = "value"; +inline constexpr char kValuesAttrName[] = "values"; +inline constexpr char kCompositeSpecConstituentsName[] = "constituents"; +} // namespace AttrNames + +template +ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef enumValues, + function_ref stringifyFn) { + if (enumValues.empty()) { + return nullptr; + } + SmallVector enumValStrs; + enumValStrs.reserve(enumValues.size()); + for (auto val : enumValues) { + enumValStrs.emplace_back(stringifyFn(val)); + } + return builder.getStrArrayAttr(enumValStrs); +} + +/// Parses the next keyword in `parser` as an enumerant of the given +/// `EnumClass`. +template +ParseResult +parseEnumKeywordAttr(EnumClass &value, ParserType &parser, + StringRef attrName = spirv::attributeName()) { + StringRef keyword; + SmallVector attr; + auto loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&keyword)) + return failure(); + + if (std::optional attr = + spirv::symbolizeEnum(keyword)) { + value = *attr; + return success(); + } + return parser.emitError(loc, "invalid ") + << attrName << " attribute specification: " << keyword; +} + +/// Parses the next string attribute in `parser` as an enumerant of the given +/// `EnumClass`. +template +ParseResult +parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, + StringRef attrName = spirv::attributeName()) { + static_assert(std::is_enum_v); + Attribute attrVal; + NamedAttrList attr; + auto loc = parser.getCurrentLocation(); + if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), + attrName, attr)) + return failure(); + if (!llvm::isa(attrVal)) + return parser.emitError(loc, "expected ") + << attrName << " attribute specified as string"; + auto attrOptional = spirv::symbolizeEnum( + llvm::cast(attrVal).getValue()); + if (!attrOptional) + return parser.emitError(loc, "invalid ") + << attrName << " attribute specification: " << attrVal; + value = *attrOptional; + return success(); +} + +/// Parses the next string attribute in `parser` as an enumerant of the given +/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer +/// attribute with the enum class's name as attribute name. +template +ParseResult +parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, + StringRef attrName = spirv::attributeName()) { + static_assert(std::is_enum_v); + if (parseEnumStrAttr(value, parser, attrName)) + return failure(); + state.addAttribute(attrName, + parser.getBuilder().getAttr(value)); + return success(); +} + +/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` +/// and inserts the enumerant into `state` as an 32-bit integer attribute with +/// the enum class's name as attribute name. +template +ParseResult +parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, + OperationState &state, + StringRef attrName = spirv::attributeName()) { + static_assert(std::is_enum_v); + if (parseEnumKeywordAttr(value, parser)) + return failure(); + state.addAttribute(attrName, + parser.getBuilder().getAttr(value)); + return success(); +} + +/// Parses optional memory access (a.k.a. memory operand) attributes attached to +/// a memory access operand/pointer. Specifically, parses the following syntax: +/// (`[` memory-access `]`)? +/// where: +/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` +/// integer-literal | `"NonTemporal"` +ParseResult parseMemoryAccessAttributes( + OpAsmParser &parser, OperationState &state, + StringRef attrName = AttrNames::kMemoryAccessAttrName); + +} // namespace mlir::spirv diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp @@ -0,0 +1,48 @@ +//===- SPIRVParsingUtilities.cpp - MLIR SPIR-V Dialect Parsing Utils-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements common SPIR-V dialect parsing functions. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVParsingUtils.h" + +using namespace mlir::spirv::AttrNames; + +namespace mlir::spirv { + +ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, + OperationState &state, + StringRef attrName) { + // Parse an optional list of attributes staring with '[' + if (parser.parseOptionalLSquare()) { + // Nothing to do + return success(); + } + + spirv::MemoryAccess memoryAccessAttr; + if (spirv::parseEnumStrAttr(memoryAccessAttr, parser, + state, attrName)) + return failure(); + + if (spirv::bitEnumContainsAll(memoryAccessAttr, + spirv::MemoryAccess::Aligned)) { + // Parse integer attribute for alignment. + Attribute alignmentAttr; + Type i32Type = parser.getBuilder().getIntegerType(32); + if (parser.parseComma() || + parser.parseAttribute(alignmentAttr, i32Type, + AttrNames::kAlignmentAttrName, + state.attributes)) { + return failure(); + } + } + return parser.parseRSquare(); +} + +} // namespace mlir::spirv