diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td @@ -53,6 +53,13 @@ list traits = []> : ArmNeon_IntrOp; +// ArmNeon dialect op that corresponds to an LLVM IR intrinsic with one +// overloaded result and overloaded operands list. +class ArmNeon_OverloadedOperandsWithOneResultIntrOp overloadedOperands, + list traits = []> + : ArmNeon_IntrOp; + def SMullOp : ArmNeon_OverloadedOneResultIntrOp<"smull", [ NoSideEffect, AllTypesMatch<["a", "b"]>, @@ -82,5 +89,32 @@ "$a `,` $b attr-dict `:` type($a) `to` type($res)"; } +def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [ + NoSideEffect, + AllTypesMatch<["b", "c"]>, + AllTypesMatch<["a", "res"]>, + TypesMatchWith<"res has the same number of elements as operand b", + "b", "res", + "VectorType::get({$_self.cast().getShape()[0] / 4}," + "IntegerType::get($_self.getContext(), 32))">]> { + let summary = "sdot op"; + let description = [{ + Signed integer addition of dot product (vector). This instruction performs + the following operation on signed integer vectors: res = dot(b, c) + a, + where vector operands are partitioned into groups of four elements. + + Source: + https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics + }]; + // Supports either: + // (vector<2xi32>, vector<8xi8>, vector<8xi8>) -> vector<2xi32> + // (vector<4xi32>, vector<16xi8>, vector<16xi8>) -> vector<16xi32> + let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a, + VectorOfLengthAndType<[16, 8], [I8]>:$b, + VectorOfLengthAndType<[16, 8], [I8]>:$c); + let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res); + let assemblyFormat = + "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; + } #endif // ARMNEON_OPS diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir --- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir +++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir @@ -18,3 +18,10 @@ return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64> } + +// CHECK-LABEL: arm_neon_sdot +func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> { + // CHECK: arm_neon.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32> + %0 = arm_neon.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32> + return %0 : vector<2xi32> +} diff --git a/mlir/test/Target/arm-neon.mlir b/mlir/test/Target/arm-neon.mlir --- a/mlir/test/Target/arm-neon.mlir +++ b/mlir/test/Target/arm-neon.mlir @@ -23,3 +23,19 @@ // CHECK: ret { <8 x i16>, <4 x i32>, <2 x i64> } llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> } + +// CHECK-LABEL: arm_neon_sdot_i8i8 +llvm.func @arm_neon_sdot_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> { + // CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}}) + // CHECK-NEXT: ret <2 x i32> + %0 = arm_neon.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32> + llvm.return %0 : vector<2xi32> +} + +// CHECK-LABEL: arm_neon_sdot_i16i16 +llvm.func @arm_neon_sdot_i16i16(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> { + // CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}}) + // CHECK-NEXT: ret <4 x i32> + %0 = arm_neon.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32> + llvm.return %0 : vector<4xi32> +}