diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td @@ -39,7 +39,7 @@ list overloadedOperands, int numResults, list traits = [], bit requiresAccessGroup = 0> : LLVM_IntrOpBase traits = []> : ArmNeon_IntrOp; +// ArmNeon dialect op that corresponds to an LLVM IR intrinsic with one +// overloaded result and overloaded operands list. +class ArmNeon_OverloadedOperandsWithOneResultIntrOp overloadedOperands, + list traits = []> + : ArmNeon_IntrOp; + def SMullOp : ArmNeon_OverloadedOneResultIntrOp<"smull", [ NoSideEffect, AllTypesMatch<["a", "b"]>, @@ -82,5 +89,32 @@ "$a `,` $b attr-dict `:` type($a) `to` type($res)"; } +def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [ + NoSideEffect, + AllTypesMatch<["b", "c"]>, + AllTypesMatch<["a", "res"]>, + TypesMatchWith<"res has the same number of elements as operand b", + "b", "res", + "VectorType::get({$_self.cast().getShape()[0] / 4}," + "IntegerType::get($_self.getContext(), 32))">]> { + let summary = "sdot op"; + let description = [{ + Signed integer addition of dot product (vector). This instruction performs + the following operation on signed integer vectors: res = dot(b, c) + a, + where vector operands are partitioned into groups of four elements. + + Source: + https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics + }]; + // Supports either: + // (vector<2xi32>, vector<8xi8>, vector<8xi8>) -> vector<2xi32> + // (vector<4xi32>, vector<16xi8>, vector<16xi8>) -> vector<16xi32> + let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a, + VectorOfLengthAndType<[16, 8], [I8]>:$b, + VectorOfLengthAndType<[16, 8], [I8]>:$c); + let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res); + let assemblyFormat = + "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; + } #endif // ARMNEON_OPS diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir --- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir +++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir @@ -3,18 +3,25 @@ // CHECK-LABEL: arm_neon_smull func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>) -> (vector<8xi16>, vector<4xi32>, vector<2xi64>) { - // CHECK: arm_neon.smull {{.*}}: vector<8xi8> to vector<8xi16> - %0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16> + // CHECK: arm_neon.intr.smull {{.*}}: vector<8xi8> to vector<8xi16> + %0 = arm_neon.intr.smull %a, %b : vector<8xi8> to vector<8xi16> %00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}: vector<8xi16> to vector<4xi16> - // CHECK: arm_neon.smull {{.*}}: vector<4xi16> to vector<4xi32> - %1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32> + // CHECK: arm_neon.intr.smull {{.*}}: vector<4xi16> to vector<4xi32> + %1 = arm_neon.intr.smull %00, %00 : vector<4xi16> to vector<4xi32> %11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}: vector<4xi32> to vector<2xi32> - // CHECK: arm_neon.smull {{.*}}: vector<2xi32> to vector<2xi64> - %2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64> + // CHECK: arm_neon.intr.smull {{.*}}: vector<2xi32> to vector<2xi64> + %2 = arm_neon.intr.smull %11, %11 : vector<2xi32> to vector<2xi64> return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64> } + +// CHECK-LABEL: arm_neon_sdot +func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> { + // CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32> + %0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32> + return %0 : vector<2xi32> +} diff --git a/mlir/test/Target/arm-neon.mlir b/mlir/test/Target/arm-neon.mlir --- a/mlir/test/Target/arm-neon.mlir +++ b/mlir/test/Target/arm-neon.mlir @@ -4,16 +4,16 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> { // CHECK: %[[V0:.*]] = call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %{{.*}}, <8 x i8> %{{.*}}) // CHECK-NEXT: %[[V00:.*]] = shufflevector <8 x i16> %3, <8 x i16> %[[V0]], <4 x i32> - %0 = arm_neon.smull %arg0, %arg1 : vector<8xi8> to vector<8xi16> + %0 = arm_neon.intr.smull %arg0, %arg1 : vector<8xi8> to vector<8xi16> %1 = llvm.shufflevector %0, %0 [3, 4, 5, 6] : vector<8xi16>, vector<8xi16> // CHECK-NEXT: %[[V1:.*]] = call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %[[V00]], <4 x i16> %[[V00]]) // CHECK-NEXT: %[[V11:.*]] = shufflevector <4 x i32> %[[V1]], <4 x i32> %[[V1]], <2 x i32> - %2 = arm_neon.smull %1, %1 : vector<4xi16> to vector<4xi32> + %2 = arm_neon.intr.smull %1, %1 : vector<4xi16> to vector<4xi32> %3 = llvm.shufflevector %2, %2 [1, 2] : vector<4xi32>, vector<4xi32> // CHECK-NEXT: %[[V1:.*]] = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %[[V11]], <2 x i32> %[[V11]]) - %4 = arm_neon.smull %3, %3 : vector<2xi32> to vector<2xi64> + %4 = arm_neon.intr.smull %3, %3 : vector<2xi32> to vector<2xi64> %5 = llvm.mlir.undef : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> %6 = llvm.insertvalue %0, %5[0] : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> @@ -23,3 +23,19 @@ // CHECK: ret { <8 x i16>, <4 x i32>, <2 x i64> } llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> } + +// CHECK-LABEL: arm_neon_sdot_i8i8 +llvm.func @arm_neon_sdot_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> { + // CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}}) + // CHECK-NEXT: ret <2 x i32> + %0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32> + llvm.return %0 : vector<2xi32> +} + +// CHECK-LABEL: arm_neon_sdot_i16i16 +llvm.func @arm_neon_sdot_i16i16(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> { + // CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}}) + // CHECK-NEXT: ret <4 x i32> + %0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32> + llvm.return %0 : vector<4xi32> +}