diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -46,6 +46,29 @@ getST()->getFeatureBits()); } +static bool canUseShiftPair(Instruction *Inst, const APInt &Imm) { + uint64_t Mask = Imm.getZExtValue(); + auto *BO = dyn_cast(Inst->getOperand(0)); + if (!BO || !BO->hasOneUse()) + return false; + + if (BO->getOpcode() != Instruction::Shl) + return false; + + if (!isa(BO->getOperand(1))) + return false; + + unsigned ShAmt = cast(BO->getOperand(1))->getZExtValue(); + + if (isShiftedMask_64(Mask)) { + unsigned Trailing = countTrailingZeros(Mask); + if (ShAmt == Trailing) + return true; + } + + return false; +} + InstructionCost RISCVTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind, @@ -75,6 +98,9 @@ // zext.w if (Imm == UINT64_C(0xffffffff) && ST->hasStdExtZba()) return TTI::TCC_Free; + if (Inst && Idx == 1 && Imm.getBitWidth() <= ST->getXLen() && + canUseShiftPair(Inst, Imm)) + return TTI::TCC_Free; [[fallthrough]]; case Instruction::Add: case Instruction::Or: diff --git a/llvm/test/Transforms/ConstantHoisting/RISCV/immediates.ll b/llvm/test/Transforms/ConstantHoisting/RISCV/immediates.ll --- a/llvm/test/Transforms/ConstantHoisting/RISCV/immediates.ll +++ b/llvm/test/Transforms/ConstantHoisting/RISCV/immediates.ll @@ -81,3 +81,17 @@ %2 = mul i64 %1, -4294967296 ret i64 %2 } + +define i32 @test10(i32 %a, i32 %b) nounwind { +; CHECK-LABEL: @test10( +; CHECK: shl i32 %a, 8 +; CHECK: and i32 %1, 65280 +; CHECK: shl i32 %b, 8 +; CHECK: and i32 %3, 65280 + %1 = shl i32 %a, 8 + %2 = and i32 %1, 65280 + %3 = shl i32 %b, 8 + %4 = and i32 %3, 65280 + %5 = mul i32 %2, %4 + ret i32 %5 +}