diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -43,7 +43,7 @@ // Helper class that represents a 'fragment' of an NVPTX *MMA instruction. // Geom: mnk. E.g. m8n32k16 -// Frag: [abcd] +// Frag: [a|b|c|d] ([x1|x2|x4] for ldmatrix) // PtxEltType: PTX type for the element. class WMMA_REGS { string geom = Geom; @@ -190,6 +190,11 @@ !eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4), !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4), + + // ldmatrix b16 -> s32 @ m8n8 + !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1), + !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2), + !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4), ); } @@ -256,6 +261,17 @@ !subst("llvm.", "int_", llvm)); } +class LDMATRIX_NAME { + string intr = "llvm.nvvm.ldmatrix.sync.aligned" + # "." # Frag.geom + # "." # Frag.frag + # !if(Trans, ".trans", "") + # "." # Frag.ptx_elt_type + ; + string record = !subst(".", "_", + !subst("llvm.", "int_", intr)); +} + // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. // Geom: list of supported geometries. // TypeN: PTX type of the corresponding fragment's element. @@ -286,6 +302,16 @@ list ops = !foreach(x, ret, x.gft); } +class LDMATRIX_OPS Geom, list Frags, list Types> { + list ret = + !foldl([], Geom, t1, geom, !listconcat(t1, + !foldl([], Frags, t2, frag, !listconcat(t2, + !foldl([], Types, t3, type, !listconcat(t3, + [WMMA_REGS])))))); + // Debugging aid for readable representation of the list above. + list ops = !foreach(x, ret, x.gft); +} + // Creates list of valid combinations of fragments. This is the master list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { @@ -370,11 +396,14 @@ // Separate A/B/C fragments (loads) from D (stores). list all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d")); list all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d")); + + list ldmatrix_b16_ops = LDMATRIX_OPS< + ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret; + list all_ldmatrix_ops = ldmatrix_b16_ops; } def NVVM_MMA_OPS : NVVM_MMA_OPS; - // Returns true if this combination of fragment and layout for WMMA load/store // ops is supported; false otherwise. // E.g. @@ -489,6 +518,23 @@ ); } +// Returns true if the fragment is valid for ldmatrix ops is supported; +// false otherwise. +// E.g. +// if NVVM_LDMATRIX_SUPPORTED<...>.ret then +// def : FOO<>; // The record will only be defined for supported ops. +// +class NVVM_LDMATRIX_SUPPORTED { + string g = frag.geom; + string t = frag.ptx_elt_type; + + bit ret = !cond( + // Only currently support m8n8 and b16 + !and(!eq(g, "m8n8"), !eq(t, "b16")): true, + true: false + ); +} + class SHFL_INFO { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -4519,4 +4565,20 @@ } // layout_b } // layout_a +// LDMATRIX +class NVVM_LDMATRIX + : Intrinsic>, + NoCapture>], + LDMATRIX_NAME.intr>; + +foreach transposed = [0, 1] in { + foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in { + if NVVM_LDMATRIX_SUPPORTED.ret then { + def LDMATRIX_NAME.record + : NVVM_LDMATRIX; + } + } +} + } // let TargetPrefix = "nvvm" diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3547,7 +3547,9 @@ case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col: case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride: case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row: - case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: + case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16: + case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v4i32; Info.ptrVal = I.getArgOperand(0); @@ -3585,7 +3587,9 @@ case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col: case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride: - case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: { + case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: + case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16: + case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::i32; Info.ptrVal = I.getArgOperand(0); @@ -3679,7 +3683,9 @@ case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col: case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: { + case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: + case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16: + case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7578,6 +7578,7 @@ !eq(ptx_elt_type, "bf16") : Int32Regs, !eq(ptx_elt_type, "tf32") : Int32Regs, !eq(ptx_elt_type, "s32") : Int32Regs, + !eq(ptx_elt_type, "b16") : Int32Regs, !eq(ptx_elt_type, "s8") : Int32Regs, !eq(ptx_elt_type, "u8") : Int32Regs, !eq(ptx_elt_type, "s4") : Int32Regs, @@ -7661,7 +7662,11 @@ !eq(geom, "m16n8k64"), !eq(geom, "m8n8k128"), !eq(geom, "m16n8k128"), - !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]); + !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70], + + !and(!eq(op,"ldmatrix"), + !eq(ptx_elt_type,"b16"), + !eq(geom, "m8n8")) : [hasSM75, hasPTX65]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -7910,6 +7915,44 @@ } // layout_a } // defset +// +// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 +// +class LDMATRIX + : WMMA_INSTR.record, [(ins SrcOp:$src)]>, + Requires { + // Build PatFrag that only matches particular address space. + PatFrag IntrFrag = PatFrag<(ops node:$src), (Intr node:$src), + !cond(!eq(Space, ".shared"): AS_match.shared, + true: AS_match.generic)>; + // Build AS-constrained pattern. + let IntrinsicPattern = BuildPatternPF.ret; + + let OutOperandList = Frag.Outs; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + let AsmString = "ldmatrix.sync.aligned." + # Frag.geom + # "." # Frag.frag + # !if(Transposed, ".trans", "") + # Space + # "." # Frag.ptx_elt_type + # " " # Frag.regstring # ", [$src];"; +} + +// Create all ldmatrix variants +defset list LDMATRIXs = { + foreach transposed = [false, true] in { + foreach space = [".shared", ""] in { + foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { + foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in + if NVVM_LDMATRIX_SUPPORTED.ret then + def : LDMATRIX, transposed, space, + addr>; + } // addr + } // space + } // transposed +} // defset // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with @@ -7921,5 +7964,5 @@ Requires; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in def : MMA_PAT; diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -6,7 +6,7 @@ # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT +# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \ # RUN: | FileCheck %t-ptx60-sm_70.ll @@ -15,7 +15,7 @@ # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT +# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \ # RUN: | FileCheck %t-ptx61-sm_70.ll @@ -24,7 +24,7 @@ # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \ -# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT +# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \ # RUN: | FileCheck %t-ptx63-sm_72.ll @@ -33,7 +33,7 @@ # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \ -# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT +# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \ # RUN: | FileCheck %t-ptx63-sm_75.ll @@ -42,14 +42,14 @@ # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \ # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \ -# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT +# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT,NOLDMATRIX # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \ # RUN: | FileCheck %t-ptx64-sm_70.ll # Check all variants of instructions supported by PTX65 on SM75+ # RUN: %python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \ -# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA,PTX65LDMATRIX # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \ # RUN: --check-prefixes=INTRINSICS # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \ @@ -58,7 +58,7 @@ # Check all variants of instructions supported by PTX71 on SM80+ # RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \ -# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX71MMA +# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX65LDMATRIX,PTX71MMA # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \ # RUN: --check-prefixes=INTRINSICS # RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \ @@ -78,6 +78,7 @@ "f32" : "float", "f64" : "double", "s32" : "i32", + "b16" : "i32", "s8" : "i32", "u8" : "i32", "s4" : "i32", @@ -232,6 +233,11 @@ "m16n8k16:d:f16": 2, "m16n8k16:c:f32": 4, "m16n8k16:d:f32": 4, + + # ldmatrix + "m8n8:x1:b16": 1, + "m8n8:x2:b16": 2, + "m8n8:x4:b16": 4, }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), { # All other FP shape/fragment/type combinations have the same size "a:f16" : 8, @@ -272,6 +278,10 @@ return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type) in product(geoms, frags, types)] +def make_ldmatrix_ops(geoms, frags, types): + return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type) + in product(geoms, frags, types)] + def get_wmma_ops(): return (make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) + @@ -317,6 +327,9 @@ make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])) return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")] +def get_ldmatrix_ops(): + return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + def is_wmma_geom_supported(geom): # geometries for FP and ints. if geom in ["m8n32k16", "m32n8k16"]: @@ -343,11 +356,18 @@ return ptx_version >= 70 assert(False) # Unexpected geometry. +def is_ldmatrix_geom_supported(geom): + if geom in ["m8n8"]: + return ptx_version >= 65 and gpu_arch >= 75 + assert(False) # Unexpected geometry. + def is_type_supported(ptx_type): if ptx_type in ["s8", "u8", "s32"]: return ptx_version >= 63 and gpu_arch >= 72 if ptx_type in ["s4", "u4", "b1"]: return ptx_version >= 63 and gpu_arch >= 75 + if ptx_type == "b16": + return ptx_version >= 65 and gpu_arch >= 75 if ptx_type in ["bf16", "tf32", "f64"]: return ptx_version >= 70 return ptx_version >= 60 and gpu_arch >= 70 @@ -413,6 +433,12 @@ or frag.frag in ["c", "d"]) return True +def is_ldmatrix_variant_supported(frag): + if not (is_type_supported(frag.mma_type.ptx_type) + and is_ldmatrix_geom_supported(frag.geom)): + return False + return frag.frag in ["x1", "x2", "x4"] + def make_wmma_slice_ty(frag): return [frag.mma_type.llvm_type] * frag.nregs @@ -584,6 +610,66 @@ return generated_items +def gen_ldmatrix_tests(): + ldmatrix_template = """ +declare ${ret_ty} @${intrinsic}(i8 ${as}* %src); + +; CHECK-LABEL: .func {{.*}}test_${function}( +define ${ret_ty} @test_${function}(i8 ${as}* %src) { +; CHECK: ${instruction} +; CHECK: {${check_result}} +; CHECK: [%rd{{[0-9]+}}] + %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src); + ret ${ret_ty} %v0; +} + +; CHECK-LABEL: .func{{.*}}test_${function}_o( +define ${ret_ty} @test_${function}_o(i8 ${as}* %src) { +; CHECK: ${instruction} +; CHECK: {${check_result}} +; CHECK: [%rd{{[0-9]+}}+128] + %src1 = getelementptr i8, i8 ${as}* %src, i32 128; + %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1); + ret ${ret_ty} %v0; +} +""" + intrinsic_template = "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" + instruction_template = "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" + + generated_items = [] + + for frag, space, trans in product( + get_ldmatrix_ops(), + ["",".shared"], + ["",".trans"], + ): + if not is_ldmatrix_variant_supported(frag): + continue + + params = { + "frag" : frag.frag, + "space" : space, + "trans" : trans, + "itype" : frag.mma_type.ptx_type, + "pspace" : get_pspace(space), + "as" : "addrspace(%d)" % get_aspace(space), + "geom" : frag.geom, + } + + test_params = params + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".","_") + test_params["instruction"] = Template(instruction_template).substitute(params) + test_params["ret_ty"] = make_wmma_ld_ret_ty(frag) + test_params["check_result"] = check_pattern(frag) + + print(Template(ldmatrix_template).substitute(test_params)) + + generated_items.append((test_params["intrinsic"], + test_params["instruction"])) + + return generated_items + def mma_signature(op): if op.a.mma_type.ptx_type == "f16": # FP16 ops identified by accumulator & result type. @@ -744,6 +830,7 @@ ; NOMMA-NOT: .m8n8k4. ; NOALTFLOAT-NOT: .{{bf16|tf32}} ; NODOUBLE-NOT: .f64 +; NOLDMATRIX-NOT: ldmatrix.sync.aligned ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p @@ -819,6 +906,19 @@ ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 +; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 + ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 @@ -861,6 +961,7 @@ def gen_tests(): items = gen_wmma_load_tests() items += gen_wmma_store_tests() + items += gen_ldmatrix_tests() items += gen_wmma_mma_tests() items += gen_mma_tests() gen_check_unsupported_ops(items)