Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
- This file was added.
//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// | |||||
// | |||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||||
// See https://llvm.org/LICENSE.txt for license information. | |||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" | |||||
#include "mlir/Conversion/LLVMCommon/Pattern.h" | |||||
#include "mlir/Dialect/Arith/IR/Arith.h" | |||||
#include "mlir/Dialect/ArmSME/ArmSMEDialect.h" | |||||
#include "mlir/Dialect/ArmSME/Transforms.h" | |||||
#include "mlir/Dialect/Func/IR/FuncOps.h" | |||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | |||||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | |||||
#include "mlir/IR/BuiltinOps.h" | |||||
#include "mlir/IR/PatternMatch.h" | |||||
#include "llvm/ADT/SmallVector.h" | |||||
using namespace mlir; | |||||
using namespace mlir::arm_sme; | |||||
template <typename MOPTy> | |||||
class MOPLowering : public OpConversionPattern<MOPTy> { | |||||
using OpConversionPattern<MOPTy>::OpConversionPattern; | |||||
LogicalResult | |||||
matchAndRewrite(MOPTy op, typename MOPTy::Adaptor adaptor, | |||||
ConversionPatternRewriter &rewriter) const final { | |||||
(void)adaptor; | |||||
Location loc = op.getLoc(); | |||||
SmallVector<Value, 5> operands; | |||||
auto tile = static_cast<uint32_t>(op.getTile()); | |||||
// Operands: | |||||
// Tile number | |||||
operands.push_back( | |||||
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), tile) | |||||
.getResult()); | |||||
if (op.isWidening()) { | |||||
return op.emitOpError("lowering of widening SME outer product " | |||||
"instructions not yet supported"); | |||||
} | |||||
// Predicates | |||||
operands.push_back(op.getLhsPred()); | |||||
operands.push_back(op.getRhsPred()); | |||||
// Input vectors | |||||
operands.push_back(op.getLhs()); | |||||
operands.push_back(op.getRhs()); | |||||
Type lhsElTy = | |||||
op.getLhs().getType().template cast<VectorType>().getElementType(); | |||||
Type rhsElTy = | |||||
op.getRhs().getType().template cast<VectorType>().getElementType(); | |||||
ValueRange operandsRange(operands); | |||||
switch (op.isAccumulate()) { | |||||
case true: | |||||
// MOPA ops | |||||
if (lhsElTy.isF32() || lhsElTy.isF64()) | |||||
rewriter.create<FmopaIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isF16() || lhsElTy.isBF16()) | |||||
rewriter.create<FmopaWidenIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isSignedInteger() && rhsElTy.isSignedInteger()) | |||||
rewriter.create<SmopaIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isSignedInteger() && rhsElTy.isUnsignedInteger()) | |||||
rewriter.create<SUmopaIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isUnsignedInteger() && rhsElTy.isSignedInteger()) | |||||
rewriter.create<USmopaIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isUnsignedInteger() && rhsElTy.isUnsignedInteger()) | |||||
rewriter.create<UmopaIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else | |||||
return op.emitOpError("unsupported SME vector element type"); | |||||
break; | |||||
case false: | |||||
// MOPS ops | |||||
if (lhsElTy.isF32() || lhsElTy.isF64()) | |||||
rewriter.create<FmopsIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isF16() || lhsElTy.isBF16()) | |||||
rewriter.create<FmopsWidenIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isSignedInteger() && rhsElTy.isSignedInteger()) | |||||
rewriter.create<SmopsIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isSignedInteger() && rhsElTy.isUnsignedInteger()) | |||||
rewriter.create<SUmopsIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isUnsignedInteger() && rhsElTy.isSignedInteger()) | |||||
rewriter.create<USmopsIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else if (lhsElTy.isUnsignedInteger() && rhsElTy.isUnsignedInteger()) | |||||
rewriter.create<UmopsIntrOp>(loc, TypeRange{}, operandsRange); | |||||
else | |||||
return op.emitOpError("unsupported SME vector element type"); | |||||
} | |||||
rewriter.eraseOp(op); | |||||
return LogicalResult::success(); | |||||
} | |||||
}; | |||||
class ZeroOpLowering : public OpConversionPattern<ZeroOp> { | |||||
using OpConversionPattern<ZeroOp>::OpConversionPattern; | |||||
LogicalResult | |||||
matchAndRewrite(ZeroOp op, ZeroOpAdaptor adaptor, | |||||
ConversionPatternRewriter &rewriter) const final { | |||||
(void)adaptor; | |||||
Location loc = op.getLoc(); | |||||
ArrayAttr tiles = op.getTiles(); | |||||
uint32_t tileNum = 0; | |||||
for (auto tile : tiles) { | |||||
auto tileEnum = tile.cast<TileEnumAttr>().getValue(); | |||||
tileNum |= static_cast<uint32_t>(tileEnum); | |||||
} | |||||
Value tileVal = | |||||
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), tileNum); | |||||
rewriter.create<ZeroIntrOp>(loc, tileVal); | |||||
rewriter.eraseOp(op); | |||||
return LogicalResult::success(); | |||||
} | |||||
}; | |||||
/// Populate the given list with patterns that convert from ArmSME to LLVM. | |||||
void mlir::populateArmSMELegalizeForLLVMExportPatterns( | |||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) { | |||||
// Populate conversion patterns | |||||
// clang-format off | |||||
patterns.add<MOPLowering<MopaOp>, | |||||
MOPLowering<MopsOp>, | |||||
ZeroOpLowering>(converter, &converter.getContext()); | |||||
// clang-format on | |||||
} | |||||
void mlir::configureArmSMELegalizeForExportTarget( | |||||
LLVMConversionTarget &target) { | |||||
// clang-format off | |||||
target.addLegalOp<ZeroIntrOp, | |||||
FmopaIntrOp, FmopsIntrOp, | |||||
FmopaWidenIntrOp, FmopsWidenIntrOp, | |||||
SmopaIntrOp, SmopsIntrOp, | |||||
UmopaIntrOp, UmopsIntrOp, | |||||
SUmopaIntrOp, SUmopsIntrOp, | |||||
USmopaIntrOp, USmopsIntrOp, | |||||
LoadHorizontalBytesIntrOp, | |||||
LoadHorizontalHalfsIntrOp, | |||||
LoadHorizontalWordsIntrOp, | |||||
LoadHorizontalDoublesIntrOp, | |||||
LoadHorizontalQuadsIntrOp, | |||||
StoreVerticalBytesIntrOp, | |||||
StoreVerticalHalfsIntrOp, | |||||
StoreVerticalWordsIntrOp, | |||||
StoreVerticalDoublesIntrOp, | |||||
StoreVerticalQuadsIntrOp>(); | |||||
target.addIllegalOp<MopaOp, MopsOp>(); | |||||
// clang-format on | |||||
} |