Index: llvm/lib/Target/X86/X86InstrFMA.td =================================================================== --- llvm/lib/Target/X86/X86InstrFMA.td +++ llvm/lib/Target/X86/X86InstrFMA.td @@ -60,27 +60,47 @@ string OpcodeStr, string PackTy, PatFrag MemFrag128, PatFrag MemFrag256, SDNode Op, ValueType OpTy128, ValueType OpTy256> { - // For 213, both the register and memory variant are commutable. - // Indeed, the commutable operands are 1 and 2 and both live in registers - // for both variants. +let hasSideEffects = 0 in { + // For 213, both the register and memory variants are commutable. + // For the register form the commutable operands are 1, 2 and 3. + // For the memory variant the folded operand must be in 3. Thus, + // in that case, only the operands 1 and 2 can be swapped. + // Commuting some of operands may require the opcode change: + // operands 1 and 2 (memory & register forms): *213* --> *213*(no changes); + // operands 1 and 3 (register forms only): *213* --> *231*; + // operands 2 and 3 (register forms only): *213* --> *132*. defm r213 : fma3p_rm; -let hasSideEffects = 0 in { + // For 132, both the register and memory variants are commutable. + // For the register form the commutable operands are 1, 2 and 3. + // For the memory variant the folded operand must be in 3. Thus, + // in that case, only the operands 1 and 2 can be swapped. + // Commuting some of operands may require the opcode change: + // operands 1 and 2 (memory & register forms): *132* --> *231*; + // operands 1 and 3 (register forms only): *132* --> *132*(no changes); + // operands 2 and 3 (register forms only): *132* --> *213*. defm r132 : fma3p_rm; - // For 231, only the register variant is commutable. + MemFrag128, MemFrag256, OpTy128, OpTy256, + /* IsRVariantCommutable */ 1, + /* IsMVariantCommutable */ 1>; + // For 231, both the register and memory variants are commutable. + // For the register form the commutable operands are 1, 2 and 3. // For the memory variant the folded operand must be in 3. Thus, - // in that case, it cannot be swapped with 2. + // in that case, only the operands 1 and 2 can be swapped. + // Commuting some of operands may require the opcode change: + // operands 1 and 2 (memory & register forms): *231* --> *132*; + // operands 1 and 3 (register forms only): *231* --> *213*; + // operands 2 and 3 (register forms only): *231* --> *231*(no changes). defm r231 : fma3p_rm; + /* IsMVariantCommutable */ 1>; } // hasSideEffects = 0 } @@ -156,23 +176,54 @@ X86MemOperand x86memop, Operand memop, PatFrag mem_frag, ComplexPattern mem_cpat> { let hasSideEffects = 0 in { + // For 132, both the register and memory variants are commutable. + // For the register form the commutable operands are 1, 2 and 3. + // For the memory variant the folded operand must be in 3. Thus, + // in that case, only the operands 1 and 2 can be swapped. + // Commuting some of operands may require the opcode change: + // operands 1 and 2 (memory & register forms): *132* --> *231*; + // operands 1 and 3 (register forms only): *132* --> *132*(no changes); + // operands 2 and 3 (register forms only): *132* --> *213*. + // Commuting the operand 1 with some other operand changes the upper bits + // of the result FMA instruction. Thus, it requires a proof of the fact that + // only the lowest element of the result is used. defm r132 : fma3s_rm; - // See the other defm of r231 for the explanation regarding the - // commutable flags. + x86memop, RC, OpVT, mem_frag, + /* IsRVariantCommutable */ 1, + /* IsMVariantCommutable */ 1>; + // For 231, both the register and memory variants are commutable. + // For the register form the commutable operands are 1, 2 and 3. + // For the memory variant the folded operand must be in 3. Thus, + // in that case, only the operands 1 and 2 can be swapped. + // Commuting some of operands may require the opcode change: + // operands 1 and 2 (memory & register forms): *231* --> *132*; + // operands 1 and 3 (register forms only): *231* --> *213*; + // operands 2 and 3 (register forms only): *231* --> *231*(no changes). + // Commuting the operand 1 with some other operand changes the upper bits + // of the result FMA instruction. Thus, it requires a proof of the fact that + // only the lowest element of the result is used. defm r231 : fma3s_rm; -} + /* IsMVariantCommutable */ 1>; -// See the other defm of r213 for the explanation regarding the -// commutable flags. -defm r213 : fma3s_rm; + // For 213, both the register and memory variants are commutable. + // For the register form the commutable operands are 1, 2 and 3. + // For the memory variant the folded operand must be in 3. Thus, + // in that case, only the operands 1 and 2 can be swapped. + // Commuting some of operands may require the opcode change: + // operands 1 and 2 (memory & register forms): *213* --> *213*(no changes); + // operands 1 and 3 (register forms only): *213* --> *231*; + // operands 2 and 3 (register forms only): *213* --> *132*. + // Commuting the operand 1 with some other operand changes the upper bits + // of the result FMA instruction. Thus, it requires a proof of the fact that + // only the lowest element of the result is used. + defm r213 : fma3s_rm; +} } multiclass fma3s opc132, bits<8> opc213, bits<8> opc231, Index: llvm/lib/Target/X86/X86InstrInfo.h =================================================================== --- llvm/lib/Target/X86/X86InstrInfo.h +++ llvm/lib/Target/X86/X86InstrInfo.h @@ -264,6 +264,46 @@ bool findCommutedOpIndices(MachineInstr *MI, unsigned &SrcOpIdx1, unsigned &SrcOpIdx2) const override; + /// Returns true if the routine could find two commutable operands + /// in the given FMA instruction. Otherwise, returns false. + /// + /// \p SrcOpIdx1 and \p SrcOpIdx2 are INPUT and OUTPUT arguments. + /// The output indices of the commuted operands are returned in these + /// arguments. Also, the input values of these arguments may be preset either + /// to indices of operands that must be commuted or be equal to a special + /// value 'CommuteAnyOperandIndex' which means that the corresponding + /// operand index is not set and this method is free to pick any of + /// available commutable operands. + /// + /// For example, calling this method this way: + /// unsigned Idx1 = 1, Idx2 = CommuteAnyOperandIndex; + /// findFMA3CommutedOpIndices(MI, Idx1, Idx2); + /// can be interpreted as a query asking if the operand #1 can be swapped + /// with any other available operand (e.g. operand #2, operand #3, etc.). + /// + /// The returned FMA opcode may differ from the opcode in the given MI. + /// For example, commuting the operands #1 and #3 in the following FMA + /// FMA213 #1, #2, #3 + /// results into instruction with adjusted opcode: + /// FMA231 #3, #2, #1 + bool findFMA3CommutedOpIndices(MachineInstr *MI, + unsigned &SrcOpIdx1, + unsigned &SrcOpIdx2) const; + + /// Returns an adjusted FMA opcode that must be used in FMA instruction that + /// performs the same computations as the given MI but which has the operands + /// \p SrcOpIdx1 and \p SrcOpIdx2 commuted. + /// It may return 0 if it is unsafe to commute the operands. + /// + /// The returned FMA opcode may differ from the opcode in the given \p MI. + /// For example, commuting the operands #1 and #3 in the following FMA + /// FMA213 #1, #2, #3 + /// results into instruction with adjusted opcode: + /// FMA231 #3, #2, #1 + unsigned getFMA3OpcodeToCommuteOperands(MachineInstr *MI, + unsigned SrcOpIdx1, + unsigned SrcOpIdx2) const; + // Branch analysis. bool isUnpredicatedTerminator(const MachineInstr* MI) const override; bool AnalyzeBranch(MachineBasicBlock &MBB, MachineBasicBlock *&TBB, Index: llvm/lib/Target/X86/X86InstrInfo.cpp =================================================================== --- llvm/lib/Target/X86/X86InstrInfo.cpp +++ llvm/lib/Target/X86/X86InstrInfo.cpp @@ -2917,6 +2917,121 @@ return NewMI; } +/// Returns true if the given instruction opcode is FMA3. +/// Otherwise, returns false. +static bool isFMA3(unsigned Opcode) { + switch (Opcode) { + case X86::VFMADDSDr132r: case X86::VFMADDSDr132m: + case X86::VFMADDSSr132r: case X86::VFMADDSSr132m: + case X86::VFMSUBSDr132r: case X86::VFMSUBSDr132m: + case X86::VFMSUBSSr132r: case X86::VFMSUBSSr132m: + case X86::VFNMADDSDr132r: case X86::VFNMADDSDr132m: + case X86::VFNMADDSSr132r: case X86::VFNMADDSSr132m: + case X86::VFNMSUBSDr132r: case X86::VFNMSUBSDr132m: + case X86::VFNMSUBSSr132r: case X86::VFNMSUBSSr132m: + + case X86::VFMADDSDr213r: case X86::VFMADDSDr213m: + case X86::VFMADDSSr213r: case X86::VFMADDSSr213m: + case X86::VFMSUBSDr213r: case X86::VFMSUBSDr213m: + case X86::VFMSUBSSr213r: case X86::VFMSUBSSr213m: + case X86::VFNMADDSDr213r: case X86::VFNMADDSDr213m: + case X86::VFNMADDSSr213r: case X86::VFNMADDSSr213m: + case X86::VFNMSUBSDr213r: case X86::VFNMSUBSDr213m: + case X86::VFNMSUBSSr213r: case X86::VFNMSUBSSr213m: + + case X86::VFMADDSDr231r: case X86::VFMADDSDr231m: + case X86::VFMADDSSr231r: case X86::VFMADDSSr231m: + case X86::VFMSUBSDr231r: case X86::VFMSUBSDr231m: + case X86::VFMSUBSSr231r: case X86::VFMSUBSSr231m: + case X86::VFNMADDSDr231r: case X86::VFNMADDSDr231m: + case X86::VFNMADDSSr231r: case X86::VFNMADDSSr231m: + case X86::VFNMSUBSDr231r: case X86::VFNMSUBSDr231m: + case X86::VFNMSUBSSr231r: case X86::VFNMSUBSSr231m: + + case X86::VFMADDSUBPDr132r: case X86::VFMADDSUBPDr132m: + case X86::VFMADDSUBPSr132r: case X86::VFMADDSUBPSr132m: + case X86::VFMSUBADDPDr132r: case X86::VFMSUBADDPDr132m: + case X86::VFMSUBADDPSr132r: case X86::VFMSUBADDPSr132m: + case X86::VFMADDSUBPDr132rY: case X86::VFMADDSUBPDr132mY: + case X86::VFMADDSUBPSr132rY: case X86::VFMADDSUBPSr132mY: + case X86::VFMSUBADDPDr132rY: case X86::VFMSUBADDPDr132mY: + case X86::VFMSUBADDPSr132rY: case X86::VFMSUBADDPSr132mY: + + case X86::VFMADDPDr132r: case X86::VFMADDPDr132m: + case X86::VFMADDPSr132r: case X86::VFMADDPSr132m: + case X86::VFMSUBPDr132r: case X86::VFMSUBPDr132m: + case X86::VFMSUBPSr132r: case X86::VFMSUBPSr132m: + case X86::VFNMADDPDr132r: case X86::VFNMADDPDr132m: + case X86::VFNMADDPSr132r: case X86::VFNMADDPSr132m: + case X86::VFNMSUBPDr132r: case X86::VFNMSUBPDr132m: + case X86::VFNMSUBPSr132r: case X86::VFNMSUBPSr132m: + case X86::VFMADDPDr132rY: case X86::VFMADDPDr132mY: + case X86::VFMADDPSr132rY: case X86::VFMADDPSr132mY: + case X86::VFMSUBPDr132rY: case X86::VFMSUBPDr132mY: + case X86::VFMSUBPSr132rY: case X86::VFMSUBPSr132mY: + case X86::VFNMADDPDr132rY: case X86::VFNMADDPDr132mY: + case X86::VFNMADDPSr132rY: case X86::VFNMADDPSr132mY: + case X86::VFNMSUBPDr132rY: case X86::VFNMSUBPDr132mY: + case X86::VFNMSUBPSr132rY: case X86::VFNMSUBPSr132mY: + + case X86::VFMADDSUBPDr213r: case X86::VFMADDSUBPDr213m: + case X86::VFMADDSUBPSr213r: case X86::VFMADDSUBPSr213m: + case X86::VFMSUBADDPDr213r: case X86::VFMSUBADDPDr213m: + case X86::VFMSUBADDPSr213r: case X86::VFMSUBADDPSr213m: + case X86::VFMADDSUBPDr213rY: case X86::VFMADDSUBPDr213mY: + case X86::VFMADDSUBPSr213rY: case X86::VFMADDSUBPSr213mY: + case X86::VFMSUBADDPDr213rY: case X86::VFMSUBADDPDr213mY: + case X86::VFMSUBADDPSr213rY: case X86::VFMSUBADDPSr213mY: + + case X86::VFMADDPDr213r: case X86::VFMADDPDr213m: + case X86::VFMADDPSr213r: case X86::VFMADDPSr213m: + case X86::VFMSUBPDr213r: case X86::VFMSUBPDr213m: + case X86::VFMSUBPSr213r: case X86::VFMSUBPSr213m: + case X86::VFNMADDPDr213r: case X86::VFNMADDPDr213m: + case X86::VFNMADDPSr213r: case X86::VFNMADDPSr213m: + case X86::VFNMSUBPDr213r: case X86::VFNMSUBPDr213m: + case X86::VFNMSUBPSr213r: case X86::VFNMSUBPSr213m: + case X86::VFMADDPDr213rY: case X86::VFMADDPDr213mY: + case X86::VFMADDPSr213rY: case X86::VFMADDPSr213mY: + case X86::VFMSUBPDr213rY: case X86::VFMSUBPDr213mY: + case X86::VFMSUBPSr213rY: case X86::VFMSUBPSr213mY: + case X86::VFNMADDPDr213rY: case X86::VFNMADDPDr213mY: + case X86::VFNMADDPSr213rY: case X86::VFNMADDPSr213mY: + case X86::VFNMSUBPDr213rY: case X86::VFNMSUBPDr213mY: + case X86::VFNMSUBPSr213rY: case X86::VFNMSUBPSr213mY: + + case X86::VFMADDSUBPDr231r: case X86::VFMADDSUBPDr231m: + case X86::VFMADDSUBPSr231r: case X86::VFMADDSUBPSr231m: + case X86::VFMSUBADDPDr231r: case X86::VFMSUBADDPDr231m: + case X86::VFMSUBADDPSr231r: case X86::VFMSUBADDPSr231m: + case X86::VFMADDSUBPDr231rY: case X86::VFMADDSUBPDr231mY: + case X86::VFMADDSUBPSr231rY: case X86::VFMADDSUBPSr231mY: + case X86::VFMSUBADDPDr231rY: case X86::VFMSUBADDPDr231mY: + case X86::VFMSUBADDPSr231rY: case X86::VFMSUBADDPSr231mY: + + case X86::VFMADDPDr231r: case X86::VFMADDPDr231m: + case X86::VFMADDPSr231r: case X86::VFMADDPSr231m: + case X86::VFMSUBPDr231r: case X86::VFMSUBPDr231m: + case X86::VFMSUBPSr231r: case X86::VFMSUBPSr231m: + case X86::VFNMADDPDr231r: case X86::VFNMADDPDr231m: + case X86::VFNMADDPSr231r: case X86::VFNMADDPSr231m: + case X86::VFNMSUBPDr231r: case X86::VFNMSUBPDr231m: + case X86::VFNMSUBPSr231r: case X86::VFNMSUBPSr231m: + case X86::VFMADDPDr231rY: case X86::VFMADDPDr231mY: + case X86::VFMADDPSr231rY: case X86::VFMADDPSr231mY: + case X86::VFMSUBPDr231rY: case X86::VFMSUBPDr231mY: + case X86::VFMSUBPSr231rY: case X86::VFMSUBPSr231mY: + case X86::VFNMADDPDr231rY: case X86::VFNMADDPDr231mY: + case X86::VFNMADDPSr231rY: case X86::VFNMADDPSr231mY: + case X86::VFNMSUBPDr231rY: case X86::VFNMSUBPDr231mY: + case X86::VFNMSUBPSr231rY: case X86::VFNMSUBPSr231mY: + return true; + default: + return false; + } + llvm_unreachable("Opcode not handled by the switch"); +} + MachineInstr *X86InstrInfo::commuteInstructionImpl(MachineInstr *MI, bool NewMI, unsigned OpIdx1, @@ -3127,10 +3242,232 @@ // Fallthrough intended. } default: + if (isFMA3(MI->getOpcode())) { + unsigned Opc = getFMA3OpcodeToCommuteOperands(MI, OpIdx1, OpIdx2); + if (Opc == 0) + return nullptr; + if (NewMI) { + MachineFunction &MF = *MI->getParent()->getParent(); + MI = MF.CloneMachineInstr(MI); + NewMI = false; + } + MI->setDesc(get(Opc)); + } return TargetInstrInfo::commuteInstructionImpl(MI, NewMI, OpIdx1, OpIdx2); } } +bool X86InstrInfo::findFMA3CommutedOpIndices(MachineInstr *MI, + unsigned &SrcOpIdx1, + unsigned &SrcOpIdx2) const { + + unsigned RegOpsNum = isMem(MI, 3) ? 2 : 3; + + // Only the first RegOpsNum operands are commutable. + // Also, the value 'CommuteAnyOperandIndex' is valid here as it means + // that the operand is not specified/fixed. + if (SrcOpIdx1 != CommuteAnyOperandIndex && + (SrcOpIdx1 < 1 || SrcOpIdx1 > RegOpsNum)) + return false; + if (SrcOpIdx2 != CommuteAnyOperandIndex && + (SrcOpIdx2 < 1 || SrcOpIdx2 > RegOpsNum)) + return false; + + // Look for two different register operands assumed to be commutable + // regardless of the FMA opcode. The FMA opcode is adjusted later. + if (SrcOpIdx1 == CommuteAnyOperandIndex || + SrcOpIdx2 == CommuteAnyOperandIndex) { + unsigned CommutableOpIdx1 = SrcOpIdx1; + unsigned CommutableOpIdx2 = SrcOpIdx2; + + // At least one of operands to be commuted is not specified and + // this method is free to choose appropriate commutable operands. + if (SrcOpIdx1 == SrcOpIdx2) + // Both of operands are not fixed. By default set one of commutable + // operands to the last register operand of the instruction. + CommutableOpIdx2 = RegOpsNum; + else if (SrcOpIdx2 == CommuteAnyOperandIndex) + // Only one of operands is not fixed. + CommutableOpIdx2 = SrcOpIdx1; + + // CommutableOpIdx2 is well defined now. Let's choose another commutable + // operand and assign its index to CommutableOpIdx1. + unsigned Op2Reg = MI->getOperand(CommutableOpIdx2).getReg(); + for (CommutableOpIdx1 = RegOpsNum; CommutableOpIdx1 > 0; CommutableOpIdx1--) { + // The commuted operands must have different registers. + // Otherwise, the commute transformation does not change anything and + // is useless then. + if (Op2Reg != MI->getOperand(CommutableOpIdx1).getReg()) + break; + } + + // No appropriate commutable operands were found. + if (CommutableOpIdx1 == 0) + return false; + + // Assign the found pair of commutable indices to SrcOpIdx1 and SrcOpidx2 + // to return those values. + if (!fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, + CommutableOpIdx1, CommutableOpIdx2)) + return false; + } + + // Check if we can adjust the opcode to preserve the semantics when + // commute the register operands. + return getFMA3OpcodeToCommuteOperands(MI, SrcOpIdx1, SrcOpIdx2) != 0; +} + +unsigned X86InstrInfo::getFMA3OpcodeToCommuteOperands(MachineInstr *MI, + unsigned SrcOpIdx1, + unsigned SrcOpIdx2) const { + int RetOpc = 0; + int Opc = MI->getOpcode(); + + // Define the array that holds FMA opcodes in groups + // of 3 opcodes(132, 213, 231) in each group. + static const unsigned OpcodeGroups[][3] = { + { X86::VFMADDSSr132r, X86::VFMADDSSr213r, X86::VFMADDSSr231r }, + { X86::VFMADDSDr132r, X86::VFMADDSDr213r, X86::VFMADDSDr231r }, + { X86::VFMADDPSr132r, X86::VFMADDPSr213r, X86::VFMADDPSr231r }, + { X86::VFMADDPDr132r, X86::VFMADDPDr213r, X86::VFMADDPDr231r }, + { X86::VFMADDPSr132rY, X86::VFMADDPSr213rY, X86::VFMADDPSr231rY }, + { X86::VFMADDPDr132rY, X86::VFMADDPDr213rY, X86::VFMADDPDr231rY }, + { X86::VFMADDSSr132m, X86::VFMADDSSr213m, X86::VFMADDSSr231m }, + { X86::VFMADDSDr132m, X86::VFMADDSDr213m, X86::VFMADDSDr231m }, + { X86::VFMADDPSr132m, X86::VFMADDPSr213m, X86::VFMADDPSr231m }, + { X86::VFMADDPDr132m, X86::VFMADDPDr213m, X86::VFMADDPDr231m }, + { X86::VFMADDPSr132mY, X86::VFMADDPSr213mY, X86::VFMADDPSr231mY }, + { X86::VFMADDPDr132mY, X86::VFMADDPDr213mY, X86::VFMADDPDr231mY }, + + { X86::VFMSUBSSr132r, X86::VFMSUBSSr213r, X86::VFMSUBSSr231r }, + { X86::VFMSUBSDr132r, X86::VFMSUBSDr213r, X86::VFMSUBSDr231r }, + { X86::VFMSUBPSr132r, X86::VFMSUBPSr213r, X86::VFMSUBPSr231r }, + { X86::VFMSUBPDr132r, X86::VFMSUBPDr213r, X86::VFMSUBPDr231r }, + { X86::VFMSUBPSr132rY, X86::VFMSUBPSr213rY, X86::VFMSUBPSr231rY }, + { X86::VFMSUBPDr132rY, X86::VFMSUBPDr213rY, X86::VFMSUBPDr231rY }, + { X86::VFMSUBSSr132m, X86::VFMSUBSSr213m, X86::VFMSUBSSr231m }, + { X86::VFMSUBSDr132m, X86::VFMSUBSDr213m, X86::VFMSUBSDr231m }, + { X86::VFMSUBPSr132m, X86::VFMSUBPSr213m, X86::VFMSUBPSr231m }, + { X86::VFMSUBPDr132m, X86::VFMSUBPDr213m, X86::VFMSUBPDr231m }, + { X86::VFMSUBPSr132mY, X86::VFMSUBPSr213mY, X86::VFMSUBPSr231mY }, + { X86::VFMSUBPDr132mY, X86::VFMSUBPDr213mY, X86::VFMSUBPDr231mY }, + + { X86::VFNMADDSSr132r, X86::VFNMADDSSr213r, X86::VFNMADDSSr231r }, + { X86::VFNMADDSDr132r, X86::VFNMADDSDr213r, X86::VFNMADDSDr231r }, + { X86::VFNMADDPSr132r, X86::VFNMADDPSr213r, X86::VFNMADDPSr231r }, + { X86::VFNMADDPDr132r, X86::VFNMADDPDr213r, X86::VFNMADDPDr231r }, + { X86::VFNMADDPSr132rY, X86::VFNMADDPSr213rY, X86::VFNMADDPSr231rY }, + { X86::VFNMADDPDr132rY, X86::VFNMADDPDr213rY, X86::VFNMADDPDr231rY }, + { X86::VFNMADDSSr132m, X86::VFNMADDSSr213m, X86::VFNMADDSSr231m }, + { X86::VFNMADDSDr132m, X86::VFNMADDSDr213m, X86::VFNMADDSDr231m }, + { X86::VFNMADDPSr132m, X86::VFNMADDPSr213m, X86::VFNMADDPSr231m }, + { X86::VFNMADDPDr132m, X86::VFNMADDPDr213m, X86::VFNMADDPDr231m }, + { X86::VFNMADDPSr132mY, X86::VFNMADDPSr213mY, X86::VFNMADDPSr231mY }, + { X86::VFNMADDPDr132mY, X86::VFNMADDPDr213mY, X86::VFNMADDPDr231mY }, + + { X86::VFNMSUBSSr132r, X86::VFNMSUBSSr213r, X86::VFNMSUBSSr231r }, + { X86::VFNMSUBSDr132r, X86::VFNMSUBSDr213r, X86::VFNMSUBSDr231r }, + { X86::VFNMSUBPSr132r, X86::VFNMSUBPSr213r, X86::VFNMSUBPSr231r }, + { X86::VFNMSUBPDr132r, X86::VFNMSUBPDr213r, X86::VFNMSUBPDr231r }, + { X86::VFNMSUBPSr132rY, X86::VFNMSUBPSr213rY, X86::VFNMSUBPSr231rY }, + { X86::VFNMSUBPDr132rY, X86::VFNMSUBPDr213rY, X86::VFNMSUBPDr231rY }, + { X86::VFNMSUBSSr132m, X86::VFNMSUBSSr213m, X86::VFNMSUBSSr231m }, + { X86::VFNMSUBSDr132m, X86::VFNMSUBSDr213m, X86::VFNMSUBSDr231m }, + { X86::VFNMSUBPSr132m, X86::VFNMSUBPSr213m, X86::VFNMSUBPSr231m }, + { X86::VFNMSUBPDr132m, X86::VFNMSUBPDr213m, X86::VFNMSUBPDr231m }, + { X86::VFNMSUBPSr132mY, X86::VFNMSUBPSr213mY, X86::VFNMSUBPSr231mY }, + { X86::VFNMSUBPDr132mY, X86::VFNMSUBPDr213mY, X86::VFNMSUBPDr231mY }, + + { X86::VFMADDSUBPSr132r, X86::VFMADDSUBPSr213r, X86::VFMADDSUBPSr231r }, + { X86::VFMADDSUBPDr132r, X86::VFMADDSUBPDr213r, X86::VFMADDSUBPDr231r }, + { X86::VFMADDSUBPSr132rY, X86::VFMADDSUBPSr213rY, X86::VFMADDSUBPSr231rY }, + { X86::VFMADDSUBPDr132rY, X86::VFMADDSUBPDr213rY, X86::VFMADDSUBPDr231rY }, + { X86::VFMADDSUBPSr132m, X86::VFMADDSUBPSr213m, X86::VFMADDSUBPSr231m }, + { X86::VFMADDSUBPDr132m, X86::VFMADDSUBPDr213m, X86::VFMADDSUBPDr231m }, + { X86::VFMADDSUBPSr132mY, X86::VFMADDSUBPSr213mY, X86::VFMADDSUBPSr231mY }, + { X86::VFMADDSUBPDr132mY, X86::VFMADDSUBPDr213mY, X86::VFMADDSUBPDr231mY }, + + { X86::VFMSUBADDPSr132r, X86::VFMSUBADDPSr213r, X86::VFMSUBADDPSr231r }, + { X86::VFMSUBADDPDr132r, X86::VFMSUBADDPDr213r, X86::VFMSUBADDPDr231r }, + { X86::VFMSUBADDPSr132rY, X86::VFMSUBADDPSr213rY, X86::VFMSUBADDPSr231rY }, + { X86::VFMSUBADDPDr132rY, X86::VFMSUBADDPDr213rY, X86::VFMSUBADDPDr231rY }, + { X86::VFMSUBADDPSr132m, X86::VFMSUBADDPSr213m, X86::VFMSUBADDPSr231m }, + { X86::VFMSUBADDPDr132m, X86::VFMSUBADDPDr213m, X86::VFMSUBADDPDr231m }, + { X86::VFMSUBADDPSr132mY, X86::VFMSUBADDPSr213mY, X86::VFMSUBADDPSr231mY }, + { X86::VFMSUBADDPDr132mY, X86::VFMSUBADDPDr213mY, X86::VFMSUBADDPDr231mY } + }; + const unsigned Form132Index = 0; + const unsigned Form213Index = 1; + const unsigned Form231Index = 2; + const unsigned LastFormIndex = 2; + + unsigned OpcodeGroupsNum = sizeof(OpcodeGroups) / sizeof(OpcodeGroups[0]); + unsigned GroupIndex, FormIndex; + for (GroupIndex = 0; GroupIndex < OpcodeGroupsNum; GroupIndex++) { + for (FormIndex = 0; FormIndex < 3; FormIndex++) { + if (OpcodeGroups[GroupIndex][FormIndex] == Opc) + break; + } + if (FormIndex <= LastFormIndex) + // Found the input opcode in the table. + break; + } + + // Input opcode does not match with any of the opcodes from the table. + if (FormIndex > LastFormIndex) + return 0; + + // Put the lowest index to SrcOpIdx1 to simplify the checks below. + if (SrcOpIdx1 > SrcOpIdx2) { + std::swap(SrcOpIdx1, SrcOpIdx2); + } + + // Find the adjusted FMA opcode to preserve the operation semantics after + // commuting the operands. + if (SrcOpIdx1 == 1 && SrcOpIdx2 == 2) { + if (FormIndex == Form132Index) + // (A * b + C) ==> (A * b + C); + // FMA132 A, C, b; ==> FMA231 C, A, b; + RetOpc = OpcodeGroups[GroupIndex][Form231Index]; + else if (FormIndex == Form213Index) + // (A * B + c) ==> (B * A + c); + // FMA213 B, A, c; ==> FMA213 A, B, c; + RetOpc = Opc; + else + // (A * b + C) ==> (A * b + C); + // FMA231 C, A, b; ==> FMA132 A, C, b; + RetOpc = OpcodeGroups[GroupIndex][Form132Index]; + } else if (SrcOpIdx1 == 1 && SrcOpIdx2 == 3) { + if (FormIndex == Form132Index) + // (A * B + c) ==> (B * A + c); + // FMA132 A, c, B; ==> FMA132 B, c, A; + RetOpc = Opc; + else if (FormIndex == Form213Index) + // (a * B + C) ==> (a * B + C); + // FMA213 B, a, C; ==> FMA231 C, a, B; + RetOpc = OpcodeGroups[GroupIndex][Form231Index]; + else + // (a * B + C) ==> (a * B + C); + // FMA231 C, a, B; ==> FMA213 B, a, C; + RetOpc = OpcodeGroups[GroupIndex][Form213Index]; + } else if (SrcOpIdx1 == 2 && SrcOpIdx2 == 3) { + if (FormIndex == Form132Index) + // (A * B + c) ==> (B * a + C); + // FMA132 a, C, B; ==> FMA213 a, B, C; + RetOpc = OpcodeGroups[GroupIndex][Form213Index]; + else if (FormIndex == Form213Index) + // (A * b + C) ==> (b * A + C); + // FMA213 b, A, C; ==> FMA132 b, C, A; + RetOpc = OpcodeGroups[GroupIndex][Form132Index]; + else + // (A * B + c) ==> (B * A + c); + // FMA231 c, A, B; ==> FMA231 c, B, A; + RetOpc = Opc; + } + + return RetOpc; +} + bool X86InstrInfo::findCommutedOpIndices(MachineInstr *MI, unsigned &SrcOpIdx1, unsigned &SrcOpIdx2) const { @@ -3155,34 +3492,9 @@ } return false; } - case X86::VFMADDPDr231r: - case X86::VFMADDPSr231r: - case X86::VFMADDSDr231r: - case X86::VFMADDSSr231r: - case X86::VFMSUBPDr231r: - case X86::VFMSUBPSr231r: - case X86::VFMSUBSDr231r: - case X86::VFMSUBSSr231r: - case X86::VFNMADDPDr231r: - case X86::VFNMADDPSr231r: - case X86::VFNMADDSDr231r: - case X86::VFNMADDSSr231r: - case X86::VFNMSUBPDr231r: - case X86::VFNMSUBPSr231r: - case X86::VFNMSUBSDr231r: - case X86::VFNMSUBSSr231r: - case X86::VFMADDPDr231rY: - case X86::VFMADDPSr231rY: - case X86::VFMSUBPDr231rY: - case X86::VFMSUBPSr231rY: - case X86::VFNMADDPDr231rY: - case X86::VFNMADDPSr231rY: - case X86::VFNMSUBPDr231rY: - case X86::VFNMSUBPSr231rY: - // The indices of the commutable operands are 2 and 3. - // Assign them to the returned operand indices here. - return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 2, 3); default: + if (isFMA3(MI->getOpcode())) + return findFMA3CommutedOpIndices(MI, SrcOpIdx1, SrcOpIdx2); return TargetInstrInfo::findCommutedOpIndices(MI, SrcOpIdx1, SrcOpIdx2); } return false; Index: llvm/test/CodeGen/X86/fma-commute-x86.ll =================================================================== --- llvm/test/CodeGen/X86/fma-commute-x86.ll +++ llvm/test/CodeGen/X86/fma-commute-x86.ll @@ -0,0 +1,312 @@ +; RUN: llc < %s -mtriple=x86_64-pc-win32 -mcpu=core-avx2 | FileCheck %s +; RUN: llc < %s -mtriple=x86_64-pc-win32 -mattr=+fma,+fma4 | FileCheck %s +; RUN: llc < %s -mcpu=bdver2 -mtriple=x86_64-pc-win32 -mattr=-fma4 | FileCheck %s + +declare <4 x float> @llvm.x86.fma.vfmadd.ps(<4 x float>, <4 x float>, <4 x float>) nounwind readnone +define <4 x float> @test_x86_fmadd_baa_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fmadd132ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfmadd.ps(<4 x float> %b, <4 x float> %a, <4 x float> %a) nounwind + ret <4 x float> %res +} + +define <4 x float> @test_x86_fmadd_aba_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fmadd231ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfmadd.ps(<4 x float> %a, <4 x float> %b, <4 x float> %a) nounwind + ret <4 x float> %res +} + +define <4 x float> @test_x86_fmadd_bba_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fmadd213ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfmadd.ps(<4 x float> %b, <4 x float> %b, <4 x float> %a) nounwind + ret <4 x float> %res +} + +declare <8 x float> @llvm.x86.fma.vfmadd.ps.256(<8 x float>, <8 x float>, <8 x float>) nounwind readnone +define <8 x float> @test_x86_fmadd_baa_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fmadd132ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfmadd.ps.256(<8 x float> %b, <8 x float> %a, <8 x float> %a) nounwind + ret <8 x float> %res +} + +define <8 x float> @test_x86_fmadd_aba_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fmadd231ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfmadd.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %a) nounwind + ret <8 x float> %res +} + +define <8 x float> @test_x86_fmadd_bba_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fmadd213ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfmadd.ps.256(<8 x float> %b, <8 x float> %b, <8 x float> %a) nounwind + ret <8 x float> %res +} + +declare <2 x double> @llvm.x86.fma.vfmadd.pd(<2 x double>, <2 x double>, <2 x double>) nounwind readnone +define <2 x double> @test_x86_fmadd_baa_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fmadd132pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfmadd.pd(<2 x double> %b, <2 x double> %a, <2 x double> %a) nounwind + ret <2 x double> %res +} + +define <2 x double> @test_x86_fmadd_aba_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fmadd231pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfmadd.pd(<2 x double> %a, <2 x double> %b, <2 x double> %a) nounwind + ret <2 x double> %res +} + +define <2 x double> @test_x86_fmadd_bba_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fmadd213pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfmadd.pd(<2 x double> %b, <2 x double> %b, <2 x double> %a) nounwind + ret <2 x double> %res +} + +declare <4 x double> @llvm.x86.fma.vfmadd.pd.256(<4 x double>, <4 x double>, <4 x double>) nounwind readnone +define <4 x double> @test_x86_fmadd_baa_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fmadd132pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfmadd.pd.256(<4 x double> %b, <4 x double> %a, <4 x double> %a) nounwind + ret <4 x double> %res +} + +define <4 x double> @test_x86_fmadd_aba_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fmadd231pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %a) nounwind + ret <4 x double> %res +} + +define <4 x double> @test_x86_fmadd_bba_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fmadd213pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfmadd.pd.256(<4 x double> %b, <4 x double> %b, <4 x double> %a) nounwind + ret <4 x double> %res +} + + + +declare <4 x float> @llvm.x86.fma.vfnmadd.ps(<4 x float>, <4 x float>, <4 x float>) nounwind readnone +define <4 x float> @test_x86_fnmadd_baa_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fnmadd132ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfnmadd.ps(<4 x float> %b, <4 x float> %a, <4 x float> %a) nounwind + ret <4 x float> %res +} + +define <4 x float> @test_x86_fnmadd_aba_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fnmadd231ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfnmadd.ps(<4 x float> %a, <4 x float> %b, <4 x float> %a) nounwind + ret <4 x float> %res +} + +define <4 x float> @test_x86_fnmadd_bba_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fnmadd213ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfnmadd.ps(<4 x float> %b, <4 x float> %b, <4 x float> %a) nounwind + ret <4 x float> %res +} + +declare <8 x float> @llvm.x86.fma.vfnmadd.ps.256(<8 x float>, <8 x float>, <8 x float>) nounwind readnone +define <8 x float> @test_x86_fnmadd_baa_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fnmadd132ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfnmadd.ps.256(<8 x float> %b, <8 x float> %a, <8 x float> %a) nounwind + ret <8 x float> %res +} + +define <8 x float> @test_x86_fnmadd_aba_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fnmadd231ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfnmadd.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %a) nounwind + ret <8 x float> %res +} + +define <8 x float> @test_x86_fnmadd_bba_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fnmadd213ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfnmadd.ps.256(<8 x float> %b, <8 x float> %b, <8 x float> %a) nounwind + ret <8 x float> %res +} + +declare <2 x double> @llvm.x86.fma.vfnmadd.pd(<2 x double>, <2 x double>, <2 x double>) nounwind readnone +define <2 x double> @test_x86_fnmadd_baa_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fnmadd132pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfnmadd.pd(<2 x double> %b, <2 x double> %a, <2 x double> %a) nounwind + ret <2 x double> %res +} + +define <2 x double> @test_x86_fnmadd_aba_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fnmadd231pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfnmadd.pd(<2 x double> %a, <2 x double> %b, <2 x double> %a) nounwind + ret <2 x double> %res +} + +define <2 x double> @test_x86_fnmadd_bba_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fnmadd213pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfnmadd.pd(<2 x double> %b, <2 x double> %b, <2 x double> %a) nounwind + ret <2 x double> %res +} + +declare <4 x double> @llvm.x86.fma.vfnmadd.pd.256(<4 x double>, <4 x double>, <4 x double>) nounwind readnone +define <4 x double> @test_x86_fnmadd_baa_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fnmadd132pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfnmadd.pd.256(<4 x double> %b, <4 x double> %a, <4 x double> %a) nounwind + ret <4 x double> %res +} + +define <4 x double> @test_x86_fnmadd_aba_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fnmadd231pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfnmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %a) nounwind + ret <4 x double> %res +} + +define <4 x double> @test_x86_fnmadd_bba_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fnmadd213pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfnmadd.pd.256(<4 x double> %b, <4 x double> %b, <4 x double> %a) nounwind + ret <4 x double> %res +} + + +declare <4 x float> @llvm.x86.fma.vfmsub.ps(<4 x float>, <4 x float>, <4 x float>) nounwind readnone +define <4 x float> @test_x86_fmsub_baa_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fmsub132ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfmsub.ps(<4 x float> %b, <4 x float> %a, <4 x float> %a) nounwind + ret <4 x float> %res +} + +define <4 x float> @test_x86_fmsub_aba_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fmsub231ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfmsub.ps(<4 x float> %a, <4 x float> %b, <4 x float> %a) nounwind + ret <4 x float> %res +} + +define <4 x float> @test_x86_fmsub_bba_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fmsub213ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfmsub.ps(<4 x float> %b, <4 x float> %b, <4 x float> %a) nounwind + ret <4 x float> %res +} + +declare <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float>, <8 x float>, <8 x float>) nounwind readnone +define <8 x float> @test_x86_fmsub_baa_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fmsub132ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %b, <8 x float> %a, <8 x float> %a) nounwind + ret <8 x float> %res +} + +define <8 x float> @test_x86_fmsub_aba_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fmsub231ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %a) nounwind + ret <8 x float> %res +} + +define <8 x float> @test_x86_fmsub_bba_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fmsub213ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %b, <8 x float> %b, <8 x float> %a) nounwind + ret <8 x float> %res +} + +declare <2 x double> @llvm.x86.fma.vfmsub.pd(<2 x double>, <2 x double>, <2 x double>) nounwind readnone +define <2 x double> @test_x86_fmsub_baa_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fmsub132pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfmsub.pd(<2 x double> %b, <2 x double> %a, <2 x double> %a) nounwind + ret <2 x double> %res +} + +define <2 x double> @test_x86_fmsub_aba_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fmsub231pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfmsub.pd(<2 x double> %a, <2 x double> %b, <2 x double> %a) nounwind + ret <2 x double> %res +} + +define <2 x double> @test_x86_fmsub_bba_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fmsub213pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfmsub.pd(<2 x double> %b, <2 x double> %b, <2 x double> %a) nounwind + ret <2 x double> %res +} + +declare <4 x double> @llvm.x86.fma.vfmsub.pd.256(<4 x double>, <4 x double>, <4 x double>) nounwind readnone +define <4 x double> @test_x86_fmsub_baa_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fmsub132pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfmsub.pd.256(<4 x double> %b, <4 x double> %a, <4 x double> %a) nounwind + ret <4 x double> %res +} + +define <4 x double> @test_x86_fmsub_aba_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fmsub231pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfmsub.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %a) nounwind + ret <4 x double> %res +} + +define <4 x double> @test_x86_fmsub_bba_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fmsub213pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfmsub.pd.256(<4 x double> %b, <4 x double> %b, <4 x double> %a) nounwind + ret <4 x double> %res +} + + +declare <4 x float> @llvm.x86.fma.vfnmsub.ps(<4 x float>, <4 x float>, <4 x float>) nounwind readnone +define <4 x float> @test_x86_fnmsub_baa_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fnmsub132ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfnmsub.ps(<4 x float> %b, <4 x float> %a, <4 x float> %a) nounwind + ret <4 x float> %res +} + +define <4 x float> @test_x86_fnmsub_aba_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fnmsub231ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfnmsub.ps(<4 x float> %a, <4 x float> %b, <4 x float> %a) nounwind + ret <4 x float> %res +} + +define <4 x float> @test_x86_fnmsub_bba_ps(<4 x float> %a, <4 x float> %b) { +; CHECK: fnmsub213ps {{.*%r.*}}, %xmm0, %xmm0 + %res = call <4 x float> @llvm.x86.fma.vfnmsub.ps(<4 x float> %b, <4 x float> %b, <4 x float> %a) nounwind + ret <4 x float> %res +} + +declare <8 x float> @llvm.x86.fma.vfnmsub.ps.256(<8 x float>, <8 x float>, <8 x float>) nounwind readnone +define <8 x float> @test_x86_fnmsub_baa_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fnmsub132ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfnmsub.ps.256(<8 x float> %b, <8 x float> %a, <8 x float> %a) nounwind + ret <8 x float> %res +} + +define <8 x float> @test_x86_fnmsub_aba_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fnmsub231ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfnmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %a) nounwind + ret <8 x float> %res +} + +define <8 x float> @test_x86_fnmsub_bba_ps_y(<8 x float> %a, <8 x float> %b) { +; CHECK: fnmsub213ps {{.*%r.*}}, %ymm0, %ymm0 + %res = call <8 x float> @llvm.x86.fma.vfnmsub.ps.256(<8 x float> %b, <8 x float> %b, <8 x float> %a) nounwind + ret <8 x float> %res +} + +declare <2 x double> @llvm.x86.fma.vfnmsub.pd(<2 x double>, <2 x double>, <2 x double>) nounwind readnone +define <2 x double> @test_x86_fnmsub_baa_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fnmsub132pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfnmsub.pd(<2 x double> %b, <2 x double> %a, <2 x double> %a) nounwind + ret <2 x double> %res +} + +define <2 x double> @test_x86_fnmsub_aba_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fnmsub231pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfnmsub.pd(<2 x double> %a, <2 x double> %b, <2 x double> %a) nounwind + ret <2 x double> %res +} + +define <2 x double> @test_x86_fnmsub_bba_pd(<2 x double> %a, <2 x double> %b) { +; CHECK: fnmsub213pd {{.*%r.*}}, %xmm0, %xmm0 + %res = call <2 x double> @llvm.x86.fma.vfnmsub.pd(<2 x double> %b, <2 x double> %b, <2 x double> %a) nounwind + ret <2 x double> %res +} + +declare <4 x double> @llvm.x86.fma.vfnmsub.pd.256(<4 x double>, <4 x double>, <4 x double>) nounwind readnone +define <4 x double> @test_x86_fnmsub_baa_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fnmsub132pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfnmsub.pd.256(<4 x double> %b, <4 x double> %a, <4 x double> %a) nounwind + ret <4 x double> %res +} + +define <4 x double> @test_x86_fnmsub_aba_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fnmsub231pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfnmsub.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %a) nounwind + ret <4 x double> %res +} + +define <4 x double> @test_x86_fnmsub_bba_pd_y(<4 x double> %a, <4 x double> %b) { +; CHECK: fnmsub213pd {{.*%r.*}}, %ymm0, %ymm0 + %res = call <4 x double> @llvm.x86.fma.vfnmsub.pd.256(<4 x double> %b, <4 x double> %b, <4 x double> %a) nounwind + ret <4 x double> %res +} + Index: llvm/test/CodeGen/X86/fma_patterns.ll =================================================================== --- llvm/test/CodeGen/X86/fma_patterns.ll +++ llvm/test/CodeGen/X86/fma_patterns.ll @@ -237,8 +237,7 @@ define <4 x float> @test_x86_fmadd_ps_load(<4 x float>* %a0, <4 x float> %a1, <4 x float> %a2) { ; CHECK-LABEL: test_x86_fmadd_ps_load: ; CHECK: # BB#0: -; CHECK-NEXT: vmovaps (%rdi), %xmm2 -; CHECK-NEXT: vfmadd213ps %xmm1, %xmm2, %xmm0 +; CHECK-NEXT: vfmadd132ps (%rdi), %xmm1, %xmm0 ; CHECK-NEXT: retq ; ; CHECK_FMA4-LABEL: test_x86_fmadd_ps_load: @@ -254,8 +253,7 @@ define <4 x float> @test_x86_fmsub_ps_load(<4 x float>* %a0, <4 x float> %a1, <4 x float> %a2) { ; CHECK-LABEL: test_x86_fmsub_ps_load: ; CHECK: # BB#0: -; CHECK-NEXT: vmovaps (%rdi), %xmm2 -; CHECK-NEXT: vfmsub213ps %xmm1, %xmm2, %xmm0 +; CHECK-NEXT: vfmsub132ps (%rdi), %xmm1, %xmm0 ; CHECK-NEXT: retq ; ; CHECK_FMA4-LABEL: test_x86_fmsub_ps_load: