diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -568,6 +568,29 @@ if (isAllNull) // This is casting one pointer type to another, always BitCast return ConstantExpr::getPointerCast(CE->getOperand(0), DestTy); + } else if (CE->getType()->isVectorTy() && DestTy->isVectorTy() && + CE->getType()->getVectorElementCount() == + DestTy->getVectorElementCount() && + CE->getOpcode() == Instruction::ShuffleVector && + opc == Instruction::BitCast) { + // For splat vector, fold bitcast to splat value. + // BitCast(ShuffleVector(InsertElement(C1, SplatV, Zero), C2, Zero)) to NewType + // into + // ShuffleVector(InsertElement(C1, BitCast(SplatV) to NewType, Zero), C2, Zero) + Constant *SplatV, *C1, *C2, *ZeroIdx, *ZeroMask; + if (match(CE, m_ShuffleVector( + m_InsertElement( + m_Constant(C1), m_Constant(SplatV), + m_CombineAnd(m_Zero(), m_Constant(ZeroIdx))), + m_Constant(C2), + m_CombineAnd(m_Zero(), m_Constant(ZeroMask))))) { + auto *CastedSplatV = + ConstantExpr::getCast(opc, SplatV, DestTy->getScalarType()); + return ConstantExpr::getShuffleVector( + ConstantExpr::getInsertElement( + ConstantExpr::getCast(opc, C1, DestTy), CastedSplatV, ZeroIdx), + ConstantExpr::getCast(opc, C2, DestTy), ZeroMask); + } } } diff --git a/llvm/test/Analysis/ConstantFolding/bitcast.ll b/llvm/test/Analysis/ConstantFolding/bitcast.ll --- a/llvm/test/Analysis/ConstantFolding/bitcast.ll +++ b/llvm/test/Analysis/ConstantFolding/bitcast.ll @@ -3,7 +3,7 @@ define @bitcast_scalable_constant() { ; CHECK-LABEL: @bitcast_scalable_constant( -; CHECK-NEXT: ret bitcast ( shufflevector ( insertelement ( undef, i32 1, i32 0), undef, zeroinitializer) to ) +; CHECK-NEXT: ret shufflevector ( insertelement ( undef, float 0x36A0000000000000, i32 0), undef, zeroinitializer) ; %i1 = insertelement undef, i32 1, i32 0 %i2 = shufflevector %i1, undef, zeroinitializer