diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1068,6 +1068,13 @@ bool shouldLocalize(const MachineInstr &MI, const TargetTransformInfo *TTI) const override; + bool SimplifyDemandedBitsForTargetNode(SDValue Op, + const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, + KnownBits &Known, + TargetLoweringOpt &TLO, + unsigned Depth) const override; + // Normally SVE is only used for byte size vectors that do not fit within a // NEON vector. This changes when OverrideNEON is true, allowing SVE to be // used for 64bit and 128bit vectors as well. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14639,6 +14639,29 @@ return SDValue(); } +/// Optimize a vector shift instruction and its operand if shifted out +/// bits are not used. +static SDValue performVectorShiftCombine(SDNode *N, + const AArch64TargetLowering &TLI, + TargetLowering::DAGCombinerInfo &DCI) { + assert(N->getOpcode() == AArch64ISD::VASHR || + N->getOpcode() == AArch64ISD::VLSHR); + + SDValue Op = N->getOperand(0); + unsigned OpScalarSize = Op.getScalarValueSizeInBits(); + + unsigned ShiftImm = N->getConstantOperandVal(1); + assert(OpScalarSize > ShiftImm && "Invalid shift imm"); + + APInt ShiftedOutBits = APInt::getLowBitsSet(OpScalarSize, ShiftImm); + APInt DemandedMask = ~ShiftedOutBits; + + if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI)) + return SDValue(N, 0); + + return SDValue(); +} + /// Target-specific DAG combine function for post-increment LD1 (lane) and /// post-increment LD1R. static SDValue performPostLD1Combine(SDNode *N, @@ -16115,6 +16138,9 @@ case AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO: case AArch64ISD::GLD1S_IMM_MERGE_ZERO: return performGLD1Combine(N, DAG); + case AArch64ISD::VASHR: + case AArch64ISD::VLSHR: + return performVectorShiftCombine(N, *this, DCI); case ISD::INSERT_VECTOR_ELT: return performInsertVectorEltCombine(N, DCI); case ISD::EXTRACT_VECTOR_ELT: @@ -17805,3 +17831,47 @@ bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const { return ::isAllActivePredicate(N); } + +bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode( + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth) const { + + unsigned Opc = Op.getOpcode(); + switch (Opc) { + case AArch64ISD::VSHL: { + // Match (VSHL (VLSHR Val X) X) + SDValue ShiftL = Op; + SDValue ShiftR = Op->getOperand(0); + if (ShiftR->getOpcode() != AArch64ISD::VLSHR) + return false; + + if (!ShiftL.hasOneUse() || !ShiftR.hasOneUse()) + return false; + + unsigned ShiftLBits = ShiftL->getConstantOperandVal(1); + unsigned ShiftRBits = ShiftR->getConstantOperandVal(1); + + // Other cases can be handled as well, but this is not + // implemented. + if (ShiftRBits != ShiftLBits) + return false; + + unsigned ScalarSize = Op.getScalarValueSizeInBits(); + assert(ScalarSize > ShiftLBits && "Invalid shift imm"); + + APInt ZeroBits = APInt::getLowBitsSet(ScalarSize, ShiftLBits); + APInt UnusedBits = ~OriginalDemandedBits; + + if ((ZeroBits & UnusedBits) != ZeroBits) + return false; + + // All bits that are zeroed by (VSHL (VLSHR Val X) X) are not + // used - simplify to just Val. + return TLO.CombineTo(Op, ShiftR->getOperand(0)); + } + } + + return TargetLowering::SimplifyDemandedBitsForTargetNode( + Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth); +} diff --git a/llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll @@ -0,0 +1,27 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s + +define <2 x i32> @test1(<2 x i16> %v2i16) { +; CHECK-LABEL: test1: +; CHECK: // %bb.0: +; CHECK-NEXT: rev32 v0.8b, v0.8b +; CHECK-NEXT: sshr v0.2s, v0.2s, #16 +; CHECK-NEXT: ret + %v2i16_rev = call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %v2i16) + %v2i32 = sext <2 x i16> %v2i16_rev to <2 x i32> + ret <2 x i32> %v2i32 +} + +define <2 x float> @test2(<2 x i16> %v2i16) { +; CHECK-LABEL: test2: +; CHECK: // %bb.0: +; CHECK-NEXT: rev32 v0.8b, v0.8b +; CHECK-NEXT: sshr v0.2s, v0.2s, #16 +; CHECK-NEXT: scvtf v0.2s, v0.2s +; CHECK-NEXT: ret + %v2i16_rev = call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %v2i16) + %v2f32 = sitofp <2 x i16> %v2i16_rev to <2 x float> + ret <2 x float> %v2f32 +} + +declare <2 x i16> @llvm.bswap.v2i16(<2 x i16>) nounwind readnone