diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -494,6 +494,11 @@ explicit AArch64TargetLowering(const TargetMachine &TM, const AArch64Subtarget &STI); + /// Control the following reassociation of operands: (op (op x, c1), y) -> (op + /// (op x, y), c1) where N0 is (op x, c1) and N1 is y. + bool isReassocProfitable(SelectionDAG &DAG, SDValue N0, + SDValue N1) const override; + /// Selects the correct CCAssignFn for a given CallingConvention value. CCAssignFn *CCAssignFnForCall(CallingConv::ID CC, bool IsVarArg) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -5468,6 +5468,36 @@ // Calling Convention Implementation //===----------------------------------------------------------------------===// +static unsigned getIntrinsicID(const SDNode *N) { + unsigned Opcode = N->getOpcode(); + switch (Opcode) { + default: + return Intrinsic::not_intrinsic; + case ISD::INTRINSIC_WO_CHAIN: { + unsigned IID = cast(N->getOperand(0))->getZExtValue(); + if (IID < Intrinsic::num_intrinsics) + return IID; + return Intrinsic::not_intrinsic; + } + } +} + +bool AArch64TargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0, + SDValue N1) const { + if (!N0.hasOneUse()) + return false; + + unsigned IID = getIntrinsicID(N1.getNode()); + // Avoid reassociating expressions that can be lowered to smlal/umlal. + if (IID == Intrinsic::aarch64_neon_umull || + N1.getOpcode() == AArch64ISD::UMULL || + IID == Intrinsic::aarch64_neon_smull || + N1.getOpcode() == AArch64ISD::SMULL) + return N0.getOpcode() != ISD::ADD; + + return true; +} + /// Selects the correct CCAssignFn for a given CallingConvention value. CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, bool IsVarArg) const { @@ -10692,20 +10722,6 @@ return true; } -static unsigned getIntrinsicID(const SDNode *N) { - unsigned Opcode = N->getOpcode(); - switch (Opcode) { - default: - return Intrinsic::not_intrinsic; - case ISD::INTRINSIC_WO_CHAIN: { - unsigned IID = cast(N->getOperand(0))->getZExtValue(); - if (IID < Intrinsic::num_intrinsics) - return IID; - return Intrinsic::not_intrinsic; - } - } -} - // Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)), // to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a // BUILD_VECTORs with constant element C1, C2 is a constant, and: diff --git a/llvm/test/CodeGen/AArch64/arm64-vmul.ll b/llvm/test/CodeGen/AArch64/arm64-vmul.ll --- a/llvm/test/CodeGen/AArch64/arm64-vmul.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vmul.ll @@ -388,12 +388,11 @@ define void @smlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) { ; CHECK-LABEL: smlal8h_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: smull.8h v0, v0, v2 -; CHECK-NEXT: mvn.8b v2, v2 ; CHECK-NEXT: movi.16b v3, #1 -; CHECK-NEXT: smlal.8h v0, v1, v2 -; CHECK-NEXT: add.8h v0, v0, v3 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: smlal.8h v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: smlal.8h v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <8 x i8> %v3, %smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3) @@ -407,13 +406,12 @@ define void @smlal2d_chain_with_constant(<2 x i64>* %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) { ; CHECK-LABEL: smlal2d_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: smull.2d v0, v0, v2 ; CHECK-NEXT: mov w8, #257 -; CHECK-NEXT: mvn.8b v2, v2 -; CHECK-NEXT: smlal.2d v0, v1, v2 -; CHECK-NEXT: dup.2d v1, x8 -; CHECK-NEXT: add.2d v0, v0, v1 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: dup.2d v3, x8 +; CHECK-NEXT: smlal.2d v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: smlal.2d v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <2 x i32> %v3, %smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3) @@ -671,12 +669,11 @@ define void @umlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) { ; CHECK-LABEL: umlal8h_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: umull.8h v0, v0, v2 -; CHECK-NEXT: mvn.8b v2, v2 ; CHECK-NEXT: movi.16b v3, #1 -; CHECK-NEXT: umlal.8h v0, v1, v2 -; CHECK-NEXT: add.8h v0, v0, v3 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: umlal.8h v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: umlal.8h v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <8 x i8> %v3, %umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3) @@ -690,13 +687,12 @@ define void @umlal2d_chain_with_constant(<2 x i64>* %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) { ; CHECK-LABEL: umlal2d_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: umull.2d v0, v0, v2 ; CHECK-NEXT: mov w8, #257 -; CHECK-NEXT: mvn.8b v2, v2 -; CHECK-NEXT: umlal.2d v0, v1, v2 -; CHECK-NEXT: dup.2d v1, x8 -; CHECK-NEXT: add.2d v0, v0, v1 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: dup.2d v3, x8 +; CHECK-NEXT: umlal.2d v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: umlal.2d v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <2 x i32> %v3, %umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3)