diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2110,10 +2110,56 @@ auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { // Check for constant shift amounts that sum to the bitwidth. // TODO: Support non-uniform shift amounts. - const APInt *LC, *RC; - if (match(L, m_APIntAllowUndef(LC)) && match(R, m_APIntAllowUndef(RC))) - if (LC->ult(Width) && RC->ult(Width) && (*LC + *RC) == Width) - return ConstantInt::get(L->getType(), *LC); + Constant *LC, *RC; + if (match(L, m_Constant(LC)) && match(R, m_Constant(RC))) { + // Attempt to find matching (inrange) rotation amounts and return the left + // constant, if one is undef then assume it matches, if both are undef + // then return undef. + auto IsRotate = [Width](Constant *LC, Constant *RC) -> Constant * { + auto *LI = dyn_cast_or_null(LC); + auto *RI = dyn_cast_or_null(RC); + if ((LI && LI->getValue().uge(Width)) || + (RI && RI->getValue().uge(Width)) || !LC || !RC) + return nullptr; + if (isa(LC) && isa(RC)) + return LC; + if ((LI && isa(RC)) || + (LI && RI && (LI->getValue() + RI->getValue()) == Width)) + return LI; + if (RI && isa(LC)) { + APInt AmtL = APInt(Width, Width) - RI->getValue(); + return ConstantInt::get(RC->getType()->getContext(), AmtL); + } + return nullptr; + }; + + if (auto *RotateAmt = IsRotate(LC, RC)) + return RotateAmt; + + if (auto *LVecType = dyn_cast(LC->getType())) { + Constant *SplatL = LC->getSplatValue(/*AllowUndefs*/ true); + Constant *SplatR = RC->getSplatValue(/*AllowUndefs*/ true); + if (auto *RotateAmt = IsRotate(SplatL, SplatR)) + return ConstantVector::getSplat(LVecType->getElementCount(), + RotateAmt); + + SmallVector RotateAmts; + unsigned NumElts = LVecType->getNumElements(); + for (unsigned I = 0; I != NumElts; ++I) { + Constant *LI = LC->getAggregateElement(I); + Constant *RI = RC->getAggregateElement(I); + if (auto *RotateAmt = IsRotate(LI, RI)) { + RotateAmts.push_back(RotateAmt); + continue; + } + break; + } + if (RotateAmts.size() == NumElts) + return ConstantVector::get(RotateAmts); + } + + return nullptr; + } // For non-constant cases we don't support non-pow2 shift masks. // TODO: Is it worth matching urem as well? diff --git a/llvm/test/Transforms/InstCombine/rotate.ll b/llvm/test/Transforms/InstCombine/rotate.ll --- a/llvm/test/Transforms/InstCombine/rotate.ll +++ b/llvm/test/Transforms/InstCombine/rotate.ll @@ -122,13 +122,11 @@ ret <2 x i17> %r } -; TODO: Allow arbitrary shift constants. +; Allow arbitrary shift constants. define <2 x i32> @rotr_v2i32_constant_nonsplat(<2 x i32> %x) { ; CHECK-LABEL: @rotr_v2i32_constant_nonsplat( -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i32> [[X]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i32> [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshl.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[X]], <2 x i32> ) ; CHECK-NEXT: ret <2 x i32> [[R]] ; %shl = shl <2 x i32> %x, @@ -139,9 +137,7 @@ define <2 x i32> @rotr_v2i32_constant_nonsplat_undef0(<2 x i32> %x) { ; CHECK-LABEL: @rotr_v2i32_constant_nonsplat_undef0( -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i32> [[X]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i32> [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshl.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[X]], <2 x i32> ) ; CHECK-NEXT: ret <2 x i32> [[R]] ; %shl = shl <2 x i32> %x, @@ -152,9 +148,7 @@ define <2 x i32> @rotr_v2i32_constant_nonsplat_undef1(<2 x i32> %x) { ; CHECK-LABEL: @rotr_v2i32_constant_nonsplat_undef1( -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i32> [[X]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i32> [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshl.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[X]], <2 x i32> ) ; CHECK-NEXT: ret <2 x i32> [[R]] ; %shl = shl <2 x i32> %x, @@ -165,9 +159,7 @@ define <2 x i36> @rotl_v2i36_constant_nonsplat(<2 x i36> %x) { ; CHECK-LABEL: @rotl_v2i36_constant_nonsplat( -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i36> [[X:%.*]], -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i36> [[X]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i36> [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i36> @llvm.fshl.v2i36(<2 x i36> [[X:%.*]], <2 x i36> [[X]], <2 x i36> ) ; CHECK-NEXT: ret <2 x i36> [[R]] ; %shl = shl <2 x i36> %x, @@ -178,9 +170,7 @@ define <3 x i36> @rotl_v3i36_constant_nonsplat_undef0(<3 x i36> %x) { ; CHECK-LABEL: @rotl_v3i36_constant_nonsplat_undef0( -; CHECK-NEXT: [[SHL:%.*]] = shl <3 x i36> [[X:%.*]], -; CHECK-NEXT: [[SHR:%.*]] = lshr <3 x i36> [[X]], -; CHECK-NEXT: [[R:%.*]] = or <3 x i36> [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = call <3 x i36> @llvm.fshl.v3i36(<3 x i36> [[X:%.*]], <3 x i36> [[X]], <3 x i36> ) ; CHECK-NEXT: ret <3 x i36> [[R]] ; %shl = shl <3 x i36> %x,