diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1796,6 +1796,38 @@ Known = KnownBits::commonBits(Known, Known2); break; } + case AArch64ISD::BICi: { + KnownBits Known2, Known3; + Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1); + Known3 = DAG.computeKnownBits(Op->getOperand(2), Depth + 1); + if (!Known2.isConstant() || !Known3.isConstant()) + break; + + // Compute the mask. + APInt Mask = APInt::getMaxValue(Known.getBitWidth()); + uint64_t Position = Known2.getConstant().getZExtValue(); + uint64_t Shift = Known3.getConstant().getZExtValue(); + Mask.clearBit(llvm::Log2_64(Position << Shift)); + + // Compute the bit cleared value. + Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1); + Known &= KnownBits::makeConstant(Mask); + break; + } + case AArch64ISD::VLSHR: { + KnownBits Known2; + Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1); + Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1); + Known = KnownBits::lshr(Known, Known2); + break; + } + case AArch64ISD::VASHR: { + KnownBits Known2; + Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1); + Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1); + Known = KnownBits::ashr(Known, Known2); + break; + } case AArch64ISD::LOADgot: case AArch64ISD::ADDlow: { if (!Subtarget->isTargetILP32()) diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -8217,6 +8217,26 @@ def : Pat<(AArch64AssertZExtBool GPR32:$op), (i32 GPR32:$op)>; +// Treat an 'or' node is as an 'add' if the or'ed operands have no common bits. +def or_is_add : PatFrag<(ops node:$lhs, node:$rhs), (or node:$lhs, node:$rhs),[{ + return CurDAG->haveNoCommonBitsSet(N->getOperand(0), N->getOperand(1)); +}]>; + +multiclass ShiftAccum { + def _usra : Pat<(or_is_add (AArch64vlshr (VT RO:$src1), ImmT:$imm), (VT RO:$src2)), + (USRA_INST (VT RO:$src1), (VT RO:$src2), ImmT:$imm)>; + def _ssra : Pat<(or_is_add (AArch64vashr (VT RO:$src1), ImmT:$imm), (VT RO:$src2)), + (SSRA_INST (VT RO:$src1), (VT RO:$src2), ImmT:$imm)>; +} + +defm v8i8 : ShiftAccum; +defm v16i8 : ShiftAccum; +defm v2i32 : ShiftAccum; +defm v2i64 : ShiftAccum; +defm v4i16 : ShiftAccum; +defm v4i32 : ShiftAccum; +defm v8i16 : ShiftAccum; + include "AArch64InstrAtomics.td" include "AArch64SVEInstrInfo.td" include "AArch64SMEInstrInfo.td" diff --git a/llvm/test/CodeGen/AArch64/shift-accumulate.ll b/llvm/test/CodeGen/AArch64/shift-accumulate.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/shift-accumulate.ll @@ -0,0 +1,27 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-unknown-unknown | FileCheck %s + +define dso_local <8 x i16> @usra(<16 x i8> %0) local_unnamed_addr #0 align 32 { +; CHECK-LABEL: usra: +; CHECK: // %bb.0: +; CHECK-NEXT: ushr v0.16b, v0.16b, #7 +; CHECK-NEXT: usra v0.8h, v0.8h, #7 + %2 = lshr <16 x i8> %0, + %3 = bitcast <16 x i8> %2 to <8 x i16> + %4 = lshr <8 x i16> %3, + %5 = or <8 x i16> %4, %3 + ret <8 x i16> %5 +} + +define dso_local <8 x i16> @ssra(<8 x i16> %0) local_unnamed_addr #0 align 32 { +; CHECK-LABEL: ssra: +; CHECK: ssra v0.8h, v1.8h, #14 + ; Set the 15th bit to zero. + %2 = and <8 x i16> %0, + ; The first 15 MSB is zero, the last bit can be zero or one. + %3 = lshr <8 x i16> %0, + ; The first 15 MSB maybe 1, and the last bit is zero. + %4 = ashr <8 x i16> %2, + %5 = or <8 x i16> %4, %3 + ret <8 x i16> %5 +}