diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1026,60 +1026,82 @@ return SelectInst::Create(X, TVal, FVal); } -static Constant *constantFoldOperationIntoSelectOperand( - Instruction &I, SelectInst *SI, Value *SO) { - auto *ConstSO = dyn_cast(SO); - if (!ConstSO) - return nullptr; - - SmallVector ConstOps; +static SmallVector +getInstSelectArmOperands(Instruction &I, SelectInst *SI, bool IsTrueArm) { + SmallVector Ops; for (Value *Op : I.operands()) { + CmpInst::Predicate Pred; + Constant *CondC; if (Op == SI) - ConstOps.push_back(ConstSO); - else if (auto *C = dyn_cast(Op)) + Ops.push_back(IsTrueArm ? SI->getTrueValue() : SI->getFalseValue()); + else if (match(SI->getCondition(), + m_ICmp(Pred, m_Specific(Op), m_Constant(CondC))) && + Pred == (IsTrueArm ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE)) + Ops.push_back(CondC); + else + Ops.push_back(Op); + } + + return Ops; +} + +static Constant *constantFoldOperationIntoSelectOperand(Instruction &I, + SelectInst *SI, + bool IsTrueArm) { + SmallVector ConstOps; + SmallVector Ops = getInstSelectArmOperands(I, SI, IsTrueArm); + for (Value *Op : Ops) { + if (auto *C = dyn_cast(Op)) ConstOps.push_back(C); else - llvm_unreachable("Operands should be select or constant"); + return nullptr; } - return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); + return ConstantFoldInstOperands(&I, ConstOps, + I.getModule()->getDataLayout()); } -static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, - InstCombiner::BuilderTy &Builder) { +static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, + InstCombiner::BuilderTy &Builder, + bool IsTrueArm) { + + Value *SO = IsTrueArm ? SI->getTrueValue() : SI->getFalseValue(); if (auto *Cast = dyn_cast(&I)) return Builder.CreateCast(Cast->getOpcode(), SO, I.getType()); + SmallVector Ops = getInstSelectArmOperands(I, SI, IsTrueArm); + if (auto *II = dyn_cast(&I)) { - assert(canConstantFoldCallTo(II, cast(II->getCalledOperand())) && - "Expected constant-foldable intrinsic"); + // assert(Ops.size() == II->arg_size() && + // "Failed to collect IntrinsicArguments"); + // TODO: Use the select condition to make II->getCalledOperand() constant + // here. + if (!canConstantFoldCallTo(II, cast(II->getCalledOperand()))) + return nullptr; Intrinsic::ID IID = II->getIntrinsicID(); if (II->arg_size() == 1) return Builder.CreateUnaryIntrinsic(IID, SO); + if (!(isa(Ops[0]) || isa(Ops[1]))) + return nullptr; + // This works for real binary ops like min/max (where we always expect the // constant operand to be canonicalized as op1) and unary ops with a bonus // constant argument like ctlz/cttz. - // TODO: Handle non-commutative binary intrinsics as below for binops. assert(II->arg_size() == 2 && "Expected binary intrinsic"); - assert(isa(II->getArgOperand(1)) && "Expected constant operand"); - return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1)); + return Builder.CreateBinaryIntrinsic(IID, Ops[0], Ops[1]); } if (auto *EI = dyn_cast(&I)) return Builder.CreateExtractElement(SO, EI->getIndexOperand()); - assert(I.isBinaryOp() && "Unexpected opcode for select folding"); - - // Figure out if the constant is the left or the right argument. - bool ConstIsRHS = isa(I.getOperand(1)); - Constant *ConstOperand = cast(I.getOperand(ConstIsRHS)); + if (!(isa(Ops[0]) || isa(Ops[1]))) + return nullptr; - Value *Op0 = SO, *Op1 = ConstOperand; - if (!ConstIsRHS) - std::swap(Op0, Op1); + assert(I.isBinaryOp() && Ops.size() == 2 && + "Unexpected opcode for select folding"); - Value *NewBO = Builder.CreateBinOp(cast(&I)->getOpcode(), Op0, - Op1, SO->getName() + ".op"); + Value *NewBO = Builder.CreateBinOp(cast(&I)->getOpcode(), + Ops[0], Ops[1], SO->getName() + ".op"); if (auto *NewBOI = dyn_cast(NewBO)) NewBOI->copyIRFlags(&I); return NewBO; @@ -1155,16 +1177,23 @@ } // Make sure that one of the select arms constant folds successfully. - Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, TV); - Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, FV); + Value *NewTV = + constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true); + Value *NewFV = + constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false); if (!NewTV && !NewFV) return nullptr; // Create an instruction for the arm that did not fold. if (!NewTV) - NewTV = foldOperationIntoSelectOperand(Op, TV, Builder); + if (!(NewTV = foldOperationIntoSelectOperand(Op, SI, Builder, + /*IsTrueArm*/ true))) + return nullptr; if (!NewFV) - NewFV = foldOperationIntoSelectOperand(Op, FV, Builder); + if (!(NewFV = foldOperationIntoSelectOperand(Op, SI, Builder, + /*IsTrueArm*/ false))) + return nullptr; + return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); }