diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -191,6 +191,7 @@ struct RISCVMaskedPseudoInfo { uint16_t MaskedPseudo; uint16_t UnmaskedPseudo; + uint16_t UnmaskedTUPseudo; uint8_t MaskOpIdx; }; diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -2265,33 +2265,52 @@ const MCInstrDesc &MaskedMCID = TII->get(N->getMachineOpcode()); + bool IsTA = true; if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) { - // The last operand of the pseudo is the policy op, but we're expecting a - // Glue operand last. We may also have a chain. + // The last operand of the pseudo is the policy op, but we might have a + // Glue operand last. We might also have a chain. TailPolicyOpIdx = N->getNumOperands() - 1; if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Glue) (*TailPolicyOpIdx)--; if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Other) (*TailPolicyOpIdx)--; - // If the policy isn't TAIL_AGNOSTIC we can't perform this optimization. - if (N->getConstantOperandVal(*TailPolicyOpIdx) != RISCVII::TAIL_AGNOSTIC) - return false; + if (!(N->getConstantOperandVal(*TailPolicyOpIdx) & + RISCVII::TAIL_AGNOSTIC)) { + // Keep the true-masked instruction when there is no unmasked TU + // instruction + if (I->UnmaskedTUPseudo == I->MaskedPseudo && !N->getOperand(0).isUndef()) + return false; + // We can't use TA if the tie-operand is not IMPLICIT_DEF + if (!N->getOperand(0).isUndef()) + IsTA = false; + } } - const MCInstrDesc &UnmaskedMCID = TII->get(I->UnmaskedPseudo); + if (IsTA) { + uint64_t TSFlags = TII->get(I->UnmaskedPseudo).TSFlags; - // Check that we're dropping the merge operand, the mask operand, and any - // policy operand when we transform to this unmasked pseudo. - assert(!RISCVII::hasMergeOp(UnmaskedMCID.TSFlags) && - RISCVII::hasDummyMaskOp(UnmaskedMCID.TSFlags) && - !RISCVII::hasVecPolicyOp(UnmaskedMCID.TSFlags) && - "Unexpected pseudo to transform to"); - (void)UnmaskedMCID; + // Check that we're dropping the merge operand, the mask operand, and any + // policy operand when we transform to this unmasked pseudo. + assert(!RISCVII::hasMergeOp(TSFlags) && RISCVII::hasDummyMaskOp(TSFlags) && + !RISCVII::hasVecPolicyOp(TSFlags) && + "Unexpected pseudo to transform to"); + (void)TSFlags; + } else { + uint64_t TSFlags = TII->get(I->UnmaskedTUPseudo).TSFlags; + + // Check that we're dropping the mask operand, and any policy operand + // when we transform to this unmasked tu pseudo. + assert(RISCVII::hasMergeOp(TSFlags) && RISCVII::hasDummyMaskOp(TSFlags) && + !RISCVII::hasVecPolicyOp(TSFlags) && + "Unexpected pseudo to transform to"); + (void)TSFlags; + } + unsigned Opc = IsTA ? I->UnmaskedPseudo : I->UnmaskedTUPseudo; SmallVector Ops; - // Skip the merge operand at index 0. - for (unsigned I = 1, E = N->getNumOperands(); I != E; I++) { + // Skip the merge operand at index 0 if IsTA + for (unsigned I = IsTA, E = N->getNumOperands(); I != E; I++) { // Skip the mask, the policy, and the Glue. SDValue Op = N->getOperand(I); if (I == MaskOpIdx || I == TailPolicyOpIdx || @@ -2304,8 +2323,7 @@ if (auto *TGlued = Glued->getGluedNode()) Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1)); - SDNode *Result = - CurDAG->getMachineNode(I->UnmaskedPseudo, SDLoc(N), N->getVTList(), Ops); + SDNode *Result = CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops); ReplaceUses(N, Result); return true; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -423,16 +423,17 @@ let PrimaryKeyName = "getRISCVVIntrinsicInfo"; } -class RISCVMaskedPseudo MaskIdx> { +class RISCVMaskedPseudo MaskIdx, bit HasTU = true> { Pseudo MaskedPseudo = !cast(NAME); Pseudo UnmaskedPseudo = !cast(!subst("_MASK", "", NAME)); + Pseudo UnmaskedTUPseudo = !if(HasTU, !cast(!subst("_MASK", "", NAME # "_TU")), MaskedPseudo); bits<4> MaskOpIdx = MaskIdx; } def RISCVMaskedPseudosTable : GenericTable { let FilterClass = "RISCVMaskedPseudo"; let CppTypeName = "RISCVMaskedPseudoInfo"; - let Fields = ["MaskedPseudo", "UnmaskedPseudo", "MaskOpIdx"]; + let Fields = ["MaskedPseudo", "UnmaskedPseudo", "UnmaskedTUPseudo", "MaskOpIdx"]; let PrimaryKey = ["MaskedPseudo"]; let PrimaryKeyName = "getMaskedPseudoInfo"; } @@ -1770,7 +1771,7 @@ let ForceTailAgnostic = true in def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMOutMask, - RISCVMaskedPseudo; + RISCVMaskedPseudo; } } diff --git a/llvm/test/CodeGen/RISCV/rvv/allone-masked-to-unmasked.ll b/llvm/test/CodeGen/RISCV/rvv/allone-masked-to-unmasked.ll --- a/llvm/test/CodeGen/RISCV/rvv/allone-masked-to-unmasked.ll +++ b/llvm/test/CodeGen/RISCV/rvv/allone-masked-to-unmasked.ll @@ -31,14 +31,12 @@ ret %a } -; FIXME: Use an unmasked TAIL_AGNOSTIC instruction if the tie operand is IMPLICIT_DEF +; Use an unmasked TAIL_AGNOSTIC instruction if the tie operand is IMPLICIT_DEF define @test1( %0, %1, iXLen %2) nounwind { ; CHECK-LABEL: test1: ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu -; CHECK-NEXT: vmset.m v0 -; CHECK-NEXT: vsetvli zero, zero, e8, mf8, tu, mu -; CHECK-NEXT: vadd.vv v8, v8, v9, v0.t +; CHECK-NEXT: vadd.vv v8, v8, v9 ; CHECK-NEXT: ret entry: %allone = call @llvm.riscv.vmset.nxv1i1( @@ -53,14 +51,12 @@ ret %a } -; FIXME: Use an unmasked TU instruction because of the policy operand +; Use an unmasked TU instruction because of the policy operand define @test2( %0, %1, %2, iXLen %3) nounwind { ; CHECK-LABEL: test2: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, mu -; CHECK-NEXT: vmset.m v0 -; CHECK-NEXT: vsetvli zero, zero, e8, mf8, tu, mu -; CHECK-NEXT: vadd.vv v8, v9, v10, v0.t +; CHECK-NEXT: vsetvli zero, a0, e8, mf8, tu, mu +; CHECK-NEXT: vadd.vv v8, v9, v10 ; CHECK-NEXT: ret entry: %allone = call @llvm.riscv.vmset.nxv1i1(