diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -14,6 +14,7 @@ #ifndef ARMSME_OPS #define ARMSME_OPS +include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -61,6 +62,12 @@ def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128, nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>; +def SVEVector : ScalableVectorOfRankAndLengthAndType< + [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>; + +def SVEPredicate : ScalableVectorOfRankAndLengthAndType< + [1], [16, 8, 4, 2, 1], [I1]>; + // A type constraint that verifies the bitwidth of the scalar integer returned // from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile". def TileElementWidthMatchesTileID : TypesMatchWith< @@ -496,6 +503,18 @@ Arguments<(ins Arg, Arg)>; +// Vector to tile +class LLVM_aarch64_sme_write + : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3], + [AllShapesMatch<["pg", "vector"]>]>, + Arguments<(ins Arg, + Arg, + Arg:$pg, + Arg:$vector)>; + +def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">; +def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">; + def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">; def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -533,6 +533,19 @@ ScalableVectorOfLength.summary, "::mlir::VectorType">; +// Any scalable vector where the rank is from the given `allowedRanks` list and +// the number of elements is from the given `allowedLengths` list and the type +// is from the given `allowedTypes` list +class ScalableVectorOfRankAndLengthAndType allowedRanks, + list allowedLengths, + list allowedTypes> : AllOfType< + [ScalableVectorOfRank, ScalableVectorOf, + ScalableVectorOfLength], + ScalableVectorOfRank.summary # + ScalableVectorOf.summary # + ScalableVectorOfLength.summary, + "::mlir::VectorType">; + def AnyVector : VectorOf<[AnyType]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// Verify shape of predicate and vector must match +llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32, + %nxv4i1 : vector<[4]xi1>, + %nxv16i8 : vector<[16]xi8>) { + %tile = llvm.mlir.constant(0 : index) : i32 + // expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}} + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) : + (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> () + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir --- a/mlir/test/Target/LLVMIR/arm-sme.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -236,3 +236,101 @@ "arm_sme.intr.za.disable"() : () -> () llvm.return } + +// ----- + +// CHECK-LABEL: @arm_sme_vector_to_tile_horiz +llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32, + %nxv16i1 : vector<[16]xi1>, + %nxv8i1 : vector<[8]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv2i1 : vector<[2]xi1>, + %nxv1i1 : vector<[1]xi1>, + %nxv16i8 : vector<[16]xi8>, + %nxv8i16 : vector<[8]xi16>, + %nxv4i32 : vector<[4]xi32>, + %nxv2i64 : vector<[2]xi64>, + %nxv1i128 : vector<[1]xi128>, + %nxv8f16 : vector<[8]xf16>, + %nxv8bf16 : vector<[8]xbf16>, + %nxv4f32 : vector<[4]xf32>, + %nxv2f64 : vector<[2]xf64>) { + %tile = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) : + (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) : + (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) : + (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> () + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) : + (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> () + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) : + (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> () + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) : + (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> () + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) : + (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> () + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) : + (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> () + // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64 + "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) : + (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> () + llvm.return +} + +// ----- + +// CHECK-LABEL: @arm_sme_vector_to_tile_vert +llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32, + %nxv16i1 : vector<[16]xi1>, + %nxv8i1 : vector<[8]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv2i1 : vector<[2]xi1>, + %nxv1i1 : vector<[1]xi1>, + %nxv16i8 : vector<[16]xi8>, + %nxv8i16 : vector<[8]xi16>, + %nxv4i32 : vector<[4]xi32>, + %nxv2i64 : vector<[2]xi64>, + %nxv1i128 : vector<[1]xi128>, + %nxv8f16 : vector<[8]xf16>, + %nxv8bf16 : vector<[8]xbf16>, + %nxv4f32 : vector<[4]xf32>, + %nxv2f64 : vector<[2]xf64>) { + %tile = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) : + (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) : + (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) : + (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> () + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) : + (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> () + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) : + (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> () + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) : + (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> () + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) : + (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> () + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) : + (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> () + // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64 + "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) : + (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> () + llvm.return +}