Index: include/llvm/IR/PatternMatch.h =================================================================== --- include/llvm/IR/PatternMatch.h +++ include/llvm/IR/PatternMatch.h @@ -180,6 +180,40 @@ /// specified pointer to the contained APFloat. inline apfloat_match m_APFloat(const APFloat *&Res) { return Res; } +struct constantint_or_constantint_vec_match { + Constant *&Res; + constantint_or_constantint_vec_match(Constant *&R) : Res(R) {} + template bool match(ITy *V) { + if (auto *CI = dyn_cast(V)) { + Res = CI; + return true; + } + + if (V->getType()->isVectorTy()) { + if (auto *C = dyn_cast(V)) { + unsigned NumElts = V->getType()->getVectorNumElements(); + assert(NumElts != 0 && "Constant vector with no elements?"); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt || !isa(Elt)) + return false; + } + + Res = C; + return true; + } + } + + return false; + } +}; + +/// Match a ConstantInt or a vector of ConstantInts, binding the specified +/// pointer to the Constant. This is more specific than m_Constant in that it +/// explicitly ensures no ConstantExpr or Undef elements. +inline constantint_or_constantint_vec_match +m_ConstantIntOrConstantIntVec(Constant *&Res) { return Res; } + template struct constantint_match { template bool match(ITy *V) { if (const auto *CI = dyn_cast(V)) { Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1195,7 +1195,8 @@ // integer add followed by a sext. if (SExtInst *LHSConv = dyn_cast(LHS)) { // (add (sext x), cst) --> (sext (add x, cst')) - if (ConstantInt *RHSC = dyn_cast(RHS)) { + Constant *RHSC; + if (match(RHS, m_ConstantIntOrConstantIntVec(RHSC))) { if (LHSConv->hasOneUse()) { Constant *CI = ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); @@ -1231,7 +1232,8 @@ // integer add followed by a zext. if (auto *LHSConv = dyn_cast(LHS)) { // (add (zext x), cst) --> (zext (add x, cst')) - if (ConstantInt *RHSC = dyn_cast(RHS)) { + Constant *RHSC; + if (match(RHS, m_ConstantIntOrConstantIntVec(RHSC))) { if (LHSConv->hasOneUse()) { Constant *CI = ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); Index: test/Transforms/InstCombine/sink-zext.ll =================================================================== --- test/Transforms/InstCombine/sink-zext.ll +++ test/Transforms/InstCombine/sink-zext.ll @@ -68,4 +68,134 @@ ret i64 %zext } +define i64 @test5(i32 %V) { +; CHECK-LABEL: @test5( +; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[V:%.*]], 1 +; CHECK-NEXT: [[ADDCONV:%.*]] = add nsw i32 [[ASHR]], 1073741823 +; CHECK-NEXT: [[ADD:%.*]] = sext i32 [[ADDCONV]] to i64 +; CHECK-NEXT: ret i64 [[ADD]] +; + %ashr = ashr i32 %V, 1 + %sext = sext i32 %ashr to i64 + %add = add i64 %sext, 1073741823 + ret i64 %add +} + +define <2 x i64> @test5_splat(<2 x i32> %V) { +; CHECK-LABEL: @test5_splat( +; CHECK-NEXT: [[ASHR:%.*]] = ashr <2 x i32> [[V:%.*]], +; CHECK-NEXT: [[ADDCONV:%.*]] = add nsw <2 x i32> [[ASHR]], +; CHECK-NEXT: [[ADD:%.*]] = sext <2 x i32> [[ADDCONV]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %ashr = ashr <2 x i32> %V, + %sext = sext <2 x i32> %ashr to <2 x i64> + %add = add <2 x i64> %sext, + ret <2 x i64> %add +} + +define <2 x i64> @test5_vec(<2 x i32> %V) { +; CHECK-LABEL: @test5_vec( +; CHECK-NEXT: [[ASHR:%.*]] = ashr <2 x i32> [[V:%.*]], +; CHECK-NEXT: [[ADDCONV:%.*]] = add nsw <2 x i32> [[ASHR]], +; CHECK-NEXT: [[ADD:%.*]] = sext <2 x i32> [[ADDCONV]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %ashr = ashr <2 x i32> %V, + %sext = sext <2 x i32> %ashr to <2 x i64> + %add = add <2 x i64> %sext, + ret <2 x i64> %add +} + +define i64 @test6(i32 %V) { +; CHECK-LABEL: @test6( +; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[V:%.*]], 1 +; CHECK-NEXT: [[ADDCONV:%.*]] = add nsw i32 [[ASHR]], -1073741824 +; CHECK-NEXT: [[ADD:%.*]] = sext i32 [[ADDCONV]] to i64 +; CHECK-NEXT: ret i64 [[ADD]] +; + %ashr = ashr i32 %V, 1 + %sext = sext i32 %ashr to i64 + %add = add i64 %sext, -1073741824 + ret i64 %add +} + +define <2 x i64> @test6_splat(<2 x i32> %V) { +; CHECK-LABEL: @test6_splat( +; CHECK-NEXT: [[ASHR:%.*]] = ashr <2 x i32> [[V:%.*]], +; CHECK-NEXT: [[ADDCONV:%.*]] = add nsw <2 x i32> [[ASHR]], +; CHECK-NEXT: [[ADD:%.*]] = sext <2 x i32> [[ADDCONV]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %ashr = ashr <2 x i32> %V, + %sext = sext <2 x i32> %ashr to <2 x i64> + %add = add <2 x i64> %sext, + ret <2 x i64> %add +} + +define <2 x i64> @test6_vec(<2 x i32> %V) { +; CHECK-LABEL: @test6_vec( +; CHECK-NEXT: [[ASHR:%.*]] = ashr <2 x i32> [[V:%.*]], +; CHECK-NEXT: [[ADDCONV:%.*]] = add nsw <2 x i32> [[ASHR]], +; CHECK-NEXT: [[ADD:%.*]] = sext <2 x i32> [[ADDCONV]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %ashr = ashr <2 x i32> %V, + %sext = sext <2 x i32> %ashr to <2 x i64> + %add = add <2 x i64> %sext, + ret <2 x i64> %add +} + +define <2 x i64> @test6_vec2(<2 x i32> %V) { +; CHECK-LABEL: @test6_vec2( +; CHECK-NEXT: [[ASHR:%.*]] = ashr <2 x i32> [[V:%.*]], +; CHECK-NEXT: [[ADDCONV:%.*]] = add nsw <2 x i32> [[ASHR]], +; CHECK-NEXT: [[ADD:%.*]] = sext <2 x i32> [[ADDCONV]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %ashr = ashr <2 x i32> %V, + %sext = sext <2 x i32> %ashr to <2 x i64> + %add = add <2 x i64> %sext, + ret <2 x i64> %add +} + +define i64 @test7(i32 %V) { +; CHECK-LABEL: @test7( +; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[V:%.*]], 1 +; CHECK-NEXT: [[ADDCONV:%.*]] = add nuw i32 [[LSHR]], 2147483647 +; CHECK-NEXT: [[ADD:%.*]] = zext i32 [[ADDCONV]] to i64 +; CHECK-NEXT: ret i64 [[ADD]] +; + %lshr = lshr i32 %V, 1 + %zext = zext i32 %lshr to i64 + %add = add i64 %zext, 2147483647 + ret i64 %add +} + +define <2 x i64> @test7_splat(<2 x i32> %V) { +; CHECK-LABEL: @test7_splat( +; CHECK-NEXT: [[LSHR:%.*]] = lshr <2 x i32> [[V:%.*]], +; CHECK-NEXT: [[ADDCONV:%.*]] = add nuw <2 x i32> [[LSHR]], +; CHECK-NEXT: [[ADD:%.*]] = zext <2 x i32> [[ADDCONV]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %lshr = lshr <2 x i32> %V, + %zext = zext <2 x i32> %lshr to <2 x i64> + %add = add <2 x i64> %zext, + ret <2 x i64> %add +} + +define <2 x i64> @test7_vec(<2 x i32> %V) { +; CHECK-LABEL: @test7_vec( +; CHECK-NEXT: [[LSHR:%.*]] = lshr <2 x i32> [[V:%.*]], +; CHECK-NEXT: [[ADDCONV:%.*]] = add nuw <2 x i32> [[LSHR]], +; CHECK-NEXT: [[ADD:%.*]] = zext <2 x i32> [[ADDCONV]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %lshr = lshr <2 x i32> %V, + %zext = zext <2 x i32> %lshr to <2 x i64> + %add = add <2 x i64> %zext, + ret <2 x i64> %add +} + !0 = !{ i32 0, i32 2000 }