diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1064,6 +1064,13 @@ bool shouldLocalize(const MachineInstr &MI, const TargetTransformInfo *TTI) const override; + bool SimplifyDemandedBitsForTargetNode(SDValue Op, + const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, + KnownBits &Known, + TargetLoweringOpt &TLO, + unsigned Depth) const override; + // Normally SVE is only used for byte size vectors that do not fit within a // NEON vector. This changes when OverrideNEON is true, allowing SVE to be // used for 64bit and 128bit vectors as well. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14558,6 +14558,26 @@ return SDValue(); } +static SDValue performShiftCombine(SDNode *N, const AArch64TargetLowering &TLI, + TargetLowering::DAGCombinerInfo &DCI) { + unsigned Opc = N->getOpcode(); + if (Opc != AArch64ISD::VASHR && Opc != AArch64ISD::VLSHR) + return SDValue(); + + SDValue Op = N->getOperand(0); + uint64_t OpScalarSize = Op.getScalarValueSizeInBits(); + + uint64_t ShiftImm = N->getConstantOperandVal(1); + assert(OpScalarSize > ShiftImm && "Invalid shift imm"); + + APInt ShiftedOutBits = + APInt::getLowBitsSet((unsigned)OpScalarSize, (unsigned)ShiftImm); + APInt DemandedMask = ~ShiftedOutBits; + + TLI.SimplifyDemandedBits(Op, DemandedMask, DCI); + return SDValue(); +} + /// Target-specific DAG combine function for post-increment LD1 (lane) and /// post-increment LD1R. static SDValue performPostLD1Combine(SDNode *N, @@ -16034,6 +16054,9 @@ case AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO: case AArch64ISD::GLD1S_IMM_MERGE_ZERO: return performGLD1Combine(N, DAG); + case AArch64ISD::VASHR: + case AArch64ISD::VLSHR: + return performShiftCombine(N, *this, DCI); case ISD::INSERT_VECTOR_ELT: return performInsertVectorEltCombine(N, DCI); case ISD::EXTRACT_VECTOR_ELT: @@ -17724,3 +17747,48 @@ bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const { return ::isAllActivePredicate(N); } + +bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode( + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth) const { + + unsigned Opc = Op.getOpcode(); + switch (Opc) { + case AArch64ISD::VSHL: { + // Match (VSHL (VLSHR Val X) X) + SDValue ShiftL = Op; + SDValue ShiftR = Op->getOperand(0); + if (ShiftR->getOpcode() != AArch64ISD::VLSHR) + return false; + + if (!ShiftL.hasOneUse() || !ShiftR.hasOneUse()) + return false; + + uint64_t ShiftLBits = ShiftL->getConstantOperandVal(1); + uint64_t ShiftRBits = ShiftR->getConstantOperandVal(1); + + // Other cases can be handled as well, but this is not + // implemented. + if (ShiftRBits != ShiftLBits) + return false; + + uint64_t ScalarSize = Op.getScalarValueSizeInBits(); + assert(ScalarSize > ShiftLBits && "Invalid shift imm"); + + APInt ZeroBits = + APInt::getLowBitsSet((unsigned)ScalarSize, (unsigned)ShiftLBits); + APInt UnusedBits = ~OriginalDemandedBits; + + if ((ZeroBits & UnusedBits) != ZeroBits) + return false; + + // All bits that are zeroed by (VSHL (VLSHR Val X) X) are not + // used - simplify to just Val. + return TLO.CombineTo(Op, ShiftR->getOperand(0)); + } + } + + return TargetLowering::SimplifyDemandedBitsForTargetNode( + Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth); +} diff --git a/llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll @@ -0,0 +1,27 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -O2 -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s + +define <2 x i32> @test1(<2 x i16> %v2i16) { +; CHECK-LABEL: test1: +; CHECK: // %bb.0: +; CHECK-NEXT: rev32 v0.8b, v0.8b +; CHECK-NEXT: sshr v0.2s, v0.2s, #16 +; CHECK-NEXT: ret + %v2i16_rev = call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %v2i16) + %v2i32 = sext <2 x i16> %v2i16_rev to <2 x i32> + ret <2 x i32> %v2i32 +} + +define <2 x float> @test2(<2 x i16> %v2i16) { +; CHECK-LABEL: test2: +; CHECK: // %bb.0: +; CHECK-NEXT: rev32 v0.8b, v0.8b +; CHECK-NEXT: sshr v0.2s, v0.2s, #16 +; CHECK-NEXT: scvtf v0.2s, v0.2s +; CHECK-NEXT: ret + %v2i16_rev = call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %v2i16) + %v2f32 = sitofp <2 x i16> %v2i16_rev to <2 x float> + ret <2 x float> %v2f32 +} + +declare <2 x i16> @llvm.bswap.v2i16(<2 x i16>) nounwind readnone