Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -89,13 +89,13 @@ SDValue &Offset); bool SelectADDRri64(SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset); - bool SelectADDRsi_imp(SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset, MVT mvt); bool SelectADDRsi(SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset); bool SelectADDRsi64(SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset); + bool SelectADDRvar(SDNode *OpNode, SDValue Addr, SDValue &Value); bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const; Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -496,318 +496,8 @@ SelectCode(N); } -// Each instruction has four addressing variants. WMMA_VARIANTS() macro below -// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to -// look up the intrinsic ID of particular variant. -enum WmmaVariant { - WMMA_VARIANT_ARI64, - WMMA_VARIANT_ARI64_STRIDE, - WMMA_VARIANT_AVAR, - WMMA_VARIANT_AVAR_STRIDE, -}; - -// clang-format off -#define WMMA_VARIANTS(base) \ - {{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }} -// clang-format on - -static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride, - const std::array Variants) { - if (Stride) { - if (Variant == WMMA_VARIANT_ARI64) - Variant = WMMA_VARIANT_ARI64_STRIDE; - else if (Variant == WMMA_VARIANT_AVAR) - Variant = WMMA_VARIANT_AVAR_STRIDE; - } - return Variants[Variant]; -} - -static Optional -getWmmaLdStOpcode(unsigned IntrinsicID, - WmmaVariant Variant = WMMA_VARIANT_ARI64) { - switch (IntrinsicID) { - default: - return None; - // - // WMMA_LOAD_A f16 - // - case Intrinsic::nvvm_wmma_load_a_f16_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col)); - case Intrinsic::nvvm_wmma_load_a_f16_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row)); - case Intrinsic::nvvm_wmma_load_a_f16_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col)); - case Intrinsic::nvvm_wmma_load_a_f16_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row)); - case Intrinsic::nvvm_wmma_load_a_f16_col_shared: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared)); - case Intrinsic::nvvm_wmma_load_a_f16_row_shared: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared)); - case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared)); - case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared)); - case Intrinsic::nvvm_wmma_load_a_f16_col_global: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global)); - case Intrinsic::nvvm_wmma_load_a_f16_row_global: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global)); - case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global)); - case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global)); - - // - // WMMA_LOAD_B f16 - // - case Intrinsic::nvvm_wmma_load_b_f16_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col)); - case Intrinsic::nvvm_wmma_load_b_f16_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row)); - case Intrinsic::nvvm_wmma_load_b_f16_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col)); - case Intrinsic::nvvm_wmma_load_b_f16_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row)); - case Intrinsic::nvvm_wmma_load_b_f16_col_shared: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared)); - case Intrinsic::nvvm_wmma_load_b_f16_row_shared: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared)); - case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared)); - case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared)); - case Intrinsic::nvvm_wmma_load_b_f16_col_global: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global)); - case Intrinsic::nvvm_wmma_load_b_f16_row_global: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global)); - case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global)); - case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global)); - - // - // WMMA_LOAD_C f16 - // - case Intrinsic::nvvm_wmma_load_c_f16_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col)); - case Intrinsic::nvvm_wmma_load_c_f16_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row)); - case Intrinsic::nvvm_wmma_load_c_f16_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col)); - case Intrinsic::nvvm_wmma_load_c_f16_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row)); - case Intrinsic::nvvm_wmma_load_c_f16_col_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared)); - case Intrinsic::nvvm_wmma_load_c_f16_row_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared)); - case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared)); - case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared)); - case Intrinsic::nvvm_wmma_load_c_f16_col_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global)); - case Intrinsic::nvvm_wmma_load_c_f16_row_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global)); - case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global)); - case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global)); - - // - // WMMA_LOAD_C f32 - // - case Intrinsic::nvvm_wmma_load_c_f32_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col)); - case Intrinsic::nvvm_wmma_load_c_f32_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row)); - case Intrinsic::nvvm_wmma_load_c_f32_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col)); - case Intrinsic::nvvm_wmma_load_c_f32_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row)); - case Intrinsic::nvvm_wmma_load_c_f32_col_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared)); - case Intrinsic::nvvm_wmma_load_c_f32_row_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared)); - case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared)); - case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared)); - case Intrinsic::nvvm_wmma_load_c_f32_col_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global)); - case Intrinsic::nvvm_wmma_load_c_f32_row_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global)); - case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global)); - case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global)); - - // - // WMMA_STORE_D f16 - // - case Intrinsic::nvvm_wmma_store_d_f16_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col)); - case Intrinsic::nvvm_wmma_store_d_f16_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row)); - case Intrinsic::nvvm_wmma_store_d_f16_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col)); - case Intrinsic::nvvm_wmma_store_d_f16_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row)); - case Intrinsic::nvvm_wmma_store_d_f16_col_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared)); - case Intrinsic::nvvm_wmma_store_d_f16_row_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared)); - case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared)); - case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared)); - case Intrinsic::nvvm_wmma_store_d_f16_col_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global)); - case Intrinsic::nvvm_wmma_store_d_f16_row_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global)); - case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global)); - case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global)); - - // - // WMMA_STORE_D f32 - // - case Intrinsic::nvvm_wmma_store_d_f32_col: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col)); - case Intrinsic::nvvm_wmma_store_d_f32_row: - return getWmmaLdVariant(Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row)); - case Intrinsic::nvvm_wmma_store_d_f32_col_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col)); - case Intrinsic::nvvm_wmma_store_d_f32_row_stride: - return getWmmaLdVariant(Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row)); - case Intrinsic::nvvm_wmma_store_d_f32_col_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared)); - case Intrinsic::nvvm_wmma_store_d_f32_row_shared: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared)); - case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared)); - case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared)); - case Intrinsic::nvvm_wmma_store_d_f32_col_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global)); - case Intrinsic::nvvm_wmma_store_d_f32_row_global: - return getWmmaLdVariant( - Variant, /*Stride=*/false, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global)); - case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global)); - case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: - return getWmmaLdVariant( - Variant, /*Stride=*/true, - WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global)); - } -} -#undef WMMA_VARIANTS - bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { unsigned IID = cast(N->getOperand(1))->getZExtValue(); - if (getWmmaLdStOpcode(IID)) - return tryWMMA_LDST(N); - switch (IID) { default: return false; @@ -1026,39 +716,6 @@ case Intrinsic::nvvm_texsurf_handle_internal: SelectTexSurfHandle(N); return true; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32: - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32: - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32: - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32: - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite: - return tryWMMA_MMA(N); } } @@ -3946,6 +3603,12 @@ return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64); } +// symbol +bool NVPTXDAGToDAGISel::SelectADDRvar(SDNode *OpNode, SDValue Addr, + SDValue &Value) { + return SelectDirectAddr(Addr, Value); +} + bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const { const Value *Src = nullptr; @@ -4038,172 +3701,3 @@ } } } - -bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) { - SDValue Chain = N->getOperand(0); - unsigned IID = cast(N->getOperand(1))->getZExtValue(); - SDValue Op1 = N->getOperand(2); - SDValue Addr, Offset, Base; - Optional Opcode; - SDLoc DL(N); - MemSDNode *MemSD = cast(N); - WmmaVariant Variant; - SmallVector Ops; - bool isStore = N->getNumValues() == 1; // Store ops only return a chain. - - if (SelectDirectAddr(Op1, Addr)) { - Variant = WMMA_VARIANT_AVAR; - Ops.push_back(Addr); - } else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) || - SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) { - Variant = WMMA_VARIANT_ARI64; - Ops.push_back(Base); - Ops.push_back(Offset); - } else { - Variant = WMMA_VARIANT_AVAR; - Ops.push_back(Op1); - } - unsigned NumOps = N->getNumOperands(); - // Pass through the rest of the operands to the machine node. - for (unsigned i = 3; i < NumOps; ++i) - Ops.push_back(N->getOperand(i)); - Ops.push_back(Chain); - - Opcode = getWmmaLdStOpcode(IID, Variant); - if (!Opcode) { - llvm::errs() << "tryWMMALD - no Opcode.\n"; - return false; - } - - EVT MemVT = MemSD->getMemoryVT(); - assert(MemVT.isVector() && "Expected vector return type."); - - SDNode *MN; - if (isStore) { - MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops); - } else { - SmallVector InstVTs(MemVT.getVectorNumElements(), - MemSD->getValueType(0)); - InstVTs.push_back(MVT::Other); - MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops); - } - - ReplaceNode(N, MN); - return true; -} - -bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) { - unsigned IID = cast(N->getOperand(0))->getZExtValue(); - SDLoc DL(N); - unsigned Opc; - - switch (IID) { - default: - return false; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16: - Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32: - Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16: - Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32: - Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16: - Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32: - Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16: - Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32: - Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16: - Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32: - Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16: - Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32: - Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16: - Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32: - Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16: - Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32: - Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32; - break; - case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite: - Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite; - break; - } - - SmallVector Ops; - // Pass through operands and return value types to the machine node. - for (unsigned i = 1; i < N->getNumOperands(); ++i) - Ops.push_back(N->getOperand(i)); - SmallVector InstVTs(N->getNumValues(), N->getValueType(0)); - SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops); - ReplaceNode(N, MN); - return true; -} Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3410,7 +3410,7 @@ case Intrinsic::nvvm_wmma_store_d_f16_row_global: case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride: case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: { - Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v4f16; Info.ptrVal = I.getArgOperand(0); Info.offset = 0; @@ -3431,7 +3431,7 @@ case Intrinsic::nvvm_wmma_store_d_f32_row_global: case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride: case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: { - Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); Info.offset = 0; Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1527,6 +1527,7 @@ [SDNPWantRoot]>; def ADDRri64 : ComplexPattern; +def ADDRvar : ComplexPattern; def MEMri : Operand { let PrintMethod = "printMemOperand"; Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7372,44 +7372,73 @@ // // wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // + +class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>; + class WMMA_LOAD_ALSTOS - : NVPTXInst + : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast("int_nvvm_wmma_load_" + # Abc + # "_" # Type + # "_" # Layout + # !subst(".","_",Space) + # !if(WithStride,"_stride", "")); + dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); + dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); + dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47)); + + dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); + dag Ins = !con((ins SrcOp:$src), StrideArg); + + // Build a dag pattern that matches the intrinsic call. + // We want a dag that looks like this: + // (set , (intrinsic )) where input and + // output arguments are named patterns that would match corresponding + // input/output arguments of the instruction. + // + // First we construct (set ) from instruction's outs dag by + // replacing dag operator 'outs' with 'set'. + dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp)); + // Similarly, construct (intrinsic ) sub-dag from + // instruction's input arguments, only now we also need to replace operands + // with patterns that would match them and the operator 'ins' with the + // intrinsic. + dag PatArgs = !foreach(tmp, Ins, + !subst(imem, ADDRvar, + !subst(MEMri64, ADDRri64, + !subst(MEMri, ADDRri, + !subst(ins, Intr, tmp))))); + // Finally, consatenate both parts together. !con() requires both dags to have + // the same operator, so we wrap PatArgs in a (set ...) dag. + let Pattern = [!con(PatOuts, (set PatArgs))]; + let OutOperandList = Outs; + let InOperandList = Ins; + let AsmString = "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t" #!if(!eq(Abc#Type,"cf16"), "{{$r0, $r1, $r2, $r3}}", "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") - #", " - #!if(WithOffset,"[$src+$offset]", "[$src]") + #", [$src]" #!if(WithStride, ", $ldm", "") - #";", - []>, - Requires<[hasPTX60, hasSM70]>; + #";"; +} multiclass WMMA_LOAD_ALSTO { - def _stride: WMMA_LOAD_ALSTOS; - def NAME: WMMA_LOAD_ALSTOS; + DAGOperand SrcOp> { + def _stride: WMMA_LOAD_ALSTOS; + def NAME: WMMA_LOAD_ALSTOS; } multiclass WMMA_LOAD_ALST { - defm _avar: WMMA_LOAD_ALSTO; - defm _ari64: WMMA_LOAD_ALSTO; + defm _avar: WMMA_LOAD_ALSTO; + defm _areg: WMMA_LOAD_ALSTO; + defm _areg64: WMMA_LOAD_ALSTO; + defm _ari: WMMA_LOAD_ALSTO; + defm _ari64: WMMA_LOAD_ALSTO; } multiclass WMMA_LOAD_ALT - : NVPTXInst<(outs), - !if(!eq(Type,"f16"), - !if(WithStride, - !if(WithOffset, - (ins DstOp:$src, i32imm:$offset, - regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, - Int32Regs:$ldm), - (ins DstOp:$src, - regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, - Int32Regs:$ldm)), - !if(WithOffset, - (ins DstOp:$src, i32imm:$offset, - regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3), - (ins DstOp:$src, - regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))), - !if(WithStride, - !if(WithOffset, - (ins DstOp:$src, i32imm:$offset, - regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, - regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7, - Int32Regs:$ldm), - (ins DstOp:$src, - regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, - regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7, - Int32Regs:$ldm)), - !if(WithOffset, - (ins DstOp:$src, i32imm:$offset, - regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, - regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7), - (ins DstOp:$src, - regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, - regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))), - "wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t" - #!if(WithOffset,"[$src+$offset], ", "[$src], ") - #!if(!eq(Type,"f16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") - #!if(WithStride, ", $ldm", "") - #";", - []>, - Requires<[hasPTX60, hasSM70]>; + DAGOperand DstOp, bit WithStride> + : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { + Intrinsic Intr = !cast("int_nvvm_wmma_store_d_" + # Type + # "_" # Layout + # !subst(".","_",Space) + # !if(WithStride,"_stride", "")); + + dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); + dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); + dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47)); + dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); + dag Ins = !con(InsR, StrideArg); + + // Construct the pattern to match corresponding intrinsic call. See the + // details in the comments in WMMA_LOAD_ALSTOS. + dag PatArgs = !foreach(tmp, Ins, + !subst(imem, ADDRvar, + !subst(MEMri64, ADDRri64, + !subst(MEMri, ADDRri, + !subst(ins, Intr, tmp))))); + let Pattern = [PatArgs]; + let OutOperandList = (outs); + let InOperandList = Ins; + let AsmString = "wmma.store.d.sync." + # Layout + # ".m16n16k16" + # Space + # "." # Type + # " \t[$src]," + # !if(!eq(Type,"f16"), + "{{$r0, $r1, $r2, $r3}}", + "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + # !if(WithStride, ", $ldm", "") + # ";"; + +} multiclass WMMA_STORE_D_LSTO { - def _stride: WMMA_STORE_D_LSTOS; - def NAME: WMMA_STORE_D_LSTOS; + DAGOperand DstOp> { + def _stride: WMMA_STORE_D_LSTOS; + def NAME: WMMA_STORE_D_LSTOS; } multiclass WMMA_STORE_D_LST { - defm _avar: WMMA_STORE_D_LSTO; - defm _ari64: WMMA_STORE_D_LSTO; + defm _avar: WMMA_STORE_D_LSTO; + defm _areg: WMMA_STORE_D_LSTO; + defm _areg64: WMMA_STORE_D_LSTO; + defm _ari: WMMA_STORE_D_LSTO; + defm _ari64: WMMA_STORE_D_LSTO; } multiclass WMMA_STORE_D_LT { - defm _row: WMMA_STORE_D_LT<"row", Type, regclass>; - defm _col: WMMA_STORE_D_LT<"col", Type, regclass>; + defm _row: WMMA_STORE_D_LT<"row", Type, regclass>; + defm _col: WMMA_STORE_D_LT<"col", Type, regclass>; } defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>; @@ -7513,35 +7538,50 @@ string CType, NVPTXRegClass c_reg, NVPTXRegClass ab_reg, string Satfinite = ""> - : NVPTXInst, - Requires<[hasPTX60, hasSM70]>; + : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { + Intrinsic Intr = !cast("int_nvvm_wmma_mma_sync_" + # ALayout + # "_" # BLayout + # "_" # DType + # "_" # CType + # !subst(".","_",Satfinite)); + dag Outs = !if(!eq(DType,"f16"), + (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3), + (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3, + d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)); + dag InsExtraCArgs = !if(!eq(CType,"f16"), + (ins), + (ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)); + dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3, + ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7, + ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3, + ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7, + c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3), + InsExtraCArgs); + + // Construct the pattern to match corresponding intrinsic call. See the + // details in the comments in WMMA_LOAD_ALSTOS. + dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp)); + dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp)); + let Pattern = [!con(PatOuts, (set PatArgs))]; + let OutOperandList = Outs; + let InOperandList = Ins; + let AsmString = "wmma.mma.sync." + # ALayout + # "." # BLayout + # ".m16n16k16" + # "." # DType + # "." # CType + # Satfinite # "\n\t\t" + # !if(!eq(DType,"f16"), + "{{$d0, $d1, $d2, $d3}}, \n\t\t", + "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t") + # "{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t" + # "{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t" + # !if(!eq(CType,"f16"), + "{{$c0, $c1, $c2, $c3}};", + "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"); +} multiclass WMMA_MMA_ABDC