diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -595,15 +595,9 @@ MachineInstr &MI, MachineBasicBlock *BB) const; MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const; - MachineBasicBlock *EmitMopa(unsigned Opc, unsigned BaseReg, MachineInstr &MI, - MachineBasicBlock *BB) const; - MachineBasicBlock *EmitInsertVectorToTile(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, - MachineBasicBlock *BB) const; + MachineBasicBlock *EmitZAInstr(unsigned Opc, unsigned BaseReg, + MachineInstr &MI, MachineBasicBlock *BB) const; MachineBasicBlock *EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const; - MachineBasicBlock *EmitAddVectorToTile(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, - MachineBasicBlock *BB) const; MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2676,35 +2676,16 @@ } MachineBasicBlock * -AArch64TargetLowering::EmitMopa(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, MachineBasicBlock *BB) const { +AArch64TargetLowering::EmitZAInstr(unsigned Opc, unsigned BaseReg, + MachineInstr &MI, + MachineBasicBlock *BB) const { const TargetInstrInfo *TII = Subtarget->getInstrInfo(); MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc)); MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define); MIB.addReg(BaseReg + MI.getOperand(0).getImm()); - MIB.add(MI.getOperand(1)); // pn - MIB.add(MI.getOperand(2)); // pm - MIB.add(MI.getOperand(3)); // zn - MIB.add(MI.getOperand(4)); // zm - - MI.eraseFromParent(); // The pseudo is gone now. - return BB; -} - -MachineBasicBlock * -AArch64TargetLowering::EmitInsertVectorToTile(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, - MachineBasicBlock *BB) const { - const TargetInstrInfo *TII = Subtarget->getInstrInfo(); - MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc)); - - MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define); - MIB.addReg(BaseReg + MI.getOperand(0).getImm()); - MIB.add(MI.getOperand(1)); // Slice index register - MIB.add(MI.getOperand(2)); // Slice index offset - MIB.add(MI.getOperand(3)); // pg - MIB.add(MI.getOperand(4)); // zn + for (unsigned I = 1; I < MI.getNumOperands(); ++I) + MIB.add(MI.getOperand(I)); MI.eraseFromParent(); // The pseudo is gone now. return BB; @@ -2727,25 +2708,28 @@ return BB; } -MachineBasicBlock * -AArch64TargetLowering::EmitAddVectorToTile(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, - MachineBasicBlock *BB) const { - const TargetInstrInfo *TII = Subtarget->getInstrInfo(); - MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc)); - - MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define); - MIB.addReg(BaseReg + MI.getOperand(0).getImm()); - MIB.add(MI.getOperand(1)); // pn - MIB.add(MI.getOperand(2)); // pm - MIB.add(MI.getOperand(3)); // zn - - MI.eraseFromParent(); // The pseudo is gone now. - return BB; -} - MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { + + int SMEOrigInstr = AArch64::getSMEPseudoMap(MI.getOpcode()); + if (SMEOrigInstr != -1) { + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + uint64_t SMEMatrixType = + TII->get(MI.getOpcode()).TSFlags & AArch64::SMEMatrixTypeMask; + switch (SMEMatrixType) { + case (AArch64::SMEMatrixTileB): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB); + case (AArch64::SMEMatrixTileH): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB); + case (AArch64::SMEMatrixTileS): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB); + case (AArch64::SMEMatrixTileD): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB); + case (AArch64::SMEMatrixTileQ): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB); + } + } + switch (MI.getOpcode()) { default: #ifndef NDEBUG @@ -2795,94 +2779,8 @@ return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB); case AArch64::LDR_ZA_PSEUDO: return EmitFill(MI, BB); - case AArch64::BFMOPA_MPPZZ_PSEUDO: - return EmitMopa(AArch64::BFMOPA_MPPZZ, AArch64::ZAS0, MI, BB); - case AArch64::BFMOPS_MPPZZ_PSEUDO: - return EmitMopa(AArch64::BFMOPS_MPPZZ, AArch64::ZAS0, MI, BB); - case AArch64::FMOPAL_MPPZZ_PSEUDO: - return EmitMopa(AArch64::FMOPAL_MPPZZ, AArch64::ZAS0, MI, BB); - case AArch64::FMOPSL_MPPZZ_PSEUDO: - return EmitMopa(AArch64::FMOPSL_MPPZZ, AArch64::ZAS0, MI, BB); - case AArch64::FMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::FMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::FMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::FMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::FMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::FMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::FMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::FMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::SMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::SMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::SMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::SMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::UMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::UMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::UMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::UMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::SUMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::SUMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::SUMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::SUMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::USMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::USMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::USMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::USMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::SMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::SMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::SMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::SMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::UMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::UMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::UMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::UMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::SUMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::SUMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::SUMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::SUMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::USMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::USMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::USMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::USMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_B: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_B, AArch64::ZAB0, MI, - BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_H: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_H, AArch64::ZAH0, MI, - BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_S: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_S, AArch64::ZAS0, MI, - BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_D: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_D, AArch64::ZAD0, MI, - BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_Q: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_Q, AArch64::ZAQ0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_B: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_B, AArch64::ZAB0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_H: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_H, AArch64::ZAH0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_S: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_S, AArch64::ZAS0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_D: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_D, AArch64::ZAD0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_Q: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_Q, AArch64::ZAQ0, MI, - BB); case AArch64::ZERO_M_PSEUDO: return EmitZero(MI, BB); - case AArch64::ADDHA_MPPZ_PSEUDO_S: - return EmitAddVectorToTile(AArch64::ADDHA_MPPZ_S, AArch64::ZAS0, MI, BB); - case AArch64::ADDVA_MPPZ_PSEUDO_S: - return EmitAddVectorToTile(AArch64::ADDVA_MPPZ_S, AArch64::ZAS0, MI, BB); - case AArch64::ADDHA_MPPZ_PSEUDO_D: - return EmitAddVectorToTile(AArch64::ADDHA_MPPZ_D, AArch64::ZAD0, MI, BB); - case AArch64::ADDVA_MPPZ_PSEUDO_D: - return EmitAddVectorToTile(AArch64::ADDVA_MPPZ_D, AArch64::ZAD0, MI, BB); } } diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -45,6 +45,17 @@ def FalseLanesZero : FalseLanesEnum<1>; def FalseLanesUndef : FalseLanesEnum<2>; +class SMEMatrixTypeEnum val> { + bits<3> Value = val; +} +def SMEMatrixNone : SMEMatrixTypeEnum<0>; +def SMEMatrixTileB : SMEMatrixTypeEnum<1>; +def SMEMatrixTileH : SMEMatrixTypeEnum<2>; +def SMEMatrixTileS : SMEMatrixTypeEnum<3>; +def SMEMatrixTileD : SMEMatrixTypeEnum<4>; +def SMEMatrixTileQ : SMEMatrixTypeEnum<5>; +def SMEMatrixArray : SMEMatrixTypeEnum<6>; + // AArch64 Instruction Format class AArch64Inst : Instruction { field bits<32> Inst; // Instruction encoding. @@ -65,16 +76,18 @@ bit isPTestLike = 0; FalseLanesEnum FalseLanes = FalseLanesNone; DestructiveInstTypeEnum DestructiveInstType = NotDestructive; + SMEMatrixTypeEnum SMEMatrixType = SMEMatrixNone; ElementSizeEnum ElementSize = ElementSizeNone; - let TSFlags{10} = isPTestLike; - let TSFlags{9} = isWhile; - let TSFlags{8-7} = FalseLanes.Value; - let TSFlags{6-3} = DestructiveInstType.Value; - let TSFlags{2-0} = ElementSize.Value; + let TSFlags{13-11} = SMEMatrixType.Value; + let TSFlags{10} = isPTestLike; + let TSFlags{9} = isWhile; + let TSFlags{8-7} = FalseLanes.Value; + let TSFlags{6-3} = DestructiveInstType.Value; + let TSFlags{2-0} = ElementSize.Value; - let Pattern = []; - let Constraints = cstr; + let Pattern = []; + let Constraints = cstr; } class InstSubst diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -539,10 +539,11 @@ } // struct TSFlags { -#define TSFLAG_ELEMENT_SIZE_TYPE(X) (X) // 3-bits -#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 4-bits -#define TSFLAG_FALSE_LANE_TYPE(X) ((X) << 7) // 2-bits -#define TSFLAG_INSTR_FLAGS(X) ((X) << 9) // 2-bits +#define TSFLAG_ELEMENT_SIZE_TYPE(X) (X) // 3-bits +#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 4-bits +#define TSFLAG_FALSE_LANE_TYPE(X) ((X) << 7) // 2-bits +#define TSFLAG_INSTR_FLAGS(X) ((X) << 9) // 2-bits +#define TSFLAG_SME_MATRIX_TYPE(X) ((X) << 11) // 3-bits // } namespace AArch64 { @@ -580,14 +581,28 @@ static const uint64_t InstrFlagIsWhile = TSFLAG_INSTR_FLAGS(0x1); static const uint64_t InstrFlagIsPTestLike = TSFLAG_INSTR_FLAGS(0x2); +enum SMEMatrixType { + SMEMatrixTypeMask = TSFLAG_SME_MATRIX_TYPE(0x7), + SMEMatrixNone = TSFLAG_SME_MATRIX_TYPE(0x0), + SMEMatrixTileB = TSFLAG_SME_MATRIX_TYPE(0x1), + SMEMatrixTileH = TSFLAG_SME_MATRIX_TYPE(0x2), + SMEMatrixTileS = TSFLAG_SME_MATRIX_TYPE(0x3), + SMEMatrixTileD = TSFLAG_SME_MATRIX_TYPE(0x4), + SMEMatrixTileQ = TSFLAG_SME_MATRIX_TYPE(0x5), + SMEMatrixArray = TSFLAG_SME_MATRIX_TYPE(0x6), +}; + #undef TSFLAG_ELEMENT_SIZE_TYPE #undef TSFLAG_DESTRUCTIVE_INST_TYPE #undef TSFLAG_FALSE_LANE_TYPE #undef TSFLAG_INSTR_FLAGS +#undef TSFLAG_SME_MATRIX_TYPE int getSVEPseudoMap(uint16_t Opcode); int getSVERevInstr(uint16_t Opcode); int getSVENonRevInstr(uint16_t Opcode); + +int getSMEPseudoMap(uint16_t Opcode); } } // end namespace llvm diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -49,15 +49,15 @@ def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>; def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>; -def ADDHA_MPPZ_S : sme_add_vector_to_tile_u32<0b0, "addha">; -def ADDVA_MPPZ_S : sme_add_vector_to_tile_u32<0b1, "addva">; +defm ADDHA_MPPZ_S : sme_add_vector_to_tile_u32<0b0, "addha", int_aarch64_sme_addha>; +defm ADDVA_MPPZ_S : sme_add_vector_to_tile_u32<0b1, "addva", int_aarch64_sme_addva>; def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>; } let Predicates = [HasSMEI16I64] in { -def ADDHA_MPPZ_D : sme_add_vector_to_tile_u64<0b0, "addha">; -def ADDVA_MPPZ_D : sme_add_vector_to_tile_u64<0b1, "addva">; +defm ADDHA_MPPZ_D : sme_add_vector_to_tile_u64<0b0, "addha", int_aarch64_sme_addha>; +defm ADDVA_MPPZ_D : sme_add_vector_to_tile_u64<0b1, "addva", int_aarch64_sme_addva>; } let Predicates = [HasSME] in { diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -25,17 +25,35 @@ def am_sme_indexed_b4 :ComplexPattern", [], [SDNPWantRoot]>; //===----------------------------------------------------------------------===// -// SME Outer Products +// SME Pseudo Classes //===----------------------------------------------------------------------===// -class sme_outer_product_pseudo +def getSMEPseudoMap : InstrMapping { + let FilterClass = "SMEPseudo2Instr"; + let RowFields = ["PseudoName"]; + let ColFields = ["IsInstr"]; + let KeyCol = ["0"]; + let ValueCols = [["1"]]; +} + +class SMEPseudo2Instr { + string PseudoName = name; + bit IsInstr = instr; +} + +class sme_outer_product_pseudo : Pseudo<(outs), (ins i32imm:$tile, PPR3bAny:$pn, PPR3bAny:$pm, zpr_ty:$zn, zpr_ty:$zm), []>, Sched<[]> { // Translated to the actual instructions in AArch64ISelLowering.cpp + let SMEMatrixType = za_flag; let usesCustomInserter = 1; } +//===----------------------------------------------------------------------===// +// SME Outer Products +//===----------------------------------------------------------------------===// + class sme_fp_outer_product_inst sz, bit op, MatrixTileOperand za_ty, ZPRRegOp zpr_ty, string mnemonic> : I<(outs za_ty:$ZAda), @@ -62,13 +80,13 @@ } multiclass sme_outer_product_fp32 { - def NAME : sme_fp_outer_product_inst { + def NAME : sme_fp_outer_product_inst, SMEPseudo2Instr { bits<2> ZAda; let Inst{1-0} = ZAda; let Inst{2} = 0b0; } - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm), (nxv4f32 ZPR32:$zn), (nxv4f32 ZPR32:$zm)), @@ -76,12 +94,12 @@ } multiclass sme_outer_product_fp64 { - def NAME : sme_fp_outer_product_inst { + def NAME : sme_fp_outer_product_inst, SMEPseudo2Instr { bits<3> ZAda; let Inst{2-0} = ZAda; } - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm), (nxv2f64 ZPR64:$zn), (nxv2f64 ZPR64:$zm)), @@ -126,13 +144,13 @@ multiclass sme_int_outer_product_i32 opc, string mnemonic, SDPatternOperator op> { def NAME : sme_int_outer_product_inst { + ZPR8, mnemonic>, SMEPseudo2Instr { bits<2> ZAda; let Inst{1-0} = ZAda; let Inst{2} = 0b0; } - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_3:$tile, (nxv16i1 PPR3bAny:$pn), (nxv16i1 PPR3bAny:$pm), (nxv16i8 ZPR8:$zn), (nxv16i8 ZPR8:$zm)), @@ -142,12 +160,12 @@ multiclass sme_int_outer_product_i64 opc, string mnemonic, SDPatternOperator op> { def NAME : sme_int_outer_product_inst { + ZPR16, mnemonic>, SMEPseudo2Instr { bits<3> ZAda; let Inst{2-0} = ZAda; } - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_7:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm), (nxv8i16 ZPR16:$zn), (nxv8i16 ZPR16:$zm)), @@ -182,9 +200,9 @@ } multiclass sme_bf16_outer_product opc, string mnemonic, SDPatternOperator op> { - def NAME : sme_outer_product_widening_inst; + def NAME : sme_outer_product_widening_inst, SMEPseudo2Instr; - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_3:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm), (nxv8bf16 ZPR16:$zn), (nxv8bf16 ZPR16:$zm)), @@ -192,9 +210,9 @@ } multiclass sme_f16_outer_product opc, string mnemonic, SDPatternOperator op> { - def NAME : sme_outer_product_widening_inst; + def NAME : sme_outer_product_widening_inst, SMEPseudo2Instr; - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_3:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm), (nxv8f16 ZPR16:$zn), (nxv8f16 ZPR16:$zm)), @@ -226,51 +244,42 @@ let Constraints = "$ZAda = $_ZAda"; } -class sme_add_vector_to_tile_u32 - : sme_add_vector_to_tile_inst<0b0, V, TileOp32, ZPR32, mnemonic> { - bits<2> ZAda; - let Inst{2} = 0b0; - let Inst{1-0} = ZAda; -} - -class sme_add_vector_to_tile_u64 - : sme_add_vector_to_tile_inst<0b1, V, TileOp64, ZPR64, mnemonic> { - bits<3> ZAda; - let Inst{2-0} = ZAda; -} - -class sme_add_vector_to_tile_pseudo +class sme_add_vector_to_tile_pseudo : Pseudo<(outs), (ins i32imm:$tile, PPR3bAny:$Pn, PPR3bAny:$Pm, zpr_ty:$Zn), []>, Sched<[]> { // Translated to the actual instructions in AArch64ISelLowering.cpp + let SMEMatrixType = za_flag; let usesCustomInserter = 1; } -def ADDHA_MPPZ_PSEUDO_S : sme_add_vector_to_tile_pseudo; -def ADDVA_MPPZ_PSEUDO_S : sme_add_vector_to_tile_pseudo; +multiclass sme_add_vector_to_tile_u32 { + def NAME : sme_add_vector_to_tile_inst<0b0, V, TileOp32, ZPR32, mnemonic>, SMEPseudo2Instr { + bits<2> ZAda; + let Inst{2} = 0b0; + let Inst{1-0} = ZAda; + } + + def _PSEUDO_S : sme_add_vector_to_tile_pseudo, SMEPseudo2Instr; -def : Pat<(int_aarch64_sme_addha - timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm), - (nxv4i32 ZPR32:$zn)), - (ADDHA_MPPZ_PSEUDO_S timm32_0_3:$tile, $pn, $pm, $zn)>; -def : Pat<(int_aarch64_sme_addva - timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm), + def : Pat<(op timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm), (nxv4i32 ZPR32:$zn)), - (ADDVA_MPPZ_PSEUDO_S timm32_0_3:$tile, $pn, $pm, $zn)>; + (!cast(NAME # _PSEUDO_S) timm32_0_3:$tile, $pn, $pm, $zn)>; +} + +multiclass sme_add_vector_to_tile_u64 { + def NAME : sme_add_vector_to_tile_inst<0b1, V, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr { + bits<3> ZAda; + let Inst{2-0} = ZAda; + } -let Predicates = [HasSMEI16I64] in { -def ADDHA_MPPZ_PSEUDO_D : sme_add_vector_to_tile_pseudo; -def ADDVA_MPPZ_PSEUDO_D : sme_add_vector_to_tile_pseudo; + def _PSEUDO_D : sme_add_vector_to_tile_pseudo, SMEPseudo2Instr; -def : Pat<(int_aarch64_sme_addha - timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm), - (nxv2i64 ZPR64:$zn)), - (ADDHA_MPPZ_PSEUDO_D timm32_0_7:$tile, $pn, $pm, $zn)>; -def : Pat<(int_aarch64_sme_addva - timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm), - (nxv2i64 ZPR64:$zn)), - (ADDVA_MPPZ_PSEUDO_D timm32_0_7:$tile, $pn, $pm, $zn)>; + let Predicates = [HasSMEI16I64] in { + def : Pat<(op timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm), + (nxv2i64 ZPR64:$zn)), + (!cast(NAME # _PSEUDO_D) timm32_0_7:$tile, $pn, $pm, $zn)>; + } } //===----------------------------------------------------------------------===// @@ -711,24 +720,27 @@ } } -class sme_mova_insert_pseudo +class sme_mova_insert_pseudo : Pseudo<(outs), (ins i32imm:$tile, MatrixIndexGPR32Op12_15:$idx, i32imm:$imm, PPR3bAny:$pg, ZPRAny:$zn), []>, Sched<[]> { // Translated to the actual instructions in AArch64ISelLowering.cpp + let SMEMatrixType = za_flag; let usesCustomInserter = 1; } multiclass sme_vector_v_to_tile { def _B : sme_vector_to_tile_inst<0b0, 0b00, !if(is_col, TileVectorOpV8, TileVectorOpH8), - is_col, sme_elm_idx0_15, ZPR8, mnemonic> { + is_col, sme_elm_idx0_15, ZPR8, mnemonic>, + SMEPseudo2Instr { bits<4> imm; let Inst{3-0} = imm; } def _H : sme_vector_to_tile_inst<0b0, 0b01, !if(is_col, TileVectorOpV16, TileVectorOpH16), - is_col, sme_elm_idx0_7, ZPR16, mnemonic> { + is_col, sme_elm_idx0_7, ZPR16, mnemonic>, + SMEPseudo2Instr { bits<1> ZAd; bits<3> imm; let Inst{3} = ZAd; @@ -736,7 +748,8 @@ } def _S : sme_vector_to_tile_inst<0b0, 0b10, !if(is_col, TileVectorOpV32, TileVectorOpH32), - is_col, sme_elm_idx0_3, ZPR32, mnemonic> { + is_col, sme_elm_idx0_3, ZPR32, mnemonic>, + SMEPseudo2Instr { bits<2> ZAd; bits<2> imm; let Inst{3-2} = ZAd; @@ -744,7 +757,8 @@ } def _D : sme_vector_to_tile_inst<0b0, 0b11, !if(is_col, TileVectorOpV64, TileVectorOpH64), - is_col, sme_elm_idx0_1, ZPR64, mnemonic> { + is_col, sme_elm_idx0_1, ZPR64, mnemonic>, + SMEPseudo2Instr { bits<3> ZAd; bits<1> imm; let Inst{3-1} = ZAd; @@ -752,7 +766,8 @@ } def _Q : sme_vector_to_tile_inst<0b1, 0b11, !if(is_col, TileVectorOpV128, TileVectorOpH128), - is_col, sme_elm_idx0_0, ZPR128, mnemonic> { + is_col, sme_elm_idx0_0, ZPR128, mnemonic>, + SMEPseudo2Instr { bits<4> ZAd; bits<1> imm; let Inst{3-0} = ZAd; @@ -760,11 +775,11 @@ // Pseudo instructions for lowering intrinsics, using immediates instead of // tile registers. - def _PSEUDO_B : sme_mova_insert_pseudo; - def _PSEUDO_H : sme_mova_insert_pseudo; - def _PSEUDO_S : sme_mova_insert_pseudo; - def _PSEUDO_D : sme_mova_insert_pseudo; - def _PSEUDO_Q : sme_mova_insert_pseudo; + def _PSEUDO_B : sme_mova_insert_pseudo, SMEPseudo2Instr; + def _PSEUDO_H : sme_mova_insert_pseudo, SMEPseudo2Instr; + def _PSEUDO_S : sme_mova_insert_pseudo, SMEPseudo2Instr; + def _PSEUDO_D : sme_mova_insert_pseudo, SMEPseudo2Instr; + def _PSEUDO_Q : sme_mova_insert_pseudo, SMEPseudo2Instr; defm : sme_vector_to_tile_aliases(NAME # _B), !if(is_col, TileVectorOpV8,