diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -153,10 +153,16 @@ return class_match(); } -/// Match an arbitrary ConstantExpr and ignore it. -inline class_match m_ConstantExpr() { - return class_match(); -} +struct constantexpr_match { + template bool match(ITy *V) { + auto *C = dyn_cast(V); + return C && (isa(C) || C->containsConstantExpression()); + } +}; + +/// Match a constant expression or a constant that contains a constant +/// expression. +inline constantexpr_match m_ConstantExpr() { return constantexpr_match(); } /// Match an arbitrary basic block value and ignore it. inline class_match m_BasicBlock() { @@ -741,14 +747,14 @@ /// Match an arbitrary immediate Constant and ignore it. inline match_combine_and, - match_unless>> + match_unless> m_ImmConstant() { return m_CombineAnd(m_Constant(), m_Unless(m_ConstantExpr())); } /// Match an immediate Constant, capturing the value if we match. inline match_combine_and, - match_unless>> + match_unless> m_ImmConstant(Constant *&C) { return m_CombineAnd(m_Constant(C), m_Unless(m_ConstantExpr())); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1539,8 +1539,7 @@ Type *Ty = II->getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); Constant *ShAmtC; - if (match(II->getArgOperand(2), m_ImmConstant(ShAmtC)) && - !ShAmtC->containsConstantExpression()) { + if (match(II->getArgOperand(2), m_ImmConstant(ShAmtC))) { // Canonicalize a shift amount constant operand to modulo the bit-width. Constant *WidthC = ConstantInt::get(Ty, BitWidth); Constant *ModuloC = diff --git a/llvm/test/Transforms/InstCombine/sub-of-negatible.ll b/llvm/test/Transforms/InstCombine/sub-of-negatible.ll --- a/llvm/test/Transforms/InstCombine/sub-of-negatible.ll +++ b/llvm/test/Transforms/InstCombine/sub-of-negatible.ll @@ -1426,5 +1426,28 @@ br label %if.end } +; This would infinite loop because we failed to match a +; vector constant with constant expression elements as +; a constant expression. + +@g = external hidden global [1 x [1 x double]] + +define <1 x i64> @PR56601(<1 x i64> %x, <1 x i64> %y) { +; CHECK-LABEL: @PR56601( +; CHECK-NEXT: [[M1:%.*]] = mul nsw <1 x i64> [[X:%.*]], +; CHECK-NEXT: [[M2:%.*]] = mul nsw <1 x i64> [[Y:%.*]], +; CHECK-NEXT: [[A1:%.*]] = add <1 x i64> [[M1]], +; CHECK-NEXT: [[A2:%.*]] = add <1 x i64> [[M2]], +; CHECK-NEXT: [[R:%.*]] = sub <1 x i64> [[A1]], [[A2]] +; CHECK-NEXT: ret <1 x i64> [[R]] +; + %m1 = mul nsw <1 x i64> %x, + %m2 = mul nsw <1 x i64> %y, + %a1 = add <1 x i64> %m1, + %a2 = add <1 x i64> %m2, + %r = sub <1 x i64> %a1, %a2 + ret <1 x i64> %r +} + ; CHECK: !0 = !{!"branch_weights", i32 40, i32 1} !0 = !{!"branch_weights", i32 40, i32 1} diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -1798,8 +1798,10 @@ PoisonValue *P = PoisonValue::get(VecTy); Constant *V = ConstantExpr::getInsertElement(P, S, IRB.getInt32(0)); + // The match succeeds on a constant that is a constant expression itself + // or a constant that contains a constant expression. EXPECT_TRUE(match(S, m_ConstantExpr())); - EXPECT_FALSE(match(V, m_ConstantExpr())); + EXPECT_TRUE(match(V, m_ConstantExpr())); } } // anonymous namespace.