diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -52,6 +52,37 @@ return getCalledFunction()->getIntrinsicID(); } + /// Return true if swapping the first two arguments to the intrinsic produces + /// the same result. + bool isCommutative() { + switch (getIntrinsicID()) { + case Intrinsic::maxnum: + case Intrinsic::minnum: + case Intrinsic::maximum: + case Intrinsic::minimum: + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::umax: + case Intrinsic::umin: + case Intrinsic::sadd_sat: + case Intrinsic::uadd_sat: + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + // TODO: These fixed-point math intrinsics have commutative first two + // operands, but callers may not handle instructions with more than + // two operands. + // case Intrinsic::smul_fix: + // case Intrinsic::umul_fix: + // case Intrinsic::smul_fix_sat: + // case Intrinsic::umul_fix_sat: + return true; + default: + return false; + } + } + // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const CallInst *I) { if (const Function *CF = I->getCalledFunction()) diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -288,6 +288,17 @@ isa(Inst)) && "Invalid/unknown instruction"); + // Handle intrinsics with commutative operands. + // TODO: Extend this to handle intrinsics with >2 operands where the 1st + // 2 operands are commutative. + auto *II = dyn_cast(Inst); + if (II && II->isCommutative() && II->getNumArgOperands() == 2) { + Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); + if (LHS > RHS) + std::swap(LHS, RHS); + return hash_combine(II->getOpcode(), LHS, RHS); + } + // Mix in the opcode. return hash_combine( Inst->getOpcode(), @@ -340,6 +351,15 @@ LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate(); } + // TODO: Extend this for >2 args by matching the trailing N-2 args. + auto *LII = dyn_cast(LHSI); + auto *RII = dyn_cast(RHSI); + if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() && + LII->isCommutative() && LII->getNumArgOperands() == 2) { + return LII->getArgOperand(0) == RII->getArgOperand(1) && + LII->getArgOperand(1) == RII->getArgOperand(0); + } + // Min/max/abs can occur with commuted operands, non-canonical predicates, // and/or non-canonical operands. // Selects can be non-trivially equivalent via inverted conditions and swaps. diff --git a/llvm/test/Transforms/EarlyCSE/commute.ll b/llvm/test/Transforms/EarlyCSE/commute.ll --- a/llvm/test/Transforms/EarlyCSE/commute.ll +++ b/llvm/test/Transforms/EarlyCSE/commute.ll @@ -766,9 +766,7 @@ define float @maxnum(float %a, float %b) { ; CHECK-LABEL: @maxnum( ; CHECK-NEXT: [[X:%.*]] = call float @llvm.maxnum.f32(float [[A:%.*]], float [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call float @llvm.maxnum.f32(float [[B]], float [[A]]) -; CHECK-NEXT: [[R:%.*]] = fdiv nnan float [[X]], [[Y]] -; CHECK-NEXT: ret float [[R]] +; CHECK-NEXT: ret float 1.000000e+00 ; %x = call float @llvm.maxnum.f32(float %a, float %b) %y = call float @llvm.maxnum.f32(float %b, float %a) @@ -779,9 +777,7 @@ define <2 x float> @minnum(<2 x float> %a, <2 x float> %b) { ; CHECK-LABEL: @minnum( ; CHECK-NEXT: [[X:%.*]] = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> [[A:%.*]], <2 x float> [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> [[B]], <2 x float> [[A]]) -; CHECK-NEXT: [[R:%.*]] = fdiv nnan <2 x float> [[X]], [[Y]] -; CHECK-NEXT: ret <2 x float> [[R]] +; CHECK-NEXT: ret <2 x float> ; %x = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> %a, <2 x float> %b) %y = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> %b, <2 x float> %a) @@ -791,10 +787,8 @@ define <2 x double> @maximum(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @maximum( -; CHECK-NEXT: [[X:%.*]] = call fast <2 x double> @llvm.maximum.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call <2 x double> @llvm.maximum.v2f64(<2 x double> [[B]], <2 x double> [[A]]) -; CHECK-NEXT: [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]] -; CHECK-NEXT: ret <2 x double> [[R]] +; CHECK-NEXT: [[X:%.*]] = call <2 x double> @llvm.maximum.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]]) +; CHECK-NEXT: ret <2 x double> ; %x = call fast <2 x double> @llvm.maximum.v2f64(<2 x double> %a, <2 x double> %b) %y = call <2 x double> @llvm.maximum.v2f64(<2 x double> %b, <2 x double> %a) @@ -804,10 +798,8 @@ define double @minimum(double %a, double %b) { ; CHECK-LABEL: @minimum( -; CHECK-NEXT: [[X:%.*]] = call nsz double @llvm.minimum.f64(double [[A:%.*]], double [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call ninf double @llvm.minimum.f64(double [[B]], double [[A]]) -; CHECK-NEXT: [[R:%.*]] = fdiv nnan double [[X]], [[Y]] -; CHECK-NEXT: ret double [[R]] +; CHECK-NEXT: [[X:%.*]] = call double @llvm.minimum.f64(double [[A:%.*]], double [[B:%.*]]) +; CHECK-NEXT: ret double 1.000000e+00 ; %x = call nsz double @llvm.minimum.f64(double %a, double %b) %y = call ninf double @llvm.minimum.f64(double %b, double %a) @@ -817,11 +809,8 @@ define i16 @sadd_ov(i16 %a, i16 %b) { ; CHECK-LABEL: @sadd_ov( ; CHECK-NEXT: [[X:%.*]] = call { i16, i1 } @llvm.sadd.with.overflow.i16(i16 [[A:%.*]], i16 [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call { i16, i1 } @llvm.sadd.with.overflow.i16(i16 [[B]], i16 [[A]]) ; CHECK-NEXT: [[X1:%.*]] = extractvalue { i16, i1 } [[X]], 0 -; CHECK-NEXT: [[Y1:%.*]] = extractvalue { i16, i1 } [[Y]], 0 -; CHECK-NEXT: [[O:%.*]] = or i16 [[X1]], [[Y1]] -; CHECK-NEXT: ret i16 [[O]] +; CHECK-NEXT: ret i16 [[X1]] ; %x = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %a, i16 %b) %y = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %b, i16 %a) @@ -834,11 +823,8 @@ define <5 x i65> @uadd_ov(<5 x i65> %a, <5 x i65> %b) { ; CHECK-LABEL: @uadd_ov( ; CHECK-NEXT: [[X:%.*]] = call { <5 x i65>, <5 x i1> } @llvm.uadd.with.overflow.v5i65(<5 x i65> [[A:%.*]], <5 x i65> [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call { <5 x i65>, <5 x i1> } @llvm.uadd.with.overflow.v5i65(<5 x i65> [[B]], <5 x i65> [[A]]) ; CHECK-NEXT: [[X1:%.*]] = extractvalue { <5 x i65>, <5 x i1> } [[X]], 0 -; CHECK-NEXT: [[Y1:%.*]] = extractvalue { <5 x i65>, <5 x i1> } [[Y]], 0 -; CHECK-NEXT: [[O:%.*]] = or <5 x i65> [[X1]], [[Y1]] -; CHECK-NEXT: ret <5 x i65> [[O]] +; CHECK-NEXT: ret <5 x i65> [[X1]] ; %x = call {<5 x i65>, <5 x i1>} @llvm.uadd.with.overflow.v5i65(<5 x i65> %a, <5 x i65> %b) %y = call {<5 x i65>, <5 x i1>} @llvm.uadd.with.overflow.v5i65(<5 x i65> %b, <5 x i65> %a) @@ -851,11 +837,8 @@ define i37 @smul_ov(i37 %a, i37 %b) { ; CHECK-LABEL: @smul_ov( ; CHECK-NEXT: [[X:%.*]] = call { i37, i1 } @llvm.smul.with.overflow.i37(i37 [[A:%.*]], i37 [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call { i37, i1 } @llvm.smul.with.overflow.i37(i37 [[B]], i37 [[A]]) ; CHECK-NEXT: [[X1:%.*]] = extractvalue { i37, i1 } [[X]], 0 -; CHECK-NEXT: [[Y1:%.*]] = extractvalue { i37, i1 } [[Y]], 0 -; CHECK-NEXT: [[O:%.*]] = or i37 [[X1]], [[Y1]] -; CHECK-NEXT: ret i37 [[O]] +; CHECK-NEXT: ret i37 [[X1]] ; %x = call {i37, i1} @llvm.smul.with.overflow.i37(i37 %a, i37 %b) %y = call {i37, i1} @llvm.smul.with.overflow.i37(i37 %b, i37 %a) @@ -868,11 +851,8 @@ define <2 x i31> @umul_ov(<2 x i31> %a, <2 x i31> %b) { ; CHECK-LABEL: @umul_ov( ; CHECK-NEXT: [[X:%.*]] = call { <2 x i31>, <2 x i1> } @llvm.umul.with.overflow.v2i31(<2 x i31> [[A:%.*]], <2 x i31> [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call { <2 x i31>, <2 x i1> } @llvm.umul.with.overflow.v2i31(<2 x i31> [[B]], <2 x i31> [[A]]) ; CHECK-NEXT: [[X1:%.*]] = extractvalue { <2 x i31>, <2 x i1> } [[X]], 0 -; CHECK-NEXT: [[Y1:%.*]] = extractvalue { <2 x i31>, <2 x i1> } [[Y]], 0 -; CHECK-NEXT: [[O:%.*]] = or <2 x i31> [[X1]], [[Y1]] -; CHECK-NEXT: ret <2 x i31> [[O]] +; CHECK-NEXT: ret <2 x i31> [[X1]] ; %x = call {<2 x i31>, <2 x i1>} @llvm.umul.with.overflow.v2i31(<2 x i31> %a, <2 x i31> %b) %y = call {<2 x i31>, <2 x i1>} @llvm.umul.with.overflow.v2i31(<2 x i31> %b, <2 x i31> %a) @@ -885,9 +865,7 @@ define i64 @sadd_sat(i64 %a, i64 %b) { ; CHECK-LABEL: @sadd_sat( ; CHECK-NEXT: [[X:%.*]] = call i64 @llvm.sadd.sat.i64(i64 [[A:%.*]], i64 [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call i64 @llvm.sadd.sat.i64(i64 [[B]], i64 [[A]]) -; CHECK-NEXT: [[O:%.*]] = or i64 [[X]], [[Y]] -; CHECK-NEXT: ret i64 [[O]] +; CHECK-NEXT: ret i64 [[X]] ; %x = call i64 @llvm.sadd.sat.i64(i64 %a, i64 %b) %y = call i64 @llvm.sadd.sat.i64(i64 %b, i64 %a) @@ -898,9 +876,7 @@ define <2 x i64> @uadd_sat(<2 x i64> %a, <2 x i64> %b) { ; CHECK-LABEL: @uadd_sat( ; CHECK-NEXT: [[X:%.*]] = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> [[A:%.*]], <2 x i64> [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> [[B]], <2 x i64> [[A]]) -; CHECK-NEXT: [[O:%.*]] = or <2 x i64> [[X]], [[Y]] -; CHECK-NEXT: ret <2 x i64> [[O]] +; CHECK-NEXT: ret <2 x i64> [[X]] ; %x = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %a, <2 x i64> %b) %y = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %b, <2 x i64> %a) @@ -911,9 +887,7 @@ define <2 x i64> @smax(<2 x i64> %a, <2 x i64> %b) { ; CHECK-LABEL: @smax( ; CHECK-NEXT: [[X:%.*]] = call <2 x i64> @llvm.smax.v2i64(<2 x i64> [[A:%.*]], <2 x i64> [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call <2 x i64> @llvm.smax.v2i64(<2 x i64> [[B]], <2 x i64> [[A]]) -; CHECK-NEXT: [[O:%.*]] = or <2 x i64> [[X]], [[Y]] -; CHECK-NEXT: ret <2 x i64> [[O]] +; CHECK-NEXT: ret <2 x i64> [[X]] ; %x = call <2 x i64> @llvm.smax.v2i64(<2 x i64> %a, <2 x i64> %b) %y = call <2 x i64> @llvm.smax.v2i64(<2 x i64> %b, <2 x i64> %a) @@ -924,9 +898,7 @@ define i4 @smin(i4 %a, i4 %b) { ; CHECK-LABEL: @smin( ; CHECK-NEXT: [[X:%.*]] = call i4 @llvm.smin.i4(i4 [[A:%.*]], i4 [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call i4 @llvm.smin.i4(i4 [[B]], i4 [[A]]) -; CHECK-NEXT: [[O:%.*]] = or i4 [[X]], [[Y]] -; CHECK-NEXT: ret i4 [[O]] +; CHECK-NEXT: ret i4 [[X]] ; %x = call i4 @llvm.smin.i4(i4 %a, i4 %b) %y = call i4 @llvm.smin.i4(i4 %b, i4 %a) @@ -937,9 +909,7 @@ define i67 @umax(i67 %a, i67 %b) { ; CHECK-LABEL: @umax( ; CHECK-NEXT: [[X:%.*]] = call i67 @llvm.umax.i67(i67 [[A:%.*]], i67 [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call i67 @llvm.umax.i67(i67 [[B]], i67 [[A]]) -; CHECK-NEXT: [[O:%.*]] = or i67 [[X]], [[Y]] -; CHECK-NEXT: ret i67 [[O]] +; CHECK-NEXT: ret i67 [[X]] ; %x = call i67 @llvm.umax.i67(i67 %a, i67 %b) %y = call i67 @llvm.umax.i67(i67 %b, i67 %a) @@ -950,9 +920,7 @@ define <3 x i17> @umin(<3 x i17> %a, <3 x i17> %b) { ; CHECK-LABEL: @umin( ; CHECK-NEXT: [[X:%.*]] = call <3 x i17> @llvm.umin.v3i17(<3 x i17> [[A:%.*]], <3 x i17> [[B:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call <3 x i17> @llvm.umin.v3i17(<3 x i17> [[B]], <3 x i17> [[A]]) -; CHECK-NEXT: [[O:%.*]] = or <3 x i17> [[X]], [[Y]] -; CHECK-NEXT: ret <3 x i17> [[O]] +; CHECK-NEXT: ret <3 x i17> [[X]] ; %x = call <3 x i17> @llvm.umin.v3i17(<3 x i17> %a, <3 x i17> %b) %y = call <3 x i17> @llvm.umin.v3i17(<3 x i17> %b, <3 x i17> %a) @@ -960,6 +928,8 @@ ret <3 x i17> %o } +; Negative test - mismatched intrinsics + define i4 @smin_umin(i4 %a, i4 %b) { ; CHECK-LABEL: @smin_umin( ; CHECK-NEXT: [[X:%.*]] = call i4 @llvm.smin.i4(i4 [[A:%.*]], i4 [[B:%.*]]) @@ -973,6 +943,8 @@ ret i4 %o } +; TODO: handle >2 args + define i16 @smul_fix(i16 %a, i16 %b) { ; CHECK-LABEL: @smul_fix( ; CHECK-NEXT: [[X:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 3) @@ -986,6 +958,8 @@ ret i16 %o } +; TODO: handle >2 args + define i16 @umul_fix(i16 %a, i16 %b, i32 %s) { ; CHECK-LABEL: @umul_fix( ; CHECK-NEXT: [[X:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 1) @@ -999,6 +973,8 @@ ret i16 %o } +; TODO: handle >2 args + define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) { ; CHECK-LABEL: @smul_fix_sat( ; CHECK-NEXT: [[X:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 2) @@ -1012,6 +988,8 @@ ret <3 x i16> %o } +; TODO: handle >2 args + define <3 x i16> @umul_fix_sat(<3 x i16> %a, <3 x i16> %b) { ; CHECK-LABEL: @umul_fix_sat( ; CHECK-NEXT: [[X:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 3)