Index: Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- Transforms/InstCombine/InstCombineCalls.cpp +++ Transforms/InstCombine/InstCombineCalls.cpp @@ -1554,6 +1554,59 @@ return false; } +// Convert Cmp(Const,a) into canonical representation Cmp(a,Const) +// This function creates a canonical compare by reordering its operands if needed. +// as a result, the compare direction may also be changed in order to preserve the original compare meaning. +// Example(a < b) -> (b > a) +// Otherwise, the comparison value is still valid for the new operands order. +// Example(a == b) -> (b == a) +// +// TODO: Add canonical representation for SSE's compare intrinsics. +// Target must have VEX encoding!!! + +static bool X86CreateCanonicalCMP(IntrinsicInst *II) { + Value *LHS = II->getOperand(0); + Value *RHS = II->getOperand(1); + StringRef IntrinsicName = II->getCalledFunction()->getName(); + // This Assertion ensures that only avx and above intrinsics are passing + // the compare canonical representation. + assert(IntrinsicName.contains("avx") && + "Canonical representation support only intrinsics with VEX encoding"); + if (isa(LHS) && !isa(RHS)) { + ConstantInt *ComparisonValue = cast(II->getOperand(2)); + uint64_t ConstantValue = ComparisonValue->getZExtValue(); + // When the lower bits of the compare are "01" or "10" (e.g. "1" or "2"), + // they represent a "direction" ('<','<='...) types of comparisons + // for which we need to change the direction of the compare when operators + // are exchanged. + // + // In the 128-bit Legacy SSE version: The comparison predicate operand is + // an 8 - bit immediate, bits 2:0 of the immediate define the type of + // comparison to be performed. Bits 7 : 3 of the immediate is reserved. + // + // The three first bits represent: + // equal, less than ,Less-than-or-equal ,Unordered, Not-equal, + // Not less than , Not Less-than-or-equal and Ordered. + // + // In the VEX version: Two more bits were added to the immediate. For the + // "relation" types (<,<=,!<,!<=) the fourth bit represent the + // "greater relation". By using "Bit flipping" operation on the 4 low bits + // of the "relation" types. We are creating a one to one match between + // less to the equivalent upside greater. + // ( xor(immediate('<'),0x0F) = immediate('>')) + // This behaviour is true also for immediate with 0X10h. + if ((ConstantValue & 0x3) == 1 || (ConstantValue & 0x3) == 2) { + const APInt NewComparison(32, (ConstantValue ^ 0xf)); + II->setOperand(2, ConstantExpr::getIntegerValue( + Type::getInt32Ty(II->getContext()), NewComparison)); + } + II->setArgOperand(0, RHS); + II->setArgOperand(1, LHS); + return true; + } + return false; +} + // Convert NVVM intrinsics to target-generic LLVM code where possible. static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { // Each NVVM intrinsic we can simplify can be replaced with one of: @@ -2290,6 +2343,11 @@ break; } + case Intrinsic::x86_avx512_mask_cmp_ss: + case Intrinsic::x86_avx512_mask_cmp_sd: + if (X86CreateCanonicalCMP(II)) + return II; + LLVM_FALLTHROUGH; case Intrinsic::x86_sse_comieq_ss: case Intrinsic::x86_sse_comige_ss: case Intrinsic::x86_sse_comigt_ss: @@ -2315,9 +2373,7 @@ case Intrinsic::x86_sse2_ucomilt_sd: case Intrinsic::x86_sse2_ucomineq_sd: case Intrinsic::x86_avx512_vcomi_ss: - case Intrinsic::x86_avx512_vcomi_sd: - case Intrinsic::x86_avx512_mask_cmp_ss: - case Intrinsic::x86_avx512_mask_cmp_sd: { + case Intrinsic::x86_avx512_vcomi_sd: { // These intrinsics only demand the 0th element of their input vectors. If // we can simplify the input based on that, do so now. bool MadeChange = false; @@ -2336,6 +2392,15 @@ return II; break; } + case Intrinsic::x86_avx512_mask_cmp_pd_128: + case Intrinsic::x86_avx512_mask_cmp_pd_256: + case Intrinsic::x86_avx512_mask_cmp_pd_512: + case Intrinsic::x86_avx512_mask_cmp_ps_128: + case Intrinsic::x86_avx512_mask_cmp_ps_256: + case Intrinsic::x86_avx512_mask_cmp_ps_512: + if(X86CreateCanonicalCMP(II)) + return II; + break; case Intrinsic::x86_avx512_mask_add_ps_512: case Intrinsic::x86_avx512_mask_div_ps_512: Index: Transforms/InstCombine/X86CanonicCmp.ll =================================================================== --- /dev/null +++ Transforms/InstCombine/X86CanonicCmp.ll @@ -0,0 +1,61 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; This test checks the transformation of compare into canonical compare. + +define zeroext i32 @fucntionTets(<2 x double> %a, <4 x double> %b, <8 x double> %c, <4 x float> %d, <8 x float> %e, <16 x float> %f) local_unnamed_addr #0 { +entry: + ; CHECK: %0 = tail call i8 @llvm.x86.avx512.mask.cmp.pd.128(<2 x double> zeroinitializer, <2 x double> , i32 1, i8 -1) + ; CHECK: %1 = tail call i8 @llvm.x86.avx512.mask.cmp.pd.128(<2 x double> %a, <2 x double> zeroinitializer, i32 14, i8 -1) + ; CHECK: %2 = tail call i8 @llvm.x86.avx512.mask.cmp.pd.256(<4 x double> %b, <4 x double> zeroinitializer, i32 10, i8 -1) + ; CHECK: %3 = tail call i8 @llvm.x86.avx512.mask.cmp.pd.512(<8 x double> %c, <8 x double> zeroinitializer, i32 11, i8 -1, i32 4) + ; CHECK: %4 = tail call i8 @llvm.x86.avx512.mask.cmp.ps.128(<4 x float> %d, <4 x float> zeroinitializer, i32 12, i8 -1) + ; CHECK: %5 = tail call i8 @llvm.x86.avx512.mask.cmp.ps.256(<8 x float> %e, <8 x float> zeroinitializer, i32 15, i8 -1) + ; CHECK: %6 = tail call i8 @llvm.x86.avx512.mask.cmp.sd(<2 x double> %a, <2 x double> , i32 16, i8 -1, i32 4) + ; CHECK: %7 = tail call i16 @llvm.x86.avx512.mask.cmp.ps.512(<16 x float> %f, <16 x float> zeroinitializer, i32 0, i16 -1, i32 4) + ; CHECK: %8 = tail call <4 x float> @llvm.x86.sse.cmp.ss(<4 x float> , <4 x float> %d, i8 3) + ; CHECK: %9 = tail call <2 x double> @llvm.x86.sse2.cmp.sd(<2 x double> , <2 x double> %a, i8 0) + ; Checking that CMP with two constant comparison operands is unchanged + %0 = tail call i8 @llvm.x86.avx512.mask.cmp.pd.128(<2 x double> zeroinitializer, <2 x double> , i32 1, i8 -1) + + ; Checking that the canonical representation is working and converts the Cmp(const, a) into Cmp(a, const) with correct comparison immediate. + %1 = tail call i8 @llvm.x86.avx512.mask.cmp.pd.128(<2 x double> zeroinitializer, <2 x double> %a, i32 1, i8 -1) + %2 = tail call i8 @llvm.x86.avx512.mask.cmp.pd.256(<4 x double> zeroinitializer, <4 x double> %b, i32 5, i8 -1) + %3 = tail call i8 @llvm.x86.avx512.mask.cmp.pd.512(<8 x double> zeroinitializer, <8 x double> %c, i32 11, i8 -1, i32 4) + %4 = tail call i8 @llvm.x86.avx512.mask.cmp.ps.128(<4 x float> zeroinitializer, <4 x float> %d, i32 12, i8 -1) + %5 = tail call i8 @llvm.x86.avx512.mask.cmp.ps.256(<8 x float> zeroinitializer, <8 x float> %e, i32 15, i8 -1) + %6 = tail call i8 @llvm.x86.avx512.mask.cmp.sd(<2 x double> zeroinitializer, <2 x double> %a, i32 16, i8 -1, i32 4) + + + ; Checking that the canonical representation is working and converts the Cmp(const, a) into Cmp(a, const) without any change on comparison immediate. + %7 = tail call i16 @llvm.x86.avx512.mask.cmp.ps.512(<16 x float> zeroinitializer, <16 x float> %f, i32 0, i16 -1, i32 4) + + %8 = tail call <4 x float> @llvm.x86.sse.cmp.ss(<4 x float> zeroinitializer, <4 x float> %d, i8 3) + %9 = tail call <2 x double> @llvm.x86.sse2.cmp.sd(<2 x double> zeroinitializer, <2 x double> %a, i8 0) + %10 = tail call i32 @llvm.x86.avx512.vcvtss2si32(<4 x float> %8 ,i32 4) + %11 = tail call i32 @llvm.x86.avx512.vcvtsd2si32(<2 x double> %9, i32 4) + + %and38 = and i8 %1, %0 + %and39 = and i8 %and38, %2 + %and2039 = and i8 %and39, %3 + %and2240 = and i8 %and2039, %4 + %and2441 = and i8 %and2240, %5 + %and2442 = and i8 %and2441, %6 + %conv27 = zext i8 %and2442 to i16 + %and28 = and i16 %7, %conv27 + %and1011 =and i32 %10, %11 + %conv28 = zext i16 %and28 to i32 + %andF = and i32 %and1011, %conv28 + ret i32 %andF +} + +declare i8 @llvm.x86.avx512.mask.cmp.pd.128(<2 x double>, <2 x double>, i32, i8) +declare i8 @llvm.x86.avx512.mask.cmp.pd.256(<4 x double>, <4 x double>, i32, i8) +declare i8 @llvm.x86.avx512.mask.cmp.pd.512(<8 x double>, <8 x double>, i32, i8, i32) +declare i8 @llvm.x86.avx512.mask.cmp.ps.128(<4 x float>, <4 x float>, i32, i8) +declare i8 @llvm.x86.avx512.mask.cmp.ps.256(<8 x float>, <8 x float>, i32, i8) +declare i16 @llvm.x86.avx512.mask.cmp.ps.512(<16 x float>, <16 x float>, i32, i16, i32) +declare i8 @llvm.x86.avx512.mask.cmp.sd(<2 x double>, <2 x double>, i32, i8, i32) +declare <2 x double> @llvm.x86.sse2.cmp.sd(<2 x double>, <2 x double>, i8) +declare <4 x float> @llvm.x86.sse.cmp.ss(<4 x float>, <4 x float>, i8) +declare i32 @llvm.x86.avx512.vcvtss2si32(<4 x float> ,i32) +declare i32 @llvm.x86.avx512.vcvtsd2si32(<2 x double> ,i32)