Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -37,11 +37,6 @@ // MISC // -// Helper class for construction of n-element list [t,t,...,t] -class RepLLVMType { - list ret = !if(N, !listconcat(RepLLVMType.ret, [T]), []); -} - // Helper class that represents a 'fragment' of an NVPTX *MMA instruction. // Geom: mnk. E.g. m8n32k16 // Frag: [abcd] @@ -54,40 +49,40 @@ string ft = frag#":"#ptx_elt_type; list regs = !cond( // mma.sync.m8n8k4 uses smaller a/b fragments than wmma fp ops - !eq(gft,"m8n8k4:a:f16") : RepLLVMType<2, llvm_v2f16_ty>.ret, - !eq(gft,"m8n8k4:b:f16") : RepLLVMType<2, llvm_v2f16_ty>.ret, + !eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2), + !eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2), // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 // All currently supported geometries use the same fragment format, // so we only need to consider {fragment, type}. - !eq(ft,"a:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret, - !eq(ft,"b:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret, - !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, - !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, - !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret, - !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret, + !eq(ft,"a:f16") : !listsplat(llvm_v2f16_ty, 8), + !eq(ft,"b:f16") : !listsplat(llvm_v2f16_ty, 8), + !eq(ft,"c:f16") : !listsplat(llvm_v2f16_ty, 4), + !eq(ft,"d:f16") : !listsplat(llvm_v2f16_ty, 4), + !eq(ft,"c:f32") : !listsplat(llvm_float_ty, 8), + !eq(ft,"d:f32") : !listsplat(llvm_float_ty, 8), // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16 - !eq(gft,"m16n16k16:a:u8") : RepLLVMType<2, llvm_i32_ty>.ret, - !eq(gft,"m16n16k16:a:s8") : RepLLVMType<2, llvm_i32_ty>.ret, - !eq(gft,"m16n16k16:b:u8") : RepLLVMType<2, llvm_i32_ty>.ret, - !eq(gft,"m16n16k16:b:s8") : RepLLVMType<2, llvm_i32_ty>.ret, - !eq(gft,"m16n16k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret, - !eq(gft,"m16n16k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + !eq(gft,"m16n16k16:a:u8") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n16k16:a:s8") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n16k16:b:u8") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n16k16:b:s8") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m16n16k16:c:s32") : !listsplat(llvm_i32_ty, 8), + !eq(gft,"m16n16k16:d:s32") : !listsplat(llvm_i32_ty, 8), !eq(gft,"m8n32k16:a:u8") : [llvm_i32_ty], !eq(gft,"m8n32k16:a:s8") : [llvm_i32_ty], - !eq(gft,"m8n32k16:b:u8") : RepLLVMType<4, llvm_i32_ty>.ret, - !eq(gft,"m8n32k16:b:s8") : RepLLVMType<4, llvm_i32_ty>.ret, - !eq(gft,"m8n32k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret, - !eq(gft,"m8n32k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + !eq(gft,"m8n32k16:b:u8") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m8n32k16:b:s8") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m8n32k16:c:s32") : !listsplat(llvm_i32_ty, 8), + !eq(gft,"m8n32k16:d:s32") : !listsplat(llvm_i32_ty, 8), - !eq(gft,"m32n8k16:a:u8") : RepLLVMType<4, llvm_i32_ty>.ret, - !eq(gft,"m32n8k16:a:s8") : RepLLVMType<4, llvm_i32_ty>.ret, + !eq(gft,"m32n8k16:a:u8") : !listsplat(llvm_i32_ty, 4), + !eq(gft,"m32n8k16:a:s8") : !listsplat(llvm_i32_ty, 4), !eq(gft,"m32n8k16:b:u8") : [llvm_i32_ty], !eq(gft,"m32n8k16:b:s8") : [llvm_i32_ty], - !eq(gft,"m32n8k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret, - !eq(gft,"m32n8k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret, + !eq(gft,"m32n8k16:c:s32") : !listsplat(llvm_i32_ty, 8), + !eq(gft,"m32n8k16:d:s32") : !listsplat(llvm_i32_ty, 8), // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty], @@ -96,10 +91,10 @@ !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty], !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty], !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty], - !eq(gft,"m8n8k128:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret, - !eq(gft,"m8n8k128:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret, - !eq(gft,"m8n8k32:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret, - !eq(gft,"m8n8k32:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret, + !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m8n8k32:c:s32") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m8n8k32:d:s32") : !listsplat(llvm_i32_ty, 2), ); }