Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp
- This file was added.
//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===// | |||||
// | |||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||||
// See https://llvm.org/LICENSE.txt for license information. | |||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
// | |||||
// This file implements the ArmSME dialect and its operations. | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
#include "mlir/Dialect/ArmSME/ArmSMEDialect.h" | |||||
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" | |||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" | |||||
#include "mlir/IR/Builders.h" | |||||
#include "mlir/IR/DialectImplementation.h" | |||||
#include "mlir/IR/OpImplementation.h" | |||||
#include "mlir/IR/TypeUtilities.h" | |||||
#include "llvm/ADT/STLExtras.h" | |||||
#include "llvm/ADT/TypeSwitch.h" | |||||
using namespace mlir; | |||||
using namespace mlir::arm_sme; | |||||
//===----------------------------------------------------------------------===// | |||||
// Custom printer/parser for list of SME Tile enums | |||||
//===----------------------------------------------------------------------===// | |||||
namespace { | |||||
void printTileEnumList(OpAsmPrinter &printer, Operation *op, ArrayAttr tiles) { | |||||
(void)op; | |||||
llvm::interleaveComma(tiles, printer, [&](Attribute elem) { | |||||
auto tile = elem.cast<TileEnumAttr>().getValue(); | |||||
printer << stringifyTileEnum(tile); | |||||
}); | |||||
} | |||||
ParseResult parseTileEnumList(OpAsmParser &parser, ArrayAttr &tiles) { | |||||
SmallVector<Attribute> tileStorage; | |||||
auto parseTileEnumAttr = [&]() -> ParseResult { | |||||
StringRef keyword; | |||||
if (parser.parseKeyword(&keyword)) | |||||
return failure(); | |||||
Optional<TileEnum> maybeTile = symbolizeTileEnum(keyword); | |||||
if (!maybeTile) | |||||
return parser.emitError(parser.getCurrentLocation(), | |||||
"invalid SME tile name"); | |||||
auto tileAttr = TileEnumAttr::get(parser.getContext(), *maybeTile); | |||||
tileStorage.push_back(tileAttr); | |||||
return success(); | |||||
}; | |||||
auto loc = parser.getCurrentLocation(); | |||||
if (parser.parseCommaSeparatedList(parseTileEnumAttr)) | |||||
return parser.emitError(loc, "expected list of SME tiles"); | |||||
tiles = ArrayAttr::get(parser.getContext(), tileStorage); | |||||
return success(); | |||||
} | |||||
} // namespace | |||||
//===----------------------------------------------------------------------===// | |||||
// Tablegen Definitions | |||||
//===----------------------------------------------------------------------===// | |||||
#include "mlir/Dialect/ArmSME/ArmSMEDialect.cpp.inc" | |||||
#include "mlir/Dialect/ArmSME/ArmSMEEnums.cpp.inc" | |||||
#define GET_OP_CLASSES | |||||
#include "mlir/Dialect/ArmSME/ArmSME.cpp.inc" | |||||
#define GET_TYPEDEF_CLASSES | |||||
#include "mlir/Dialect/ArmSME/ArmSMETypes.cpp.inc" | |||||
void ArmSMEDialect::initialize() { | |||||
addOperations< | |||||
#define GET_OP_LIST | |||||
#include "mlir/Dialect/ArmSME/ArmSME.cpp.inc" | |||||
>(); | |||||
} | |||||
//===----------------------------------------------------------------------===// | |||||
// Custom Verifier | |||||
//===----------------------------------------------------------------------===// | |||||
/// Additional verification of MOP ops | |||||
static LogicalResult verifyMOP(TileEnum tile, Type lhsTy, Type rhsTy, | |||||
bool isWidening, Operation *op) { | |||||
auto lhsVecTy = lhsTy.cast<VectorType>(); | |||||
auto rhsVecTy = lhsTy.cast<VectorType>(); | |||||
if (lhsVecTy.getNumScalableDims() != lhsVecTy.getRank() || | |||||
rhsVecTy.getNumScalableDims() != rhsVecTy.getRank()) | |||||
return op->emitOpError("expecting all dimensions to be scalable"); | |||||
Type lhsElTy = lhsVecTy.getElementType(); | |||||
Type rhsElTy = rhsVecTy.getElementType(); | |||||
const llvm::DenseSet<TileEnum> b32Tiles( | |||||
{TileEnum::za0s, TileEnum::za1s, TileEnum::za2s, TileEnum::za3s}); | |||||
const llvm::DenseSet<TileEnum> b64Tiles( | |||||
{TileEnum::za0d, TileEnum::za1d, TileEnum::za2d, TileEnum::za3d, | |||||
TileEnum::za4d, TileEnum::za5d, TileEnum::za6d, TileEnum::za7d}); | |||||
// Verify element type width | |||||
unsigned elWidth = lhsElTy.getIntOrFloatBitWidth(); | |||||
if (elWidth != rhsElTy.getIntOrFloatBitWidth()) | |||||
return op->emitOpError("invalid vector element type"); | |||||
// Verify valid vector unit length: | |||||
constexpr unsigned sveUnitVecWidth = 128; | |||||
if (elWidth * lhsVecTy.getNumElements() != sveUnitVecWidth) | |||||
return op->emitOpError( | |||||
"expected operand vector length to be multiples of 128 bits"); | |||||
if (isWidening) { | |||||
// Check element types - integer types can be either signed or unsigned for | |||||
// both operands, otherwise the types must match. | |||||
if (lhsVecTy.getRank() != 2) | |||||
return op->emitOpError( | |||||
"expecting widening MOP ops to have 2D vector operands"); | |||||
auto lhsShape = lhsVecTy.getShape(); | |||||
if (lhsElTy.isBF16() || lhsElTy.isF16()) { | |||||
// widening fmop*/bfmop* | |||||
if (!b32Tiles.contains(tile)) | |||||
return op->emitOpError( | |||||
"expecting 16b float types to accumulate into 32b tiles"); | |||||
if (rhsElTy != lhsElTy) | |||||
return op->emitOpError("mismatching lhs and rhs vector element types"); | |||||
if (lhsShape[0] != 4) | |||||
return op->emitOpError("invalid vector shape for widening MOP"); | |||||
} else if (lhsElTy.isInteger(8)) { | |||||
// 8->32-bit smop*/umop*/sumop*/usmop* | |||||
if (!b32Tiles.contains(tile)) | |||||
return op->emitOpError( | |||||
"expecting 8b int types to accumulate into 32b tiles"); | |||||
if (!rhsElTy.isInteger(8)) | |||||
return op->emitOpError( | |||||
"expecting lhs and rhs element types to be of same integer width"); | |||||
if (lhsShape[0] != 4) | |||||
return op->emitOpError("invalid vector shape for widening MOP"); | |||||
} else if (lhsElTy.isInteger(16)) { | |||||
// 16->64-bit smop*/umop*/sumop*/usmop* | |||||
if (!b64Tiles.contains(tile)) | |||||
return op->emitOpError( | |||||
"expecting 16b int types to accumulate into 64b tiles"); | |||||
if (!rhsElTy.isInteger(16)) | |||||
return op->emitOpError( | |||||
"expecting lhs and rhs element types to be of same integer width"); | |||||
if (lhsShape[0] != 2) | |||||
return op->emitOpError("invalid vector shape for widening MOP"); | |||||
} | |||||
return success(); | |||||
} | |||||
// non-widening fmop* | |||||
if (lhsVecTy != rhsVecTy) | |||||
return op->emitOpError("expecting lhs and rhs operands to have the same " | |||||
"type for non-widening MOP"); | |||||
if (lhsVecTy.getRank() != 1) | |||||
return op->emitOpError("expecting 1D vector operands for non-widening MOP"); | |||||
if (lhsVecTy.isF32() && !b32Tiles.contains(tile)) | |||||
return op->emitOpError("expecting f32 MOP to accumulate into 32b tiles"); | |||||
if (lhsVecTy.isF64() && !b64Tiles.contains(tile)) | |||||
return op->emitOpError("expecting f64 MOP to accumulate into 64b tiles"); | |||||
return success(); | |||||
} | |||||
LogicalResult MopaOp::verify() { | |||||
return verifyMOP(getTile(), getLhs().getType(), getRhs().getType(), | |||||
isWidening(), getOperation()); | |||||
} | |||||
LogicalResult MopsOp::verify() { | |||||
return verifyMOP(getTile(), getLhs().getType(), getRhs().getType(), | |||||
isWidening(), getOperation()); | |||||
} |