Index: llvm/lib/Target/SystemZ/SystemZISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/SystemZ/SystemZISelDAGToDAG.cpp +++ llvm/lib/Target/SystemZ/SystemZISelDAGToDAG.cpp @@ -304,7 +304,7 @@ uint64_t UpperVal, uint64_t LowerVal); void loadVectorConstant(const SystemZVectorConstantInfo &VCI, - SDNode *Node); + SDNode *Node, EVT VT); // Try to use gather instruction Opcode to implement vector insertion N. bool tryGather(SDNode *N, unsigned Opcode); @@ -1147,13 +1147,12 @@ } void SystemZDAGToDAGISel::loadVectorConstant( - const SystemZVectorConstantInfo &VCI, SDNode *Node) { + const SystemZVectorConstantInfo &VCI, SDNode *Node, EVT VT) { assert((VCI.Opcode == SystemZISD::BYTE_MASK || VCI.Opcode == SystemZISD::REPLICATE || VCI.Opcode == SystemZISD::ROTATE_MASK) && "Bad opcode!"); assert(VCI.VecVT.getSizeInBits() == 128 && "Expected a vector type"); - EVT VT = Node->getValueType(0); SDLoc DL(Node); SmallVector Ops; for (unsigned OpVal : VCI.OpVals) @@ -1166,11 +1165,20 @@ SDValue BitCast = CurDAG->getNode(ISD::BITCAST, DL, VT, Op); ReplaceNode(Node, BitCast.getNode()); SelectCode(BitCast.getNode()); - } else { // float or double + } else if (VT.isFloatingPoint()) { unsigned SubRegIdx = (VT.getSizeInBits() == 32 ? SystemZ::subreg_h32 : SystemZ::subreg_h64); ReplaceNode( Node, CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, Op).getNode()); + } else { + unsigned NumBytes = VT.getStoreSize(); + assert((NumBytes == 4 || NumBytes == 8) && "Unexpected vector element size"); + EVT VecVT = EVT::getVectorVT(*CurDAG->getContext(), VT, + SystemZ::VectorBytes / NumBytes); + SDValue BitCast = CurDAG->getBitcast(VecVT, Op); + SDValue ValueToUse = CurDAG->getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, + BitCast, CurDAG->getVectorIdxConstant(0, DL)); + ReplaceNode(Node, ValueToUse.getNode()); } SelectCode(Op.getNode()); } @@ -1634,7 +1642,7 @@ auto *BVN = cast(Node); SystemZVectorConstantInfo VCI(BVN); if (VCI.isVectorConstantLegal(*Subtarget)) { - loadVectorConstant(VCI, Node); + loadVectorConstant(VCI, Node, Node->getValueType(0)); return; } break; @@ -1647,7 +1655,7 @@ SystemZVectorConstantInfo VCI(Imm); bool Success = VCI.isVectorConstantLegal(*Subtarget); (void)Success; assert(Success && "Expected legal FP immediate"); - loadVectorConstant(VCI, Node); + loadVectorConstant(VCI, Node, Node->getValueType(0)); return; } @@ -1663,6 +1671,32 @@ if (tryScatter(Store, SystemZ::VSCEG)) return; } + auto &Op1 = Node->getOperand(1); + if (auto *C = dyn_cast(Op1)) { + EVT MemVT = Store->getMemoryVT(); + unsigned NumMemBytes = MemVT.getStoreSize(); + const SystemZInstrInfo *TII = getInstrInfo(); + if (C->getAPIntValue().getBitWidth() <= 64 && + !TII->isSupportedByScalarStore(C->getAPIntValue(), MemVT)) { + SmallVector Stores; + for (auto *U : C->uses()) + if (StoreSDNode *ST = dyn_cast(U)) + Stores.push_back(ST); + if (Stores.size() == C->use_size()) { + SystemZVectorConstantInfo VCI(C->getAPIntValue()); + if (VCI.isVectorConstantLegal(*Subtarget)) { + loadVectorConstant(VCI, Op1.getNode(), MemVT); + // Need to select all stores into VSTE before the bitcast is removed. + for (auto *STNode : Stores) + SelectCode(STNode); + auto &StoredVal = Node->getOperand(0); + if (StoredVal->getOpcode() == ISD::BITCAST) + SelectCode(StoredVal.getNode()); + return; + } + } + } + } break; } } Index: llvm/lib/Target/SystemZ/SystemZISelLowering.cpp =================================================================== --- llvm/lib/Target/SystemZ/SystemZISelLowering.cpp +++ llvm/lib/Target/SystemZ/SystemZISelLowering.cpp @@ -6444,7 +6444,7 @@ // type in WordVT. auto FindReplicatedImm = [&](ConstantSDNode *C, unsigned TotBytes) { // Some constants are better handled with a scalar store. - if (C->getAPIntValue().getBitWidth() > 64 || C->isAllOnes() || + if (C->getAPIntValue().getBitWidth() > 64 || C->isAllOnes() || // XXX isInt<16>(C->getSExtValue()) || MemVT.getStoreSize() <= 2) return; SystemZVectorConstantInfo VCI(APInt(TotBytes * 8, C->getZExtValue())); Index: llvm/lib/Target/SystemZ/SystemZInstrInfo.h =================================================================== --- llvm/lib/Target/SystemZ/SystemZInstrInfo.h +++ llvm/lib/Target/SystemZ/SystemZInstrInfo.h @@ -15,11 +15,13 @@ #include "SystemZ.h" #include "SystemZRegisterInfo.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/ValueTypes.h" #include #define GET_INSTRINFO_HEADER @@ -348,6 +350,8 @@ MachineBasicBlock::iterator MBBI, unsigned Reg, uint64_t Value) const; + bool isSupportedByScalarStore(const APInt &Val, EVT MemVT) const; + // Perform target specific instruction verification. bool verifyInstruction(const MachineInstr &MI, StringRef &ErrInfo) const override; Index: llvm/lib/Target/SystemZ/SystemZInstrInfo.cpp =================================================================== --- llvm/lib/Target/SystemZ/SystemZInstrInfo.cpp +++ llvm/lib/Target/SystemZ/SystemZInstrInfo.cpp @@ -1975,6 +1975,32 @@ .addReg(Reg1).addImm(Value & ((uint64_t(1) << 32) - 1)); } +// EXPERIMENTAL +static cl::opt KEEP_IMM_LOAD("keep-imm-load", cl::init(false)); + +bool SystemZInstrInfo::isSupportedByScalarStore(const APInt &Val, // XXX + EVT MemVT) const { + if (MemVT.getStoreSize() <= 2) + return true; + if (Val.getBitWidth() > 64) + return false; + + uint64_t UVal = Val.getZExtValue(); + int64_t SVal = Val.getSExtValue(); + + if (isInt<16>(SVal) || Val.isAllOnes()) + return true; + + if (KEEP_IMM_LOAD) + if (SystemZ::isImmLL(UVal) || SystemZ::isImmLH(UVal) || + SystemZ::isImmHL(UVal) || SystemZ::isImmHH(UVal) || + SystemZ::isImmLF(UVal) || SystemZ::isImmHF(UVal) || + isInt<32>(SVal)) + return true; + + return false; +} + bool SystemZInstrInfo::verifyInstruction(const MachineInstr &MI, StringRef &ErrInfo) const { const MCInstrDesc &MCID = MI.getDesc();