Index: lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp =================================================================== --- lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -2062,61 +2062,134 @@ // // i32,ch = load t0, t7, undef:i64 // - // Since we load an i8 value, the matching logic above will have selected an - // LDG instruction that reads i8 and stores it in an i16 register (NVPTX does - // not expose 8-bit registers): - // - // i16,ch = INT_PTX_LDG_GLOBAL_i8areg64 t7, t0 - // - // To get the correct type in this case, truncate back to i8 and then extend - // to the original load type. + // In this case, the matching logic above will select a load for the original + // memory type (in this case, i8) and our types will not match (the node needs + // to return an i32 in this case). Our LDG/LDU nodes do not support the + // concept of sign-/zero-extension, so emulate it here by adding an explicit + // CVT instruction. Ptxas should clean up any redundancies here. + EVT OrigType = N->getValueType(0); - LoadSDNode *LDSD = dyn_cast(N); - if (LDSD && EltVT == MVT::i8 && OrigType.getScalarSizeInBits() >= 32) { + LoadSDNode *LdNode = dyn_cast(N); + + if (OrigType != EltVT && LdNode) { + // We have an extending-load. The instruction we selected operates on the + // smaller type, but the SDNode we are replacing has the larger type. We + // need to emit a CVT to make the types match. unsigned CvtOpc = 0; - switch (LDSD->getExtensionType()) { - default: - llvm_unreachable("An extension is required for i8 loads"); - break; - case ISD::SEXTLOAD: - switch (OrigType.getSimpleVT().SimpleTy) { + if (EltVT == MVT::i8) { + switch (LdNode->getExtensionType()) { default: - llvm_unreachable("Unhandled integer load type"); + llvm_unreachable("Unknown load extension"); break; - case MVT::i32: - CvtOpc = NVPTX::CVT_s32_s8; + case ISD::SEXTLOAD: + switch (OrigType.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unhandled integer load type"); + break; + case MVT::i16: + CvtOpc = NVPTX::CVT_s16_s8; + break; + case MVT::i32: + CvtOpc = NVPTX::CVT_s32_s8; + break; + case MVT::i64: + CvtOpc = NVPTX::CVT_s64_s8; + break; + } break; - case MVT::i64: - CvtOpc = NVPTX::CVT_s64_s8; + case ISD::EXTLOAD: + case ISD::ZEXTLOAD: + switch (OrigType.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unhandled integer load type"); + break; + case MVT::i16: + CvtOpc = NVPTX::CVT_u16_u8; + break; + case MVT::i32: + CvtOpc = NVPTX::CVT_u32_u8; + break; + case MVT::i64: + CvtOpc = NVPTX::CVT_u64_u8; + break; + } break; } - break; - case ISD::EXTLOAD: - case ISD::ZEXTLOAD: - switch (OrigType.getSimpleVT().SimpleTy) { + } else if (EltVT == MVT::i16) { + switch (LdNode->getExtensionType()) { default: - llvm_unreachable("Unhandled integer load type"); + llvm_unreachable("Unknown load extension"); break; - case MVT::i32: - CvtOpc = NVPTX::CVT_u32_u8; + case ISD::SEXTLOAD: + switch (OrigType.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unhandled integer load type"); + break; + case MVT::i32: + CvtOpc = NVPTX::CVT_s32_s16; + break; + case MVT::i64: + CvtOpc = NVPTX::CVT_s64_s16; + break; + } break; - case MVT::i64: - CvtOpc = NVPTX::CVT_u64_u8; + case ISD::EXTLOAD: + case ISD::ZEXTLOAD: + switch (OrigType.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unhandled integer load type"); + break; + case MVT::i32: + CvtOpc = NVPTX::CVT_u32_u16; + break; + case MVT::i64: + CvtOpc = NVPTX::CVT_u64_u16; + break; + } break; } - break; + } else if (EltVT == MVT::i32) { + switch (LdNode->getExtensionType()) { + default: + llvm_unreachable("Unknown load extension"); + break; + case ISD::SEXTLOAD: + switch (OrigType.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unhandled integer load type"); + break; + case MVT::i64: + CvtOpc = NVPTX::CVT_s64_s32; + break; + } + break; + case ISD::EXTLOAD: + case ISD::ZEXTLOAD: + switch (OrigType.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unhandled integer load type"); + break; + case MVT::i64: + CvtOpc = NVPTX::CVT_u64_u32; + break; + } + break; + } + } else { + llvm_unreachable("Extending load of invalid base type"); } - // For each output value, truncate to i8 (since the upper 8 bits are - // undefined) and then extend to the desired type. + // For each output value, apply the manual sign/zero-extension and make sure + // all users of the load go through that CVT. for (unsigned i = 0; i != NumElts; ++i) { SDValue Res(LD, i); SDValue OrigVal(N, i); SDNode *CvtNode = CurDAG->getMachineNode(CvtOpc, DL, OrigType, Res, - CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32)); + CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, + DL, MVT::i32)); ReplaceUses(OrigVal, SDValue(CvtNode, 0)); } }