Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -3869,4 +3869,150 @@ Intrinsic<[llvm_i64_ty, llvm_i1_ty], [llvm_i32_ty, llvm_i64_ty], [IntrNoMem, IntrConvergent], "llvm.nvvm.match.all.sync.i64p">; +// +// WMMA instructions +// + +// WMMA.LOAD +class NVVM_WMMA_LD_ALSTS + : Intrinsic; + +multiclass NVVM_WMMA_LD_ALST { + def _stride: NVVM_WMMA_LD_ALSTS; + def NAME : NVVM_WMMA_LD_ALSTS; +} + +multiclass NVVM_WMMA_LD_ALT { + defm _global: NVVM_WMMA_LD_ALST; + defm _shared: NVVM_WMMA_LD_ALST; + defm NAME: NVVM_WMMA_LD_ALST; +} + +multiclass NVVM_WMMA_LD_AT { + defm _row: NVVM_WMMA_LD_ALT; + defm _col: NVVM_WMMA_LD_ALT; +} + +// For some reason ReadOnly and NoCapture confuses tblgen if they are +// passed to Intrinsic<> form inside of a multiclass. Setting them globally +// outside of the multiclass works. +let IntrProperties = [IntrReadMem, IntrArgMemOnly, + ReadOnly<0>, NoCapture<0>] in { + defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>; + defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>; + defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>; + defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>; +} + +// WMMA.STORE.D +class NVVM_WMMA_STD_LSTSEmpty=[]> + : Intrinsic<[], + !listconcat( + [llvm_ptr_ty], + !if(!eq(Type,"f16"), + [regty, regty, regty, regty], + [regty, regty, regty, regty, + regty, regty, regty, regty]), + !if(WithStride, [llvm_i32_ty], Empty)), + [], // Properties must be set during instantiation. + "llvm.nvvm.wmma.store.d.sync."#Layout + #".m16n16k16"#Space + #!if(WithStride,".stride","") + #"."#Type>; + +multiclass NVVM_WMMA_STD_LST { + def _stride: NVVM_WMMA_STD_LSTS; + def NAME: NVVM_WMMA_STD_LSTS; +} + +multiclass NVVM_WMMA_STD_LT { + defm _global: NVVM_WMMA_STD_LST; + defm _shared: NVVM_WMMA_STD_LST; + defm NAME: NVVM_WMMA_STD_LST; +} + +multiclass NVVM_WMMA_STD_T { + defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>; + defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>; +} + +let IntrProperties = [IntrWriteMem, IntrArgMemOnly, + WriteOnly<0>, NoCapture<0>] in { + defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>; + defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>; +} + +// WMMA.MMA +class NVVM_WMMA_MMA_ABDCS + : Intrinsic; + +multiclass NVVM_WMMA_MMA_ABDC { + def NAME : NVVM_WMMA_MMA_ABDCS; + def _satfinite: NVVM_WMMA_MMA_ABDCS; +} + +multiclass NVVM_WMMA_MMA_ABD { + defm _f16: NVVM_WMMA_MMA_ABDC; + defm _f32: NVVM_WMMA_MMA_ABDC; +} + +multiclass NVVM_WMMA_MMA_AB { + defm _f16: NVVM_WMMA_MMA_ABD; + defm _f32: NVVM_WMMA_MMA_ABD; +} + +multiclass NVVM_WMMA_MMA_A { + defm _col: NVVM_WMMA_MMA_AB; + defm _row: NVVM_WMMA_MMA_AB; +} + +defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">; +defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">; + } // let TargetPrefix = "nvvm" Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -74,6 +74,8 @@ bool tryConstantFP16(SDNode *N); bool SelectSETP_F16X2(SDNode *N); bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N); + bool tryWMMA_LDST(SDNode *N); + bool tryWMMA_MMA(SDNode *N); inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) { return CurDAG->getTargetConstant(Imm, DL, MVT::i32); Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -496,8 +496,315 @@ SelectCode(N); } +enum WmmaVariant { + WMMA_VARIANT_ARI64, + WMMA_VARIANT_ARI64_STRIDE, + WMMA_VARIANT_AVAR, + WMMA_VARIANT_AVAR_STRIDE, +}; + +// clang-format off +#define WMMA_VARIANTS(base) \ + {{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }} +// clang-format on + +static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride, + const std::array Variants) { + if (Stride) { + if (Variant == WMMA_VARIANT_ARI64) + Variant = WMMA_VARIANT_ARI64_STRIDE; + else if (Variant == WMMA_VARIANT_AVAR) + Variant = WMMA_VARIANT_AVAR_STRIDE; + } + return Variants[Variant]; +} + +static Optional +getWmmaLdStOpcode(unsigned IntrinsicID, + WmmaVariant Variant = WMMA_VARIANT_ARI64) { + switch (IntrinsicID) { + default: + return None; + // + // WMMA_LOAD_A f16 + // + case Intrinsic::nvvm_wmma_load_a_f16_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col)); + case Intrinsic::nvvm_wmma_load_a_f16_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row)); + case Intrinsic::nvvm_wmma_load_a_f16_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col)); + case Intrinsic::nvvm_wmma_load_a_f16_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row)); + case Intrinsic::nvvm_wmma_load_a_f16_col_shared: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared)); + case Intrinsic::nvvm_wmma_load_a_f16_row_shared: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared)); + case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared)); + case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared)); + case Intrinsic::nvvm_wmma_load_a_f16_col_global: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global)); + case Intrinsic::nvvm_wmma_load_a_f16_row_global: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global)); + case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global)); + case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global)); + + // + // WMMA_LOAD_B f16 + // + case Intrinsic::nvvm_wmma_load_b_f16_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col)); + case Intrinsic::nvvm_wmma_load_b_f16_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row)); + case Intrinsic::nvvm_wmma_load_b_f16_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col)); + case Intrinsic::nvvm_wmma_load_b_f16_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row)); + case Intrinsic::nvvm_wmma_load_b_f16_col_shared: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared)); + case Intrinsic::nvvm_wmma_load_b_f16_row_shared: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared)); + case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared)); + case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared)); + case Intrinsic::nvvm_wmma_load_b_f16_col_global: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global)); + case Intrinsic::nvvm_wmma_load_b_f16_row_global: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global)); + case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global)); + case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global)); + + // + // WMMA_LOAD_C f16 + // + case Intrinsic::nvvm_wmma_load_c_f16_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col)); + case Intrinsic::nvvm_wmma_load_c_f16_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row)); + case Intrinsic::nvvm_wmma_load_c_f16_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col)); + case Intrinsic::nvvm_wmma_load_c_f16_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row)); + case Intrinsic::nvvm_wmma_load_c_f16_col_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared)); + case Intrinsic::nvvm_wmma_load_c_f16_row_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared)); + case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared)); + case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared)); + case Intrinsic::nvvm_wmma_load_c_f16_col_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global)); + case Intrinsic::nvvm_wmma_load_c_f16_row_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global)); + case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global)); + case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global)); + + // + // WMMA_LOAD_C f32 + // + case Intrinsic::nvvm_wmma_load_c_f32_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col)); + case Intrinsic::nvvm_wmma_load_c_f32_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row)); + case Intrinsic::nvvm_wmma_load_c_f32_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col)); + case Intrinsic::nvvm_wmma_load_c_f32_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row)); + case Intrinsic::nvvm_wmma_load_c_f32_col_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared)); + case Intrinsic::nvvm_wmma_load_c_f32_row_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared)); + case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared)); + case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared)); + case Intrinsic::nvvm_wmma_load_c_f32_col_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global)); + case Intrinsic::nvvm_wmma_load_c_f32_row_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global)); + case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global)); + case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global)); + + // + // WMMA_STORE_D f16 + // + case Intrinsic::nvvm_wmma_store_d_f16_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col)); + case Intrinsic::nvvm_wmma_store_d_f16_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row)); + case Intrinsic::nvvm_wmma_store_d_f16_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col)); + case Intrinsic::nvvm_wmma_store_d_f16_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row)); + case Intrinsic::nvvm_wmma_store_d_f16_col_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared)); + case Intrinsic::nvvm_wmma_store_d_f16_row_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared)); + case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared)); + case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared)); + case Intrinsic::nvvm_wmma_store_d_f16_col_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global)); + case Intrinsic::nvvm_wmma_store_d_f16_row_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global)); + case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global)); + case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global)); + + // + // WMMA_STORE_D f32 + // + case Intrinsic::nvvm_wmma_store_d_f32_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col)); + case Intrinsic::nvvm_wmma_store_d_f32_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row)); + case Intrinsic::nvvm_wmma_store_d_f32_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col)); + case Intrinsic::nvvm_wmma_store_d_f32_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row)); + case Intrinsic::nvvm_wmma_store_d_f32_col_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared)); + case Intrinsic::nvvm_wmma_store_d_f32_row_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared)); + case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared)); + case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared)); + case Intrinsic::nvvm_wmma_store_d_f32_col_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global)); + case Intrinsic::nvvm_wmma_store_d_f32_row_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global)); + case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global)); + case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global)); + } +} +#undef WMMA_VARIANTS + bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { unsigned IID = cast(N->getOperand(1))->getZExtValue(); + if (getWmmaLdStOpcode(IID)) + return tryWMMA_LDST(N); + switch (IID) { default: return false; @@ -719,6 +1026,39 @@ case Intrinsic::nvvm_match_all_sync_i64p: SelectMatchAll(N); return true; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite: + return tryWMMA_MMA(N); } } @@ -3725,3 +4065,172 @@ } } } + +bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) { + SDValue Chain = N->getOperand(0); + unsigned IID = cast(N->getOperand(1))->getZExtValue(); + SDValue Op1 = N->getOperand(2); + SDValue Addr, Offset, Base; + Optional Opcode; + SDLoc DL(N); + MemSDNode *MemSD = cast(N); + WmmaVariant Variant; + SmallVector Ops; + bool isStore = N->getNumValues() == 1; // Store ops only return a chain. + + if (SelectDirectAddr(Op1, Addr)) { + Variant = WMMA_VARIANT_AVAR; + Ops.push_back(Addr); + } else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) || + SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) { + Variant = WMMA_VARIANT_ARI64; + Ops.push_back(Base); + Ops.push_back(Offset); + } else { + Variant = WMMA_VARIANT_AVAR; + Ops.push_back(Op1); + } + unsigned NumOps = N->getNumOperands(); + // Pass through the rest of the operands to the machine node. + for (unsigned i = 3; i < NumOps; ++i) + Ops.push_back(N->getOperand(i)); + Ops.push_back(Chain); + + Opcode = getWmmaLdStOpcode(IID, Variant); + if (!Opcode) { + llvm::errs() << "tryWMMALD - no Opcode.\n"; + return false; + } + + EVT MemVT = MemSD->getMemoryVT(); + assert(MemVT.isVector() && "Expected vector return type."); + + SDNode *MN; + if (isStore) { + MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops); + } else { + SmallVector InstVTs(MemVT.getVectorNumElements(), + MemSD->getValueType(0)); + InstVTs.push_back(MVT::Other); + MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops); + } + + ReplaceNode(N, MN); + return true; +} + +bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) { + unsigned IID = cast(N->getOperand(0))->getZExtValue(); + SDLoc DL(N); + unsigned Opc; + + switch (IID) { + default: + return false; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16: + Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32: + Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16: + Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32: + Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16: + Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32: + Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16: + Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32: + Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16: + Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32: + Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16: + Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32: + Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16: + Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32: + Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16: + Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32: + Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite; + break; + } + + SmallVector Ops; + // Pass through operands and return value types to the machine node. + for (unsigned i = 1; i < N->getNumOperands(); ++i) + Ops.push_back(N->getOperand(i)); + SmallVector InstVTs(N->getNumValues(), N->getValueType(0)); + SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops); + ReplaceNode(N, MN); + return true; +} Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3321,6 +3321,132 @@ switch (Intrinsic) { default: return false; + case Intrinsic::nvvm_wmma_load_a_f16_col: + case Intrinsic::nvvm_wmma_load_a_f16_row: + case Intrinsic::nvvm_wmma_load_a_f16_col_stride: + case Intrinsic::nvvm_wmma_load_a_f16_row_stride: + case Intrinsic::nvvm_wmma_load_a_f16_col_shared: + case Intrinsic::nvvm_wmma_load_a_f16_row_shared: + case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride: + case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride: + case Intrinsic::nvvm_wmma_load_a_f16_col_global: + case Intrinsic::nvvm_wmma_load_a_f16_row_global: + case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride: + case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride: + case Intrinsic::nvvm_wmma_load_b_f16_col: + case Intrinsic::nvvm_wmma_load_b_f16_row: + case Intrinsic::nvvm_wmma_load_b_f16_col_stride: + case Intrinsic::nvvm_wmma_load_b_f16_row_stride: + case Intrinsic::nvvm_wmma_load_b_f16_col_shared: + case Intrinsic::nvvm_wmma_load_b_f16_row_shared: + case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride: + case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride: + case Intrinsic::nvvm_wmma_load_b_f16_col_global: + case Intrinsic::nvvm_wmma_load_b_f16_row_global: + case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride: + case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v8f16; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = true; + Info.writeMem = false; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_load_c_f16_col: + case Intrinsic::nvvm_wmma_load_c_f16_row: + case Intrinsic::nvvm_wmma_load_c_f16_col_stride: + case Intrinsic::nvvm_wmma_load_c_f16_row_stride: + case Intrinsic::nvvm_wmma_load_c_f16_col_shared: + case Intrinsic::nvvm_wmma_load_c_f16_row_shared: + case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride: + case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride: + case Intrinsic::nvvm_wmma_load_c_f16_col_global: + case Intrinsic::nvvm_wmma_load_c_f16_row_global: + case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride: + case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v4f16; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = true; + Info.writeMem = false; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_load_c_f32_col: + case Intrinsic::nvvm_wmma_load_c_f32_row: + case Intrinsic::nvvm_wmma_load_c_f32_col_stride: + case Intrinsic::nvvm_wmma_load_c_f32_row_stride: + case Intrinsic::nvvm_wmma_load_c_f32_col_shared: + case Intrinsic::nvvm_wmma_load_c_f32_row_shared: + case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride: + case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride: + case Intrinsic::nvvm_wmma_load_c_f32_col_global: + case Intrinsic::nvvm_wmma_load_c_f32_row_global: + case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride: + case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v8f32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = true; + Info.writeMem = false; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_store_d_f16_col: + case Intrinsic::nvvm_wmma_store_d_f16_row: + case Intrinsic::nvvm_wmma_store_d_f16_col_stride: + case Intrinsic::nvvm_wmma_store_d_f16_row_stride: + case Intrinsic::nvvm_wmma_store_d_f16_col_shared: + case Intrinsic::nvvm_wmma_store_d_f16_row_shared: + case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride: + case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride: + case Intrinsic::nvvm_wmma_store_d_f16_col_global: + case Intrinsic::nvvm_wmma_store_d_f16_row_global: + case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride: + case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v4f16; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = false; + Info.writeMem = true; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_store_d_f32_col: + case Intrinsic::nvvm_wmma_store_d_f32_row: + case Intrinsic::nvvm_wmma_store_d_f32_col_stride: + case Intrinsic::nvvm_wmma_store_d_f32_row_stride: + case Intrinsic::nvvm_wmma_store_d_f32_col_shared: + case Intrinsic::nvvm_wmma_store_d_f32_row_shared: + case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride: + case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride: + case Intrinsic::nvvm_wmma_store_d_f32_col_global: + case Intrinsic::nvvm_wmma_store_d_f32_row_global: + case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride: + case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v8f32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = false; + Info.writeMem = true; + Info.align = 16; + return true; + } case Intrinsic::nvvm_atomic_load_add_f32: case Intrinsic::nvvm_atomic_load_inc_32: Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7368,3 +7368,208 @@ def INT_PTX_SREG_WARPSIZE : NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;", [(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>; + +// +// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] +// +class WMMA_LOAD_ALSTOS + : NVPTXInst, + Requires<[hasPTX60, hasSM70]>; + +multiclass WMMA_LOAD_ALSTO { + def _stride: WMMA_LOAD_ALSTOS; + def NAME: WMMA_LOAD_ALSTOS; +} + +multiclass WMMA_LOAD_ALST { + defm _avar: WMMA_LOAD_ALSTO; + defm _ari64: WMMA_LOAD_ALSTO; +} + +multiclass WMMA_LOAD_ALT { + defm _global: WMMA_LOAD_ALST; + defm _shared: WMMA_LOAD_ALST; + defm NAME: WMMA_LOAD_ALST; +} + +multiclass WMMA_LOAD_AT { + defm _row: WMMA_LOAD_ALT; + defm _col: WMMA_LOAD_ALT; +} + +defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>; +defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>; +defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>; +defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>; + +// +// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] +// +class WMMA_STORE_D_LSTOS + : NVPTXInst<(outs), + !if(!eq(Type,"f16"), + !if(WithStride, + !if(WithOffset, + (ins DstOp:$src, i32imm:$offset, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + Int32Regs:$ldm), + (ins DstOp:$src, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + Int32Regs:$ldm)), + !if(WithOffset, + (ins DstOp:$src, i32imm:$offset, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3), + (ins DstOp:$src, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))), + !if(WithStride, + !if(WithOffset, + (ins DstOp:$src, i32imm:$offset, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7, + Int32Regs:$ldm), + (ins DstOp:$src, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7, + Int32Regs:$ldm)), + !if(WithOffset, + (ins DstOp:$src, i32imm:$offset, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7), + (ins DstOp:$src, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))), + "wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t" + #!if(WithOffset,"[$src+$offset], ", "[$src], ") + #!if(!eq(Type,"f16"), + "{{$r0, $r1, $r2, $r3}}", + "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + #!if(WithStride, ", $ldm", "") + #";", + []>, + Requires<[hasPTX60, hasSM70]>; + +multiclass WMMA_STORE_D_LSTO { + def _stride: WMMA_STORE_D_LSTOS; + def NAME: WMMA_STORE_D_LSTOS; +} + +multiclass WMMA_STORE_D_LST { + defm _avar: WMMA_STORE_D_LSTO; + defm _ari64: WMMA_STORE_D_LSTO; +} + +multiclass WMMA_STORE_D_LT { + defm _global: WMMA_STORE_D_LST; + defm _shared: WMMA_STORE_D_LST; + defm NAME: WMMA_STORE_D_LST; +} + +multiclass WMMA_STORE_D_T { + defm _row: WMMA_STORE_D_LT<"row", Type, regclass>; + defm _col: WMMA_STORE_D_LT<"col", Type, regclass>; +} + +defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>; +defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>; + +// WMMA.MMA +class WMMA_MMA_ABDCS + : NVPTXInst, + Requires<[hasPTX60, hasSM70]>; + +multiclass WMMA_MMA_ABDC { + def _satfinite: WMMA_MMA_ABDCS; + def NAME: WMMA_MMA_ABDCS; +} + +multiclass WMMA_MMA_ABD { + defm _f16: WMMA_MMA_ABDC; + defm _f32: WMMA_MMA_ABDC; +} + +multiclass WMMA_MMA_AB { + defm _f16: WMMA_MMA_ABD; + defm _f32: WMMA_MMA_ABD; +} + +multiclass WMMA_MMA_A { + defm _col: WMMA_MMA_AB; + defm _row: WMMA_MMA_AB; +} + +defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">; +defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">; + Index: llvm/test/CodeGen/NVPTX/wmma.py =================================================================== --- /dev/null +++ llvm/test/CodeGen/NVPTX/wmma.py @@ -0,0 +1,201 @@ +# This test generates all variants of wmma intrinsics and verifies that LLVM +# generates correct instructions for them. + +# RUN: python %s > %t.ll +# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 | FileCheck %t.ll + +from itertools import product +from string import Template + +def make_wmma_slice_ty(abcd, itype): + elt_ty = "<2 x half>" if itype == "f16" else "float" + num_elts = 4 if abcd in "cd" and itype == "f16" else 8; + return [elt_ty] * num_elts + +def make_wmma_ld_ret_ty(abc, itype): + return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype)) + +# Convenient test patterns. +check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8) +check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4) +check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8) + +def gen_wmma_load_tests(): + load_template = """ +declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args}); + +; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}( +define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) { +; CHECK wmma.load.${intrinsic_suffix} +; CHECK: {${check_result}} +; CHECK: [%rd{{[0-9]+}}]${stride_pattern} + %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args}); + ret ${ret_ty} %v0; +} + +; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o( +define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) { +; CHECK wmma.load.${intrinsic_suffix} +; CHECK: {${check_result}} +; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern} + %src1 = getelementptr i8, i8* %src, i32 128; + %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args}); + ret ${ret_ty} %v0; +} +""" + suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}" + instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}" + + for abc, layout, space, stride, itype in product( + "abc", + ["row","col"], + ["",".shared",".global"], + ["", ".stride"], + ["f16", "f32"]): + + params = { + "abc" : abc, + "layout" : layout, + "space" : space, + "stride" : stride, + "itype" : itype + } + + if itype == "f32" and abc != "c": + continue + + test_params = params + test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params) + test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_") + test_params["instruction_suffix"] = Template(instruction_template).substitute(params) + test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) + if abc == "c" : + test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8 + else: + test_params["check_result"] = check_f16_8 + + if stride: + test_params["extra_args"] = ", i32 %stride"; + test_params["stride_pattern"] = ", %r{{[0-9]+}}" + else: + test_params["extra_args"] = "" + test_params["stride_pattern"] = "" + + print(Template(load_template).substitute(test_params)) + +def make_wmma_slice_args(itype, abcd, prefix="v"): + return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t + in enumerate(make_wmma_slice_ty(abcd, itype))]) + +def gen_wmma_store_tests(): + store_template = """ +declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args}); + +; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}( +define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) { +; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}} +; CHECK: {${check_args}} +; CHECK: ${stride_pattern} + call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args}); + ret void +} + +; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o( +define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) { +; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128] +; CHECK: ${check_args} +; CHECK: ${stride_pattern} + %src1 = getelementptr i8, i8* %src, i32 128; + call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args}); + ret void +} +""" + suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}" + instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}" + + for abc, layout, space, stride, itype in product( + "d", + ["row","col"], + ["",".shared",".global"], + ["", ".stride"], + ["f16", "f32"]): + + params = { + "abc" : abc, + "layout" : layout, + "space" : space, + "stride" : stride, + "itype" : itype + } + + test_params = params + test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params) + test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_") + test_params["instruction_suffix"] = Template(instruction_template).substitute(params) + test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) + test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8 + if stride: + test_params["extra_args"] = ", i32 %stride"; + test_params["stride_pattern"] = ", %r{{[0-9]+}};" + else: + test_params["extra_args"] = "" + test_params["stride_pattern"] = ";" + test_params["args"] = make_wmma_slice_args(itype, "d"); + + print(Template(store_template).substitute(test_params)) + +def gen_wmma_mma_tests(): + mma_template = """ +declare ${ret_ty} @llvm.nvvm.wmma.mma.sync.$intrinsic_suffix( + ${args}); + +; CHECK-LABEL: .func {{.*}}test_wmma_mma_${function_suffix}( +define ${ret_ty} @test_wmma_mma_${function_suffix}( + ${args}) { +; CHECK wmma.mma.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}} +; CHECK ${check_d} +; CHECK ${check_ab} +; CHECK ${check_ab} +; CHECK ${check_c} + %r = call ${ret_ty} @llvm.nvvm.wmma.mma.sync.${intrinsic_suffix}( + ${args}); + ret ${ret_ty} %r; +} +""" + suffix_template = "${alayout}.${blayout}.m16n16k16.${dtype}.${ctype}${satf}" + + for alayout, blayout, ctype, dtype, satf in product( + ["row","col"], + ["row","col"], + ["f16", "f32"], + ["f16", "f32"], + [".satfinite", ""]): + + params = { + "alayout" : alayout, + "blayout" : blayout, + "ctype" : ctype, + "dtype" : dtype, + "satf" : satf + } + + test_params = params + test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params) + test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".", "_") + test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype) + test_params["check_ab"] = check_f16_8 + test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8 + test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8 + args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd) + for abcd, t in (("a", "f16"), + ("b", "f16"), + ("c", ctype))) + test_params["args"] = args + print(Template(mma_template).substitute(test_params)) + +def main(): + gen_wmma_load_tests() + gen_wmma_store_tests() + gen_wmma_mma_tests() + +main()