diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp --- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp +++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp @@ -116,12 +116,7 @@ } static bool isDGEMM(unsigned Opcode) { - return Opcode == AMDGPU::V_MFMA_F64_4X4X4F64_e64 || - Opcode == AMDGPU::V_MFMA_F64_4X4X4F64_vgprcd_e64 || - Opcode == AMDGPU::V_MFMA_F64_16X16X4F64_e64 || - Opcode == AMDGPU::V_MFMA_F64_16X16X4F64_vgprcd_e64 || - Opcode == AMDGPU::V_MFMA_F64_16X16X4F64_mac_e64 || - Opcode == AMDGPU::V_MFMA_F64_16X16X4F64_mac_vgprcd_e64; + return AMDGPU::getMAIIsDGEMM(Opcode); } static bool isXDL(const GCNSubtarget &ST, const MachineInstr &MI) { diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h @@ -70,6 +70,7 @@ struct MAIInstInfo { uint16_t Opcode; + bool is_dgemm; bool is_gfx940_xdl; }; @@ -450,6 +451,10 @@ LLVM_READONLY bool getVOP3IsSingle(unsigned Opc); +/// Returns true if MAI operation is a double precision GEMM. +LLVM_READONLY +bool getMAIIsDGEMM(unsigned Opc); + LLVM_READONLY bool getMAIIsGFX940XDL(unsigned Opc); diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp @@ -343,6 +343,11 @@ return Info ? Info->IsSingle : false; } +bool getMAIIsDGEMM(unsigned Opc) { + const MAIInstInfo *Info = getMAIInstInfoHelper(Opc); + return Info ? Info->is_dgemm : false; +} + bool getMAIIsGFX940XDL(unsigned Opc) { const MAIInstInfo *Info = getMAIInstInfoHelper(Opc); return Info ? Info->is_gfx940_xdl : false; diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -487,6 +487,7 @@ class MAIInst : VOP3InstBase { Instruction Opcode = !cast(NAME); + bit is_dgemm = 0; bit is_gfx940_xdl = 0; } @@ -559,8 +560,10 @@ defm V_MFMA_F32_16X16X16BF16_1K : MAIInst<"v_mfma_f32_16x16x16bf16_1k", "F32_V4I16_X4", int_amdgcn_mfma_f32_16x16x16bf16_1k>; } + let is_dgemm = 1 in { defm V_MFMA_F64_16X16X4F64 : MAIInst<"v_mfma_f64_16x16x4f64", "F64_16X16X4F64", int_amdgcn_mfma_f64_16x16x4f64>; defm V_MFMA_F64_4X4X4F64 : MAIInst<"v_mfma_f64_4x4x4f64", "F64_4X4X4F64", int_amdgcn_mfma_f64_4x4x4f64>; + } } // End Predicates = [isGFX90APlus] let Predicates = [isGFX940Plus], is_gfx940_xdl = 1 in { @@ -590,7 +593,7 @@ let FilterClass = "MAIInst"; let CppTypeName = "MAIInstInfo"; let Fields = [ - "Opcode", "is_gfx940_xdl" + "Opcode", "is_dgemm", "is_gfx940_xdl" ]; let PrimaryKey = ["Opcode"];