diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -38,10 +38,6 @@ : Op; -class TernOpTyped opCode, RegisterClass CCond, RegisterClass CID, SDNode node> - : Op; - multiclass BinOpTypedGen opCode, SDNode node, bit genF = 0, bit genV = 0> { if genF then def S: BinOpTyped; @@ -55,27 +51,6 @@ } } -multiclass TernOpTypedGen opCode, SDNode node, bit genI = 1, bit genF = 0, bit genV = 0> { - if genF then { - def SFSCond: TernOpTyped; - def SFVCond: TernOpTyped; - } - if genI then { - def SISCond: TernOpTyped; - def SIVCond: TernOpTyped; - } - if genV then { - if genF then { - def VFSCond: TernOpTyped; - def VFVCond: TernOpTyped; - } - if genI then { - def VISCond: TernOpTyped; - def VIVCond: TernOpTyped; - } - } -} - class UnOp opCode, list pattern=[]> : Op; @@ -531,7 +506,8 @@ def OpLogicalAnd: BinOp<"OpLogicalAnd", 167>; def OpLogicalNot: UnOp<"OpLogicalNot", 168>; -defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1>; +def OpSelect: Op<169, (outs ID:$dst), (ins TYPE:$src_ty, ID:$cond, ID:$src1, ID:$src2), + "$dst = OpSelect $src_ty $cond $src1 $src2">; def OpIEqual: BinOp<"OpIEqual", 170>; def OpINotEqual: BinOp<"OpINotEqual", 171>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -127,8 +127,10 @@ bool selectConst(Register ResVReg, const SPIRVType *ResType, const APInt &Imm, MachineInstr &I) const; - bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, - bool IsSigned) const; + bool selectBoolSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, bool IsSigned) const; + bool selectSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsSigned, unsigned Opcode) const; bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, @@ -178,6 +180,8 @@ Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const; Register buildOnesVal(bool AllOnes, const SPIRVType *ResType, MachineInstr &I) const; + bool buildSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + Register TrueValue, Register FalseValue) const; }; } // end anonymous namespace @@ -319,6 +323,9 @@ case TargetOpcode::G_PHI: return selectPhi(ResVReg, ResType, I); + case TargetOpcode::G_SELECT: + return selectSelect(ResVReg, ResType, I); + case TargetOpcode::G_FPTOSI: return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToS); case TargetOpcode::G_FPTOUI: @@ -1087,23 +1094,34 @@ return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII); } -bool SPIRVInstructionSelector::selectSelect(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I, - bool IsSigned) const { +bool SPIRVInstructionSelector::selectBoolSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool IsSigned) const { // To extend a bool, we need to use OpSelect between constants. Register ZeroReg = buildZerosVal(ResType, I); Register OneReg = buildOnesVal(IsSigned, ResType, I); - bool IsScalarBool = - GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool); - unsigned Opcode = - IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectSIVCond; - return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + return buildSelect(ResVReg, ResType, I, OneReg, ZeroReg); +} + +bool SPIRVInstructionSelector::selectSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return buildSelect(ResVReg, ResType, I, I.getOperand(2).getReg(), + I.getOperand(3).getReg()); +} + +bool SPIRVInstructionSelector::buildSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, Register TrueValue, + Register FalseValue) const { + // To extend a bool, we need to use OpSelect between constants. + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpSelect)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) .addUse(I.getOperand(1).getReg()) - .addUse(OneReg) - .addUse(ZeroReg) + .addUse(TrueValue) + .addUse(FalseValue) .constrainAllUses(TII, TRI, RBI); } @@ -1122,7 +1140,7 @@ TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII); } SrcReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); - selectSelect(SrcReg, TmpType, I, false); + selectBoolSelect(SrcReg, TmpType, I, false); } return selectUnOpWithSrc(ResVReg, ResType, I, SrcReg, Opcode); } @@ -1131,7 +1149,7 @@ const SPIRVType *ResType, MachineInstr &I, bool IsSigned) const { if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) - return selectSelect(ResVReg, ResType, I, IsSigned); + return selectBoolSelect(ResVReg, ResType, I, IsSigned); unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; return selectUnOp(ResVReg, ResType, I, Opcode); } diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -46,7 +46,6 @@ TargetOpcode::G_SHL, TargetOpcode::G_ASHR, TargetOpcode::G_LSHR, - TargetOpcode::G_SELECT, TargetOpcode::G_EXTRACT_VECTOR_ELT, }; @@ -199,6 +198,9 @@ all(typeInSet(0, allBoolScalarsAndVectors), typeInSet(1, allFloatScalarsAndVectors))); + getActionDefinitionsBuilder(G_SELECT).legalIf(all( + typeInSet(0, allScalarsAndVectors), typeInSet(1, allScalarsAndVectors))); + getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,