Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2681,7 +2681,7 @@ } case ISD::ZERO_EXTEND_VECTOR_INREG: { EVT InVT = Op.getOperand(0).getValueType(); - APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements()); + APInt InDemandedElts = DemandedElts.zextOrSelf(InVT.getVectorNumElements()); computeKnownBits(Op.getOperand(0), Known, InDemandedElts, Depth + 1); Known = Known.zext(BitWidth); Known.Zero.setBitsFrom(InVT.getScalarSizeInBits()); @@ -3264,7 +3264,7 @@ case ISD::SIGN_EXTEND_VECTOR_INREG: { SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); - APInt DemandedSrcElts = DemandedElts.zext(SrcVT.getVectorNumElements()); + APInt DemandedSrcElts = DemandedElts.zextOrSelf(SrcVT.getVectorNumElements()); Tmp = VTBits - SrcVT.getScalarSizeInBits(); return ComputeNumSignBits(Src, DemandedSrcElts, Depth+1) + Tmp; } Index: lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1481,22 +1481,20 @@ break; } case ISD::EXTRACT_SUBVECTOR: { - if (!isa(Op.getOperand(1))) - break; SDValue Src = Op.getOperand(0); + ConstantSDNode *SubIdx = dyn_cast(Op.getOperand(1)); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); - const APInt& Idx = cast(Op.getOperand(1))->getAPIntValue(); - if (Idx.uge(NumSrcElts - NumElts)) - break; - // Offset the demanded elts by the subvector index. - uint64_t SubIdx = Idx.getZExtValue(); - APInt SrcElts = DemandedElts.zext(NumSrcElts).shl(SubIdx); - APInt SrcUndef, SrcZero; - if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO, - Depth + 1)) - return true; - KnownUndef = SrcUndef.extractBits(NumElts, SubIdx); - KnownZero = SrcZero.extractBits(NumElts, SubIdx); + if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) { + // Offset the demanded elts by the subvector index. + uint64_t Idx = SubIdx->getZExtValue(); + APInt SrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + APInt SrcUndef, SrcZero; + if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO, + Depth + 1)) + return true; + KnownUndef = SrcUndef.extractBits(NumElts, Idx); + KnownZero = SrcZero.extractBits(NumElts, Idx); + } break; } case ISD::INSERT_VECTOR_ELT: { Index: unittests/CodeGen/CMakeLists.txt =================================================================== --- unittests/CodeGen/CMakeLists.txt +++ unittests/CodeGen/CMakeLists.txt @@ -1,4 +1,6 @@ set(LLVM_LINK_COMPONENTS + ${LLVM_TARGETS_TO_BUILD} + AsmParser AsmPrinter CodeGen Core @@ -15,6 +17,7 @@ MachineInstrTest.cpp MachineOperandTest.cpp ScalableVectorMVTsTest.cpp + SelectionDAGTest.cpp ) add_subdirectory(GlobalISel) Index: unittests/CodeGen/SelectionDAGTest.cpp =================================================================== --- /dev/null +++ unittests/CodeGen/SelectionDAGTest.cpp @@ -0,0 +1,165 @@ +//===- llvm/unittest/CodeGen/SelectionDAGTest.cpp -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/SelectionDAG.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +void initLLVM() { + InitializeAllTargets(); + InitializeAllTargetMCs(); +} + +class SelectionDAGTest : public testing::Test { +protected: + void SetUp() override { + StringRef Assembly = "define void @f() { ret void }"; + + Triple TargetTriple("aarch64--"); + std::string Error; + const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); + if (!T) + return; + + TargetOptions Options; + TM = std::unique_ptr(T->createTargetMachine( + "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive)); + if (!TM) + return; + + SMDiagnostic SMError; + M = parseAssemblyString(Assembly, SMError, Context); + if (!M) + report_fatal_error(SMError.getMessage()); + M->setDataLayout(TM->createDataLayout()); + + F = M->getFunction("f"); + if (!F) + report_fatal_error("F?"); + + MachineModuleInfo MMI(TM.get()); + + MF = make_unique(*F, *TM, *TM->getSubtargetImpl(*F), 0, + MMI); + + DAG = make_unique(*TM, CodeGenOpt::None); + if (!DAG) + report_fatal_error("DAG?"); + OptimizationRemarkEmitter ORE(F); + DAG->init(*MF, ORE, nullptr, nullptr, nullptr); + } + + LLVMContext Context; + std::unique_ptr TM = nullptr; + std::unique_ptr M; + Function *F; + std::unique_ptr MF; + std::unique_ptr DAG; +}; + +TEST_F(SelectionDAGTest, computeKnownBits_ZERO_EXTEND_VECTOR_INREG) { + if (!TM) + return; + SDLoc Loc; + auto Int8VT = EVT::getIntegerVT(Context, 8); + auto Int16VT = EVT::getIntegerVT(Context, 16); + auto InVecVT = EVT::getVectorVT(Context, Int8VT, 4); + auto OutVecVT = EVT::getVectorVT(Context, Int16VT, 2); + auto InVec = DAG->getConstant(0, Loc, InVecVT); + auto Op = DAG->getZeroExtendVectorInReg(InVec, Loc, OutVecVT); + auto DemandedElts = APInt(4, 15); + KnownBits Known; + DAG->computeKnownBits(Op, Known, DemandedElts); + EXPECT_TRUE(Known.isZero()); +} + +TEST_F(SelectionDAGTest, computeKnownBits_EXTRACT_SUBVECTOR) { + if (!TM) + return; + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 3); + auto IdxVT = EVT::getIntegerVT(Context, 64); + auto Vec = DAG->getConstant(0, Loc, VecVT); + auto ZeroIdx = DAG->getConstant(0, Loc, IdxVT); + auto Op = DAG->getNode(ISD::EXTRACT_SUBVECTOR, Loc, VecVT, Vec, ZeroIdx); + auto DemandedElts = APInt(3, 7); + KnownBits Known; + DAG->computeKnownBits(Op, Known, DemandedElts); + EXPECT_TRUE(Known.isZero()); +} + +TEST_F(SelectionDAGTest, ComputeNumSignBits_SIGN_EXTEND_VECTOR_INREG) { + if (!TM) + return; + SDLoc Loc; + auto Int8VT = EVT::getIntegerVT(Context, 8); + auto Int16VT = EVT::getIntegerVT(Context, 16); + auto InVecVT = EVT::getVectorVT(Context, Int8VT, 4); + auto OutVecVT = EVT::getVectorVT(Context, Int16VT, 2); + auto InVec = DAG->getConstant(1, Loc, InVecVT); + auto Op = DAG->getSignExtendVectorInReg(InVec, Loc, OutVecVT); + auto DemandedElts = APInt(4, 15); + EXPECT_EQ(DAG->ComputeNumSignBits(Op, DemandedElts), 15u); +} + +TEST_F(SelectionDAGTest, ComputeNumSignBits_EXTRACT_SUBVECTOR) { + if (!TM) + return; + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 3); + auto IdxVT = EVT::getIntegerVT(Context, 64); + auto Vec = DAG->getConstant(1, Loc, VecVT); + auto ZeroIdx = DAG->getConstant(0, Loc, IdxVT); + auto Op = DAG->getNode(ISD::EXTRACT_SUBVECTOR, Loc, VecVT, Vec, ZeroIdx); + auto DemandedElts = APInt(3, 7); + EXPECT_EQ(DAG->ComputeNumSignBits(Op, DemandedElts), 7u); +} + +TEST_F(SelectionDAGTest, SimplifyDemandedVectorElts_EXTRACT_SUBVECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 3); + auto IdxVT = EVT::getIntegerVT(Context, 64); + auto Vec = DAG->getConstant(1, Loc, VecVT); + auto ZeroIdx = DAG->getConstant(0, Loc, IdxVT); + auto Op = DAG->getNode(ISD::EXTRACT_SUBVECTOR, Loc, VecVT, Vec, ZeroIdx); + auto DemandedElts = APInt(3, 7); + auto KnownUndef = APInt(3, 0); + auto KnownZero = APInt(3, 0); + TargetLowering::TargetLoweringOpt TLO(*DAG, false, false); + EXPECT_EQ(TL.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, + KnownZero, TLO), + false); +} + +} // end anonymous namespace + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + initLLVM(); + return RUN_ALL_TESTS(); +}