Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -6983,5 +6983,121 @@ def : Pat<(i32 (extractelt (v2i32 V64:$V), (i64 0))), (EXTRACT_SUBREG V64:$V, ssub)>; } +// dot_v4i8 +class mul_v4i8 : + PatFrag<(ops node:$Rn, node:$Rm, node:$offset), + (mul (ldop (add node:$Rn, node:$offset)), + (ldop (add node:$Rm, node:$offset)))>; +class mulz_v4i8 : + PatFrag<(ops node:$Rn, node:$Rm), + (mul (ldop node:$Rn), (ldop node:$Rm))>; + +def load_v4i8 : + OutPatFrag<(ops node:$R), + (INSERT_SUBREG + (v2i32 (IMPLICIT_DEF)), + (i32 (COPY_TO_REGCLASS (LDRWui node:$R, (i64 0)), FPR32)), + ssub)>; + +class dot_v4i8 : + Pat<(i32 (add (mul_v4i8 GPR64sp:$Rn, GPR64sp:$Rm, (i64 3)), + (add (mul_v4i8 GPR64sp:$Rn, GPR64sp:$Rm, (i64 2)), + (add (mul_v4i8 GPR64sp:$Rn, GPR64sp:$Rm, (i64 1)), + (mulz_v4i8 GPR64sp:$Rn, GPR64sp:$Rm))))), + (EXTRACT_SUBREG (i64 (DOT (DUPv2i32gpr WZR), + (load_v4i8 GPR64sp:$Rn), + (load_v4i8 GPR64sp:$Rm))), + sub_32)>, Requires<[HasDotProd]>; + +// dot_v8i8 +class ee_v8i8 : + PatFrag<(ops node:$V, node:$K), + (v4i16 (extract_subvector (v8i16 (extend node:$V)), node:$K))>; + +class mul_v8i8 : + PatFrag<(ops node:$M, node:$N, node:$K), + (mulop (v4i16 (ee_v8i8 node:$M, node:$K)), + (v4i16 (ee_v8i8 node:$N, node:$K)))>; + +class idot_v8i8 : + PatFrag<(ops node:$M, node:$N), + (i32 (extractelt + (v4i32 (AArch64uaddv + (add (mul_v8i8 node:$M, node:$N, (i64 0)), + (mul_v8i8 node:$M, node:$N, (i64 4))))), + (i64 0)))>; + +// vaddv_[su]32 is special; -> ADDP Vd.2S,Vn.2S,Vm.2S; return Vd.s[0];Vn==Vm +def VADDV_32 : OutPatFrag<(ops node:$R), (ADDPv2i32 node:$R, node:$R)>; + +class odot_v8i8 : + OutPatFrag<(ops node:$Vo, node:$Vm, node:$Vn), + (EXTRACT_SUBREG + (VADDV_32 + (i64 (DOT (DUPv2i32gpr node:$Vo), + (v8i8 node:$Vm), + (v8i8 node:$Vn)))), + sub_32)>; + +multiclass dot_v8i8 { + def : Pat<(idot_v8i8 V64:$Vm, V64:$Vn), + (odot_v8i8 WZR, V64:$Vm, V64:$Vn)>, + Requires<[HasDotProd]>; + def : Pat<(add GPR32:$Vo, + (idot_v8i8 V64:$Vm, V64:$Vn)), + (odot_v8i8 GPR32:$Vo, V64:$Vm, V64:$Vn)>, + Requires<[HasDotProd]>; +} + +// dot_v16i8 +class ee_v16i8 : + PatFrag<(ops node:$V, node:$K1, node:$K2), + (v4i16 (extract_subvector + (v8i16 (extend + (v8i8 (extract_subvector node:$V, node:$K1)))), node:$K2))>; + +class mul_v16i8 : + PatFrag<(ops node:$M, node:$N, node:$K1, node:$K2), + (v4i32 + (mulop (v4i16 (ee_v16i8 node:$M, node:$K1, node:$K2)), + (v4i16 (ee_v16i8 node:$N, node:$K1, node:$K2))))>; + +class idot_v16i8 : + PatFrag<(ops node:$M, node:$N), + (i32 (extractelt + (v4i32 (AArch64uaddv + (add + (add (mul_v16i8 node:$M, node:$N, (i64 0), (i64 0)), + (mul_v16i8 node:$M, node:$N, (i64 8), (i64 0))), + (add (mul_v16i8 node:$M, node:$N, (i64 0), (i64 4)), + (mul_v16i8 node:$M, node:$N, (i64 8), (i64 4)))))), + (i64 0)))>; + +class odot_v16i8 : + OutPatFrag<(ops node:$Vo, node:$Vm, node:$Vn), + (i32 (ADDVv4i32v + (DOT (DUPv4i32gpr node:$Vo), node:$Vm, node:$Vn)))>; + +multiclass dot_v16i8 { + def : Pat<(idot_v16i8 V128:$Vm, V128:$Vn), + (odot_v16i8 WZR, V128:$Vm, V128:$Vn)>, + Requires<[HasDotProd]>; + def : Pat<(add GPR32:$Vo, + (idot_v16i8 V128:$Vm, V128:$Vn)), + (odot_v16i8 GPR32:$Vo, V128:$Vm, V128:$Vn)>, + Requires<[HasDotProd]>; +} + +let AddedComplexity = 30 in { + def : dot_v4i8; + def : dot_v4i8; + defm : dot_v8i8; + defm : dot_v8i8; + defm : dot_v16i8; + defm : dot_v16i8; +} + include "AArch64InstrAtomics.td" include "AArch64SVEInstrInfo.td" Index: llvm/test/CodeGen/AArch64/neon-dot-product.ll =================================================================== --- llvm/test/CodeGen/AArch64/neon-dot-product.ll +++ llvm/test/CodeGen/AArch64/neon-dot-product.ll @@ -128,3 +128,145 @@ %vdot1.i = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %a, <16 x i8> %b, <16 x i8> %.cast3) #2 ret <4 x i32> %vdot1.i } + +define fastcc void @test_sdot_v4i8(i8* noalias nocapture %0, i8* noalias nocapture readonly %1, i8* noalias nocapture readonly %2) { +entry: +; CHECK-LABEL: test_sdot_v4i8: +; CHECK: sdot {{v[0-9]+}}.2s, {{v[0-9]+}}.8b, {{v[0-9]+}}.8b + %3 = bitcast i8* %0 to i32* + %4 = load i8, i8* %1, align 1 + %5 = sext i8 %4 to i32 + %6 = load i8, i8* %2, align 1 + %7 = sext i8 %6 to i32 + %8 = mul nsw i32 %7, %5 + %9 = getelementptr inbounds i8, i8* %1, i64 1 + %10 = load i8, i8* %9, align 1 + %11 = sext i8 %10 to i32 + %12 = getelementptr inbounds i8, i8* %2, i64 1 + %13 = load i8, i8* %12, align 1 + %14 = sext i8 %13 to i32 + %15 = mul nsw i32 %14, %11 + %16 = add nsw i32 %15, %8 + %17 = getelementptr inbounds i8, i8* %1, i64 2 + %18 = load i8, i8* %17, align 1 + %19 = sext i8 %18 to i32 + %20 = getelementptr inbounds i8, i8* %2, i64 2 + %21 = load i8, i8* %20, align 1 + %22 = sext i8 %21 to i32 + %23 = mul nsw i32 %22, %19 + %24 = add nsw i32 %23, %16 + %25 = getelementptr inbounds i8, i8* %1, i64 3 + %26 = load i8, i8* %25, align 1 + %27 = sext i8 %26 to i32 + %28 = getelementptr inbounds i8, i8* %2, i64 3 + %29 = load i8, i8* %28, align 1 + %30 = sext i8 %29 to i32 + %31 = mul nsw i32 %30, %27 + %32 = add nsw i32 %31, %24 + store i32 %32, i32* %3, align 64 + ret void +} + +define fastcc void @test_udot_v4i8(i8* noalias nocapture %0, i8* noalias nocapture readonly %1, i8* noalias nocapture readonly %2) { +entry: +; CHECK-LABEL: test_udot_v4i8: +; CHECK: udot {{v[0-9]+}}.2s, {{v[0-9]+}}.8b, {{v[0-9]+}}.8b + %3 = bitcast i8* %0 to i32* + %4 = load i8, i8* %1, align 1 + %5 = zext i8 %4 to i32 + %6 = load i8, i8* %2, align 1 + %7 = zext i8 %6 to i32 + %8 = mul nsw i32 %7, %5 + %9 = getelementptr inbounds i8, i8* %1, i64 1 + %10 = load i8, i8* %9, align 1 + %11 = zext i8 %10 to i32 + %12 = getelementptr inbounds i8, i8* %2, i64 1 + %13 = load i8, i8* %12, align 1 + %14 = zext i8 %13 to i32 + %15 = mul nsw i32 %14, %11 + %16 = add nsw i32 %15, %8 + %17 = getelementptr inbounds i8, i8* %1, i64 2 + %18 = load i8, i8* %17, align 1 + %19 = zext i8 %18 to i32 + %20 = getelementptr inbounds i8, i8* %2, i64 2 + %21 = load i8, i8* %20, align 1 + %22 = zext i8 %21 to i32 + %23 = mul nsw i32 %22, %19 + %24 = add nsw i32 %23, %16 + %25 = getelementptr inbounds i8, i8* %1, i64 3 + %26 = load i8, i8* %25, align 1 + %27 = zext i8 %26 to i32 + %28 = getelementptr inbounds i8, i8* %2, i64 3 + %29 = load i8, i8* %28, align 1 + %30 = zext i8 %29 to i32 + %31 = mul nsw i32 %30, %27 + %32 = add nsw i32 %31, %24 + store i32 %32, i32* %3, align 64 + ret void +} + +declare i32 @llvm.experimental.vector.reduce.add.v8i32(<8 x i32>) + +define i32 @test_udot_v8i8(i8* nocapture readonly %a, i8* nocapture readonly %b) { +entry: +; CHECK-LABEL: test_udot_v8i8: +; CHECK: udot {{v[0-9]+}}.2s, {{v[0-9]+}}.8b, {{v[0-9]+}}.8b + %0 = bitcast i8* %a to <8 x i8>* + %1 = load <8 x i8>, <8 x i8>* %0 + %2 = zext <8 x i8> %1 to <8 x i32> + %3 = bitcast i8* %b to <8 x i8>* + %4 = load <8 x i8>, <8 x i8>* %3 + %5 = zext <8 x i8> %4 to <8 x i32> + %6 = mul nuw nsw <8 x i32> %5, %2 + %7 = call i32 @llvm.experimental.vector.reduce.add.v8i32(<8 x i32> %6) + ret i32 %7 +} + +define i32 @test_sdot_v8i8(i8* nocapture readonly %a, i8* nocapture readonly %b) { +entry: +; CHECK-LABEL: test_sdot_v8i8: +; CHECK: sdot {{v[0-9]+}}.2s, {{v[0-9]+}}.8b, {{v[0-9]+}}.8b + %0 = bitcast i8* %a to <8 x i8>* + %1 = load <8 x i8>, <8 x i8>* %0 + %2 = sext <8 x i8> %1 to <8 x i32> + %3 = bitcast i8* %b to <8 x i8>* + %4 = load <8 x i8>, <8 x i8>* %3 + %5 = sext <8 x i8> %4 to <8 x i32> + %6 = mul nsw <8 x i32> %5, %2 + %7 = call i32 @llvm.experimental.vector.reduce.add.v8i32(<8 x i32> %6) + ret i32 %7 +} + +declare i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32>) + +define i32 @test_udot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b, i32 %sum) { +entry: +; CHECK-LABEL: test_udot_v16i8: +; CHECK: udot {{v[0-9]+}}.2s, {{v[0-9]+}}.8b, {{v[0-9]+}}.8b + %0 = bitcast i8* %a to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0 + %2 = zext <16 x i8> %1 to <16 x i32> + %3 = bitcast i8* %b to <16 x i8>* + %4 = load <16 x i8>, <16 x i8>* %3 + %5 = zext <16 x i8> %4 to <16 x i32> + %6 = mul nuw nsw <16 x i32> %5, %2 + %7 = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %6) + %op.extra = add i32 %7, %sum + ret i32 %op.extra +} + +define i32 @test_sdot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b, i32 %sum) { +entry: +; CHECK-LABEL: test_sdot_v16i8: +; CHECK: sdot {{v[0-9]+}}.2s, {{v[0-9]+}}.8b, {{v[0-9]+}}.8b + %0 = bitcast i8* %a to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0 + %2 = sext <16 x i8> %1 to <16 x i32> + %3 = bitcast i8* %b to <16 x i8>* + %4 = load <16 x i8>, <16 x i8>* %3 + %5 = sext <16 x i8> %4 to <16 x i32> + %6 = mul nsw <16 x i32> %5, %2 + %7 = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %6) + %op.extra = add nsw i32 %7, %sum + ret i32 %op.extra +}