Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/include/mlir/Dialect/ArmSME/ArmSME.td
- This file was added.
//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===// | |||||
// | |||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||||
// See https://llvm.org/LICENSE.txt for license information. | |||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
// | |||||
// This file defines the basic operations for the ArmSME dialect. | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
#ifndef ARMSME_OPS | |||||
#define ARMSME_OPS | |||||
include "mlir/Interfaces/SideEffectInterfaces.td" | |||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td" | |||||
//===----------------------------------------------------------------------===// | |||||
// ArmSME dialect definition | |||||
//===----------------------------------------------------------------------===// | |||||
def ArmSME_Dialect : Dialect { | |||||
let name = "arm_sme"; | |||||
let cppNamespace = "::mlir::arm_sme"; | |||||
let summary = "Basic dialect to target Arm SME architectures"; | |||||
let description = [{ | |||||
This dialect contains the definitions necessary to target specific Arm SME | |||||
scalable vector operations. | |||||
Source: | |||||
https://developer.arm.com/documentation/ddi0616/aa | |||||
}]; | |||||
let dependentDialects = ["arm_sve::ArmSVEDialect"]; | |||||
} | |||||
//===----------------------------------------------------------------------===// | |||||
// ArmSME Tile enum definitions | |||||
//===----------------------------------------------------------------------===// | |||||
def ZA0D : I32EnumAttrCase<"za0d", 1>; | |||||
def ZA1D : I32EnumAttrCase<"za1d", 2>; | |||||
def ZA2D : I32EnumAttrCase<"za2d", 4>; | |||||
def ZA3D : I32EnumAttrCase<"za3d", 8>; | |||||
def ZA4D : I32EnumAttrCase<"za4d", 16>; | |||||
def ZA5D : I32EnumAttrCase<"za5d", 32>; | |||||
def ZA6D : I32EnumAttrCase<"za6d", 64>; | |||||
def ZA7D : I32EnumAttrCase<"za7d", 128>; | |||||
def ZA0S : I32EnumAttrCase<"za0s", 17>; // = ZA0D | ZA4D | |||||
def ZA1S : I32EnumAttrCase<"za1s", 34>; // = ZA1D | ZA5D | |||||
def ZA2S : I32EnumAttrCase<"za2s", 68>; // = ZA2D | ZA6D | |||||
def ZA3S : I32EnumAttrCase<"za3s", 136>; // = ZA3D | ZA7D | |||||
def ArmSME_TileAttr : I32EnumAttr<"TileEnum", | |||||
"Enum representation the SME matrix tiles", | |||||
[ZA0D, ZA1D, ZA2D, ZA3D, ZA4D, ZA5D, ZA6D, | |||||
ZA7D, ZA0S, ZA1S, ZA2S, ZA3S]> { | |||||
let cppNamespace = "::mlir::arm_sme"; | |||||
} | |||||
//===----------------------------------------------------------------------===// | |||||
// ArmSME op definitions | |||||
//===----------------------------------------------------------------------===// | |||||
class ArmSME_Op<string mnemonic, list<Trait> traits = []> : | |||||
Op<ArmSME_Dialect, mnemonic, traits> {} | |||||
def Predicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>; | |||||
def SMEVector : ScalableVectorOfLengthAndType< | |||||
[16, 8, 4, 2], [SI8, SI16, UI8, UI16, BF16, F16, F32, F64]>; | |||||
def TileList : TypedArrayAttrBase<ArmSME_TileAttr, "list of SME matrix tiles">; | |||||
class MOPOpBase<string mnemonic, bit accumulate> | |||||
: ArmSME_Op<mnemonic, | |||||
[AllShapesMatch<["lhs", "lhsPred", "rhs", "rhsPred"]>]> { | |||||
let arguments = (ins | |||||
ArmSME_TileAttr:$tile, | |||||
Predicate:$lhsPred, | |||||
Predicate:$rhsPred, | |||||
SMEVector:$lhs, | |||||
SMEVector:$rhs | |||||
); | |||||
let extraClassDeclaration = [{ | |||||
bool isAccumulate() { return }] # accumulate # [{; | |||||
} | |||||
bool isSubtract() { return }] # !not(accumulate) # [{; } | |||||
bool isWidening() { | |||||
auto elTy = this->getLhs().getType().cast<VectorType>().getElementType(); | |||||
if (elTy.isF32() || elTy.isF64()) | |||||
return false; | |||||
else | |||||
return true; | |||||
} | |||||
}]; | |||||
let assemblyFormat =[{ $tile`,` $lhsPred`,` $rhsPred`,` $lhs`,` $rhs attr-dict | |||||
`:` type($lhsPred)`,` type($rhsPred)`,` type($lhs)`,` type($rhs) }]; | |||||
let hasVerifier = 1; | |||||
} | |||||
def MopaOp : MOPOpBase<"mopa", /*accumulate=*/true> { | |||||
let summary = "Vector-vector outer product and accumulate op"; | |||||
let description = [{ | |||||
MOPA: Outer product product accumulate. | |||||
peixin: typo? | |||||
This function maps to the *MOPA instructions, it takes scalable vector | |||||
operands which will be used to compute the outer product matrix. Two | |||||
masking predicate operands for each of the floating point operands will also | |||||
be provided, such that elements marked inactive by the predicate will not | |||||
update the corresponding row/column in the result matrix tile, specified by | |||||
the attribute. | |||||
Theere are two variations of MOPA instructions - widening and non-widening. | |||||
Non-widening MOPAs will take a 1D vector of f32 or f64 as input and | |||||
accumulate into 32b and 64b tiles respectively (za*s and za*d). | |||||
Widening MOPAs will pack two f16/bf16 or four (signed or unsigned) i8 | |||||
elements into a single 32b lane of the vector and accumulate into 32b tiles | |||||
(za*s); Or it will pack four (signed or unsigned) i16 elements into a 64b | |||||
lane and accumulate into 64b tiles (za*d). Hence widening MOPAs will take | |||||
2D scalable vectors as input, i.e. `<[4x2]xf16>, <[2x4]xsi16>, <[4x4]xsi8>` | |||||
Example: Assume `vscale == 2`, `%lhs = %rhs = <1, 2, 3, 4> : <[2]xfp64>`, | |||||
`%lhsPred = %rhsPred = <true, true, false, true>`, then: | |||||
``` | |||||
arm_sme.zero za0d | |||||
arm_sme.fmopa za0d, %lhsPred, %rhsPred, %lhs, %rhs | |||||
: vector<[2]xi1>, vector<[2]xf64> | |||||
``` | |||||
Would result in za0d containing: | |||||
``` | |||||
1 2 0 4 | |||||
2 4 0 8 | |||||
0 0 0 0 | |||||
4 8 0 16 | |||||
``` | |||||
}]; | |||||
} | |||||
def MopsOp : MOPOpBase<"mops", /*accumulate=*/false> { | |||||
let summary = "Vector-vector outer product and subtract op"; | |||||
let description = [{ | |||||
FMOPA: Outer product product accumulate. | |||||
peixinUnsubmitted Not Done ReplyInline Actionstypo? peixin: typo? | |||||
This function maps to the *MOPS instructions, it functions similarily to | |||||
the *MOPA instructions, but differs in that it subtracts the outer product | |||||
computed from the input vectors from the existing values within the tile | |||||
provided. | |||||
}]; | |||||
} | |||||
def ZeroOp : ArmSME_Op<"zero"> { | |||||
let summary = "Zeroes a list of SME matrix tiles"; | |||||
let description = [{ | |||||
ZERO: Sets the contents of specified matrix tiles to zero"; | |||||
Source: | |||||
https: // developer.arm.com/documentation/ddi0616/aa | |||||
}]; | |||||
let arguments = (ins TileList:$tiles); | |||||
let assemblyFormat = "custom<TileEnumList>($tiles) attr-dict"; | |||||
} | |||||
//===----------------------------------------------------------------------===// | |||||
// ArmSME Intrinsic op definitions | |||||
//===----------------------------------------------------------------------===// | |||||
class ArmSME_IntrOverloadedOp<string mnemonic, list<int> overloadOperands = []> | |||||
: LLVM_IntrOpBase< | |||||
/*Dialect dialect=*/ArmSME_Dialect, | |||||
/*string opName=*/"intr." #mnemonic, | |||||
/*string enumName=*/"aarch64_sme_" #!subst(".", "_", mnemonic), | |||||
/*list<int> overloadedResults=*/[], | |||||
/*list<int> overloadedOperands=*/overloadOperands, | |||||
/*list<Trait> traits=*/[], | |||||
/*int numResults=*/0>; | |||||
def ZeroIntrOp : ArmSME_IntrOverloadedOp<"zero">, | |||||
Arguments<(ins Arg<I32, "Tile register ID">)>; | |||||
class ArmSME_IntrMopOverloadedOp<string mnemonic> | |||||
: ArmSME_IntrOverloadedOp<mnemonic, [4]>, | |||||
Arguments<(ins Arg<I32, "Tile register ID">, | |||||
Arg<Predicate, "LHS predicate">, | |||||
Arg<Predicate, "RHS predicate">, | |||||
Arg<AnyScalableVector, "LHS vector operand">, | |||||
Arg<AnyScalableVector, "RHS vector operand">)>; | |||||
def FmopaIntrOp : ArmSME_IntrMopOverloadedOp<"mopa">; | |||||
def FmopsIntrOp : ArmSME_IntrMopOverloadedOp<"mops">; | |||||
def FmopaWidenIntrOp : ArmSME_IntrMopOverloadedOp<"mopa_wide">; | |||||
def FmopsWidenIntrOp : ArmSME_IntrMopOverloadedOp<"mops_wide">; | |||||
def SmopaIntrOp : ArmSME_IntrMopOverloadedOp<"smopa_wide">; | |||||
def SmopsIntrOp : ArmSME_IntrMopOverloadedOp<"smops_wide">; | |||||
def UmopaIntrOp : ArmSME_IntrMopOverloadedOp<"umopa_wide">; | |||||
def UmopsIntrOp : ArmSME_IntrMopOverloadedOp<"umops_wide">; | |||||
def SUmopaIntrOp : ArmSME_IntrMopOverloadedOp<"sumopa_wide">; | |||||
def SUmopsIntrOp : ArmSME_IntrMopOverloadedOp<"sumops_wide">; | |||||
def USmopaIntrOp : ArmSME_IntrMopOverloadedOp<"usmopa_wide">; | |||||
def USmopsIntrOp : ArmSME_IntrMopOverloadedOp<"usmops_wide">; | |||||
class ArmSME_IntrLoadStoreOverloadedOp<string mnemonic> | |||||
: ArmSME_IntrOverloadedOp<mnemonic>, | |||||
Arguments<(ins Arg<Predicate, "Vector predicate">, | |||||
Arg<LLVM_AnyPointer, "The location to store to", [MemWrite]>, | |||||
Arg<I32, "Tile register ID">, Arg<I32, "Vector number">)>; | |||||
// Loads | |||||
def LoadHorizontalBytesIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1b_horiz">; | |||||
def LoadHorizontalHalfsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1h_horiz">; | |||||
def LoadHorizontalWordsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1w_horiz">; | |||||
def LoadHorizontalDoublesIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1d_horiz">; | |||||
def LoadHorizontalQuadsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1q_horiz">; | |||||
// Stores | |||||
def StoreVerticalBytesIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1b_vert">; | |||||
def StoreVerticalHalfsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1h_vert">; | |||||
def StoreVerticalWordsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1w_vert">; | |||||
def StoreVerticalDoublesIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1d_vert">; | |||||
def StoreVerticalQuadsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1q_vert">; | |||||
#endif // ARMSME_OPS |
typo?