diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -657,23 +657,45 @@ /// Match a specified Value*. struct specificval_ty { const Value *Val; + bool AllowFrozenVal; - specificval_ty(const Value *V) : Val(V) {} + specificval_ty(const Value *V, bool AllowFrozenVal = false) + : Val(V), AllowFrozenVal(AllowFrozenVal) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { + if (AllowFrozenVal) { + auto *FI = dyn_cast(Val); + if (FI && FI->getOperand(0) == V) + return true; + } + return V == Val; + } }; /// Match if we have a specific specified value. inline specificval_ty m_Specific(const Value *V) { return V; } +inline specificval_ty m_SpecificMaybeFrozen(const Value *V) { + return {V, true}; +} + /// Stores a reference to the Value *, not the Value * itself, /// thus can be used in commutative matchers. template struct deferredval_ty { Class *const &Val; + bool AllowFrozenVal; - deferredval_ty(Class *const &V) : Val(V) {} + deferredval_ty(Class *const &V, bool AllowFrozenVal = false) + : Val(V), AllowFrozenVal(AllowFrozenVal) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *const V) { + if (AllowFrozenVal) { + auto *FI = dyn_cast(Val); + if (FI && FI->getOperand(0) == V) + return true; + } + return V == Val; + } }; /// A commutative-friendly version of m_Specific(). @@ -682,6 +704,14 @@ return V; } +inline deferredval_ty m_DeferredMaybeFrozen(Value *const &V) { + return {V, true}; +} +inline deferredval_ty +m_DeferredMaybeFrozen(const Value *const &V) { + return {V, true}; +} + /// Match a specified floating point value or vector of all elements of /// that value. struct specific_fpval { @@ -1339,6 +1369,13 @@ return OneOps_match(Op); } +/// Matches FreezeInst(Op) or Op. +template +inline match_combine_or, OpTy> +m_FreezeOrSelf(const OpTy &Op) { + return m_CombineOr(m_Freeze(Op), Op); +} + /// Matches InsertElementInst. template inline ThreeOps_match diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -123,6 +123,13 @@ return ConstantInt::getTrue(Ty); } +/// If it is a freeze instruction, returns its operand +static Value *StripFreeze(Value *V) { + if (auto *FI = dyn_cast(V)) + return FI->getOperand(0); + return V; +} + /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"? static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, Value *RHS) { @@ -634,9 +641,11 @@ // X + (Y - X) -> Y // (Y - X) + X -> Y // Eg: X + -X -> 0 + // This is valid when either of X is replaced with freeze(X). Value *Y = nullptr; - if (match(Op1, m_Sub(m_Value(Y), m_Specific(Op0))) || - match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)))) + if (match(Op1, + m_Sub(m_Value(Y), m_FreezeOrSelf(m_SpecificMaybeFrozen(Op0)))) || + match(Op0, m_Sub(m_Value(Y), m_FreezeOrSelf(m_SpecificMaybeFrozen(Op1))))) return Y; // X + ~X -> -1 since ~X = -X-1 @@ -748,7 +757,8 @@ return Op0; // X - X -> 0 - if (Op0 == Op1) + // This is valid when either of X is replaced with freeze(X). + if (StripFreeze(Op0) == StripFreeze(Op1)) return Constant::getNullValue(Op0->getType()); // Is this a negation? @@ -966,7 +976,8 @@ // X / X -> 1 // X % X -> 0 - if (Op0 == Op1) + // This is valid when either of X is replaced with freeze(X). + if (StripFreeze(Op0) == StripFreeze(Op1)) return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty); // X / 1 -> X @@ -1056,8 +1067,10 @@ bool IsSigned = Opcode == Instruction::SDiv; // (X * Y) / Y -> X if the multiplication does not overflow. + // This is valid when any of Y is replaced with freeze(Y). Value *X; - if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { + if (match(Op0, + m_c_Mul(m_Value(X), m_FreezeOrSelf(m_SpecificMaybeFrozen(Op1))))) { auto *Mul = cast(Op0); // If the Mul does not overflow, then we are good to go. if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) || @@ -1070,8 +1083,13 @@ } // (X rem Y) / Y -> 0 - if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || - (!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) + // This is valid when either of Y is replaced with freeze(Y). + if ((IsSigned && + match(Op0, + m_SRem(m_Value(), m_FreezeOrSelf(m_SpecificMaybeFrozen(Op1))))) || + (!IsSigned && + match(Op0, + m_URem(m_Value(), m_FreezeOrSelf(m_SpecificMaybeFrozen(Op1)))))) return Constant::getNullValue(Op0->getType()); // (X /u C1) /u C2 -> 0 if C1 * C2 overflow diff --git a/llvm/test/Transforms/InstSimplify/add.ll b/llvm/test/Transforms/InstSimplify/add.ll --- a/llvm/test/Transforms/InstSimplify/add.ll +++ b/llvm/test/Transforms/InstSimplify/add.ll @@ -10,6 +10,36 @@ ret i32 %Q } +define void @common_sub_operand_fr(i8 %x, i8 %y) { +; CHECK-LABEL: @common_sub_operand_fr( +; CHECK-NEXT: call void @f(i8 [[Y:%.*]]) +; CHECK-NEXT: call void @f(i8 [[Y]]) +; CHECK-NEXT: ret void +; + %x.fr = freeze i8 %x + %op1 = sub i8 %y, %x.fr + %res = add i8 %x, %op1 + call void @f(i8 %res) + %res2 = add i8 %op1, %x + call void @f(i8 %res2) + ret void +} + +define void @common_sub_operand_fr2(i8 %x, i8 %y) { +; CHECK-LABEL: @common_sub_operand_fr2( +; CHECK-NEXT: call void @f(i8 [[Y:%.*]]) +; CHECK-NEXT: call void @f(i8 [[Y]]) +; CHECK-NEXT: ret void +; + %x.fr = freeze i8 %x + %op1 = sub i8 %y, %x + %res = add i8 %x.fr, %op1 + call void @f(i8 %res) + %res2 = add i8 %op1, %x.fr + call void @f(i8 %res2) + ret void +} + define i32 @negated_operand(i32 %x) { ; CHECK-LABEL: @negated_operand( ; CHECK-NEXT: ret i32 0 @@ -30,7 +60,7 @@ define i8 @knownnegation(i8 %x, i8 %y) { ; CHECK-LABEL: @knownnegation( -; CHECK-NEXT: ret i8 0 +; CHECK-NEXT: ret i8 0 ; %xy = sub i8 %x, %y %yx = sub i8 %y, %x @@ -48,4 +78,4 @@ ret <2 x i8> %r } - +declare void @f(i8) diff --git a/llvm/test/Transforms/InstSimplify/div.ll b/llvm/test/Transforms/InstSimplify/div.ll --- a/llvm/test/Transforms/InstSimplify/div.ll +++ b/llvm/test/Transforms/InstSimplify/div.ll @@ -193,4 +193,97 @@ ret i32 %urem } +define void @div_self_fr(i8 %x, i8 %y) { +; CHECK-LABEL: @div_self_fr( +; CHECK-NEXT: call void @f(i8 1) +; CHECK-NEXT: call void @f(i8 1) +; CHECK-NEXT: call void @f(i8 1) +; CHECK-NEXT: call void @f(i8 1) +; CHECK-NEXT: call void @f(i8 0) +; CHECK-NEXT: call void @f(i8 0) +; CHECK-NEXT: call void @f(i8 0) +; CHECK-NEXT: call void @f(i8 0) +; CHECK-NEXT: ret void +; + %x.fr = freeze i8 %x + + %res1 = sdiv i8 %x, %x.fr + call void @f(i8 %res1) + %res2 = sdiv i8 %x.fr, %x + call void @f(i8 %res2) + %res3 = udiv i8 %x, %x.fr + call void @f(i8 %res3) + %res4 = udiv i8 %x.fr, %x + call void @f(i8 %res4) + + %res5 = srem i8 %x, %x.fr + call void @f(i8 %res5) + %res6 = srem i8 %x.fr, %x + call void @f(i8 %res6) + %res7 = urem i8 %x, %x.fr + call void @f(i8 %res7) + %res8 = urem i8 %x.fr, %x + call void @f(i8 %res8) + + ret void +} + +define void @mul_div_fr(i8 %x, i8 %y) { +; CHECK-LABEL: @mul_div_fr( +; CHECK-NEXT: call void @f(i8 [[X:%.*]]) +; CHECK-NEXT: call void @f(i8 [[X]]) +; CHECK-NEXT: call void @f(i8 [[X]]) +; CHECK-NEXT: call void @f(i8 [[X]]) +; CHECK-NEXT: ret void +; + %y.fr = freeze i8 %y + %op0 = mul nuw i8 %x, %y.fr + %res = udiv i8 %op0, %y + call void @f(i8 %res) + + %op0_2 = mul nuw i8 %x, %y + %res2 = udiv i8 %op0_2, %y.fr + call void @f(i8 %res2) + + %op0_3 = mul nsw i8 %x, %y.fr + %res3 = sdiv i8 %op0_3, %y + call void @f(i8 %res3) + + %op0_4 = mul nsw i8 %x, %y + %res4 = sdiv i8 %op0_4, %y.fr + call void @f(i8 %res4) + + ret void +} + +define void @rem_div_fr(i8 %x, i8 %y) { +; CHECK-LABEL: @rem_div_fr( +; CHECK-NEXT: call void @f(i8 0) +; CHECK-NEXT: call void @f(i8 0) +; CHECK-NEXT: call void @f(i8 0) +; CHECK-NEXT: call void @f(i8 0) +; CHECK-NEXT: ret void +; + %y.fr = freeze i8 %y + %op0 = urem i8 %x, %y.fr + %res = udiv i8 %op0, %y + call void @f(i8 %res) + + %op0_2 = urem i8 %x, %y + %res2 = udiv i8 %op0_2, %y.fr + call void @f(i8 %res2) + + %op0_3 = srem i8 %x, %y.fr + %res3 = sdiv i8 %op0_3, %y + call void @f(i8 %res3) + + %op0_4 = srem i8 %x, %y + %res4 = sdiv i8 %op0_4, %y.fr + call void @f(i8 %res4) + + ret void +} + +declare void @f(i8) + !0 = !{i32 0, i32 3} diff --git a/llvm/test/Transforms/InstSimplify/rem.ll b/llvm/test/Transforms/InstSimplify/rem.ll --- a/llvm/test/Transforms/InstSimplify/rem.ll +++ b/llvm/test/Transforms/InstSimplify/rem.ll @@ -325,3 +325,26 @@ ret <2 x i32> %r } +define void @div_self_fr(i32 %x, i32 %y) { +; CHECK-LABEL: @div_self_fr( +; CHECK-NEXT: call void @f(i32 0) +; CHECK-NEXT: call void @f(i32 0) +; CHECK-NEXT: call void @f(i32 0) +; CHECK-NEXT: call void @f(i32 0) +; CHECK-NEXT: ret void +; + %x.fr = freeze i32 %x + + %res5 = srem i32 %x, %x.fr + call void @f(i32 %res5) + %res6 = srem i32 %x.fr, %x + call void @f(i32 %res6) + %res7 = urem i32 %x, %x.fr + call void @f(i32 %res7) + %res8 = urem i32 %x.fr, %x + call void @f(i32 %res8) + + ret void +} + +declare void @f(i32) diff --git a/llvm/test/Transforms/InstSimplify/sub.ll b/llvm/test/Transforms/InstSimplify/sub.ll --- a/llvm/test/Transforms/InstSimplify/sub.ll +++ b/llvm/test/Transforms/InstSimplify/sub.ll @@ -9,6 +9,20 @@ ret i32 %B } +define void @sub_self_fr(i32 %x, i32 %y) { +; CHECK-LABEL: @sub_self_fr( +; CHECK-NEXT: call void @f(i32 0) +; CHECK-NEXT: call void @f(i32 0) +; CHECK-NEXT: ret void +; + %x.fr = freeze i32 %x + %res = sub i32 %x.fr, %x + call void @f(i32 %res) + %res2 = sub i32 %x, %x.fr + call void @f(i32 %res2) + ret void +} + define <2 x i32> @sub_self_vec(<2 x i32> %A) { ; CHECK-LABEL: @sub_self_vec( ; CHECK-NEXT: ret <2 x i32> zeroinitializer @@ -51,3 +65,4 @@ ret <2 x i32> %C } +declare void @f(i32)