diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -145,8 +145,7 @@ // Please use the SimplifyQuery versions in new code. /// Given operand for an FNeg, fold the result or return null. -Value *SimplifyFNegInst(Value *Op, FastMathFlags FMF, - const SimplifyQuery &Q); +Value *SimplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q); /// Given operands for an Add, fold the result or return null. Value *SimplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, @@ -297,8 +296,8 @@ /// Given operands for a BinaryOperator, fold the result or return null. /// Try to use FastMathFlags when folding the result. -Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q); +Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q); /// Given a callsite, fold the result or return null. Value *SimplifyCall(CallBase *Call, const SimplifyQuery &Q); @@ -312,6 +311,13 @@ Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, OptimizationRemarkEmitter *ORE = nullptr); +/// Like \p SimplifyInstruction but the operands of \p I are replaced with +/// \p NewOps. Returns a simplified value, or null if none was found. +Value * +SimplifyInstructionWithOperands(Instruction *I, ArrayRef NewOps, + const SimplifyQuery &Q, + OptimizationRemarkEmitter *ORE = nullptr); + /// See if V simplifies when its operand Op is replaced with RepOp. If not, /// return null. /// AllowRefinement specifies whether the simplification can be a refinement @@ -345,4 +351,3 @@ } // end namespace llvm #endif - diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -17,6 +17,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/InstructionSimplify.h" + +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" @@ -4567,7 +4569,8 @@ } /// See if we can fold the given phi. If not, returns null. -static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { +static Value *SimplifyPHINode(PHINode *PN, ArrayRef IncomingValues, + const SimplifyQuery &Q) { // WARNING: no matter how worthwhile it may seem, we can not perform PHI CSE // here, because the PHI we may succeed simplifying to was not // def-reachable from the original PHI! @@ -4576,7 +4579,7 @@ // with the common value. Value *CommonValue = nullptr; bool HasUndefInput = false; - for (Value *Incoming : PN->incoming_values()) { + for (Value *Incoming : IncomingValues) { // If the incoming value is the phi node itself, it can safely be skipped. if (Incoming == PN) continue; if (Q.isUndefValue(Incoming)) { @@ -6040,21 +6043,19 @@ return NewOp; } -static Value *SimplifyLoadInst(LoadInst *LI, const SimplifyQuery &Q) { +static Value *SimplifyLoadInst(LoadInst *LI, Value *PtrOp, + const SimplifyQuery &Q) { if (LI->isVolatile()) return nullptr; - if (auto *C = ConstantFoldInstruction(LI, Q.DL)) - return C; + // Try to make the load operand a constant, specifically handle + // invariant.group intrinsics. + auto *PtrOpC = dyn_cast(PtrOp); + if (!PtrOpC) + PtrOpC = ConstructLoadOperandConstant(PtrOp); - // The following only catches more cases than ConstantFoldInstruction() if the - // load operand wasn't a constant. Specifically, invariant.group intrinsics. - if (isa(LI->getPointerOperand())) - return nullptr; - - if (auto *C = dyn_cast_or_null( - ConstructLoadOperandConstant(LI->getPointerOperand()))) - return ConstantFoldLoadFromConstPtr(C, LI->getType(), Q.DL); + if (PtrOpC) + return ConstantFoldLoadFromConstPtr(PtrOpC, LI->getType(), Q.DL); return nullptr; } @@ -6062,161 +6063,149 @@ /// See if we can compute a simplified version of this instruction. /// If not, this returns null. -Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { +static Value *simplifyInstructionWithOperands(Instruction *I, + ArrayRef NewOps, + const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); - Value *Result; + Value *Result = nullptr; switch (I->getOpcode()) { default: - Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); + if (llvm::all_of(NewOps, [](Value *V) { return isa(V); })) { + SmallVector NewConstOps(NewOps.size()); + transform(NewOps, NewConstOps.begin(), + [](Value *V) { return cast(V); }); + Result = ConstantFoldInstOperands(I, NewConstOps, Q.DL, Q.TLI); + } break; case Instruction::FNeg: - Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q); + Result = SimplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q); break; case Instruction::FAdd: - Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Add: - Result = - SimplifyAddInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifyAddInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast(I)), + Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); break; case Instruction::FSub: - Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Sub: - Result = - SimplifySubInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifySubInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast(I)), + Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); break; case Instruction::FMul: - Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Mul: - Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyMulInst(NewOps[0], NewOps[1], Q); break; case Instruction::SDiv: - Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifySDivInst(NewOps[0], NewOps[1], Q); break; case Instruction::UDiv: - Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyUDivInst(NewOps[0], NewOps[1], Q); break; case Instruction::FDiv: - Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::SRem: - Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifySRemInst(NewOps[0], NewOps[1], Q); break; case Instruction::URem: - Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyURemInst(NewOps[0], NewOps[1], Q); break; case Instruction::FRem: - Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Shl: - Result = - SimplifyShlInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifyShlInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast(I)), + Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); break; case Instruction::LShr: - Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), + Result = SimplifyLShrInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast(I)), Q); break; case Instruction::AShr: - Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), + Result = SimplifyAShrInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast(I)), Q); break; case Instruction::And: - Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyAndInst(NewOps[0], NewOps[1], Q); break; case Instruction::Or: - Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyOrInst(NewOps[0], NewOps[1], Q); break; case Instruction::Xor: - Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyXorInst(NewOps[0], NewOps[1], Q); break; case Instruction::ICmp: - Result = SimplifyICmpInst(cast(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyICmpInst(cast(I)->getPredicate(), NewOps[0], + NewOps[1], Q); break; case Instruction::FCmp: - Result = - SimplifyFCmpInst(cast(I)->getPredicate(), I->getOperand(0), - I->getOperand(1), I->getFastMathFlags(), Q); + Result = SimplifyFCmpInst(cast(I)->getPredicate(), NewOps[0], + NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Select: - Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), - I->getOperand(2), Q); + Result = SimplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q); break; case Instruction::GetElementPtr: { - SmallVector Ops(I->operands()); Result = SimplifyGEPInst(cast(I)->getSourceElementType(), - Ops, Q); + NewOps, Q); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast(I); - Result = SimplifyInsertValueInst(IV->getAggregateOperand(), - IV->getInsertedValueOperand(), - IV->getIndices(), Q); + Result = SimplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q); break; } case Instruction::InsertElement: { - auto *IE = cast(I); - Result = SimplifyInsertElementInst(IE->getOperand(0), IE->getOperand(1), - IE->getOperand(2), Q); + Result = SimplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q); break; } case Instruction::ExtractValue: { auto *EVI = cast(I); - Result = SimplifyExtractValueInst(EVI->getAggregateOperand(), - EVI->getIndices(), Q); + Result = SimplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q); break; } case Instruction::ExtractElement: { - auto *EEI = cast(I); - Result = SimplifyExtractElementInst(EEI->getVectorOperand(), - EEI->getIndexOperand(), Q); + Result = SimplifyExtractElementInst(NewOps[0], NewOps[1], Q); break; } case Instruction::ShuffleVector: { auto *SVI = cast(I); - Result = - SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), - SVI->getShuffleMask(), SVI->getType(), Q); + Result = SimplifyShuffleVectorInst( + NewOps[0], NewOps[1], SVI->getShuffleMask(), SVI->getType(), Q); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast(I), Q); + Result = SimplifyPHINode(cast(I), NewOps, Q); break; case Instruction::Call: { + // TODO: Use NewOps Result = SimplifyCall(cast(I), Q); break; } case Instruction::Freeze: - Result = SimplifyFreezeInst(I->getOperand(0), Q); + Result = llvm::SimplifyFreezeInst(NewOps[0], Q); break; #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - Result = - SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); + Result = SimplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q); break; case Instruction::Alloca: // No simplifications for Alloca and it can't be constant folded. Result = nullptr; break; case Instruction::Load: - Result = SimplifyLoadInst(cast(I), Q); + Result = SimplifyLoadInst(cast(I), NewOps[0], Q); break; } @@ -6226,6 +6215,19 @@ return Result == I ? UndefValue::get(I->getType()) : Result; } +Value *llvm::SimplifyInstructionWithOperands(Instruction *I, + ArrayRef NewOps, + const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + return ::simplifyInstructionWithOperands(I, NewOps, SQ, ORE); +} + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + SmallVector Ops(I->operands()); + return ::simplifyInstructionWithOperands(I, Ops, SQ, ORE); +} + /// Implementation of recursive simplification through an instruction's /// uses. ///