diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -879,6 +879,24 @@ getSplatVector(VT, DL, Op) : getSplatBuildVector(VT, DL, Op); } + /// Returns an exported splat source from another block. This helps the + /// WebAssembly target lowering for vector shift operation where i32 is used + /// as shift amount value type. + const Value *getExportedSplatSource(const SDNode *N, Register &Reg) const { + auto I = ExportedSplatValueMap.find(N); + if (I != ExportedSplatValueMap.end()) { + Reg = I->second.second; + return I->second.first; + } + return nullptr; + } + + /// Set exported splat source mapping to the splat node. + void addExportedSplatSource(const SDNode *N, const Value *V, + const unsigned Reg) { + ExportedSplatValueMap[N] = {V, Reg}; + } + /// Returns a vector of type ResVT whose elements contain the linear sequence /// <0, Step, Step * 2, Step * 3, ...> SDValue getStepVector(const SDLoc &DL, EVT ResVT, APInt StepVal); @@ -2438,6 +2456,8 @@ std::map, SDNode *> TargetExternalSymbols; DenseMap MCSymbols; + DenseMap> + ExportedSplatValueMap; FlagInserter *Inserter = nullptr; }; diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1120,6 +1120,13 @@ return false; } + virtual bool isShiftAmountScalar() const { return false; } + + virtual bool hasSplatValueUseForVectorOp(const Instruction *I = nullptr, + const Value *Splat = nullptr) const { + return false; + } + /// Targets can use this to indicate that they only support *some* /// VECTOR_SHUFFLE operations, those with specific masks. By default, if a /// target supports the VECTOR_SHUFFLE node, all mask values are assumed to be diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1402,6 +1402,7 @@ ExternalSymbols.clear(); TargetExternalSymbols.clear(); MCSymbols.clear(); + ExportedSplatValueMap.clear(); SDEI.clear(); std::fill(CondCodeNodes.begin(), CondCodeNodes.end(), static_cast(nullptr)); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -414,6 +414,7 @@ bool ShouldEmitAsBranches(const std::vector &Cases); bool isExportableFromCurrentBlock(const Value *V, const BasicBlock *FromBB); void CopyToExportRegsIfNeeded(const Value *V); + Register GetExportReg(const Value *V); void ExportFromCurrentBlock(const Value *V); void LowerCallTo(const CallBase &CB, SDValue Callee, bool IsTailCall, bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -826,7 +826,6 @@ const DataLayout &DL, unsigned Reg, Type *Ty, std::optional CC) { ComputeValueVTs(TLI, DL, Ty, ValueVTs); - CallConv = CC; for (EVT ValueVT : ValueVTs) { @@ -2170,16 +2169,20 @@ /// created for it, emit nodes to copy the value into the virtual /// registers. void SelectionDAGBuilder::CopyToExportRegsIfNeeded(const Value *V) { + if (Register reg = GetExportReg(V)) { + assert((!V->use_empty() || isa(V)) && + "Unused value assigned virtual registers!"); + CopyValueToVirtualRegister(V, reg); + } +} + +Register SelectionDAGBuilder::GetExportReg(const Value *V) { // Skip empty types if (V->getType()->isEmptyTy()) - return; + return Register(); DenseMap::iterator VMI = FuncInfo.ValueMap.find(V); - if (VMI != FuncInfo.ValueMap.end()) { - assert((!V->use_empty() || isa(V)) && - "Unused value assigned virtual registers!"); - CopyValueToVirtualRegister(V, VMI->second); - } + return (VMI != FuncInfo.ValueMap.end()) ? VMI->second : Register(); } /// ExportFromCurrentBlock - If this condition isn't known to be exported from @@ -3272,8 +3275,23 @@ SDValue Op1 = getValue(I.getOperand(0)); SDValue Op2 = getValue(I.getOperand(1)); - EVT ShiftTy = DAG.getTargetLoweringInfo().getShiftAmountTy( - Op1.getValueType(), DAG.getDataLayout()); + auto &TLI = DAG.getTargetLoweringInfo(); + Value *V2 = I.getOperand(1); + if (I.getType()->isVectorTy() && TLI.isShiftAmountScalar() && + !(NodeMap[V2].getNode())) { + const Value *splat = getSplatValue(V2); + if (splat && !isa(splat)) { + assert(FuncInfo.isExportedInst(splat) && "Splat value is not exported"); + // TODO: It's possible to be mapped to multiple splat vectors. + DenseMap::iterator It = + FuncInfo.ValueMap.find(splat); + if (It != FuncInfo.ValueMap.end()) { + DAG.addExportedSplatSource(Op2.getNode(), splat, It->second); + } + } + } + + EVT ShiftTy = TLI.getShiftAmountTy(Op1.getValueType(), DAG.getDataLayout()); // Coerce the shift amount to the right type if we can. This exposes the // truncate or zext to optimization early. @@ -3703,6 +3721,17 @@ unsigned MaskNumElts = Mask.size(); if (SrcNumElts == MaskNumElts) { + if (TLI.hasSplatValueUseForVectorOp() && GetExportReg(&I)) { + if (const Value *splat = getSplatValue(&I)) { + for (const Use &U : I.uses()) { + Instruction *UserI = cast(U.getUser()); + if (UserI->getType()->isVectorTy() && + TLI.hasSplatValueUseForVectorOp(UserI, &I)) { + ExportFromCurrentBlock(splat); + } + } + } + } setValue(&I, DAG.getVectorShuffle(VT, DL, Src1, Src2, Mask)); return; } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h @@ -76,6 +76,9 @@ bool isIntDivCheap(EVT VT, AttributeList Attr) const override; bool isVectorLoadExtDesirable(SDValue ExtVal) const override; bool isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const override; + bool isShiftAmountScalar() const override; + bool hasSplatValueUseForVectorOp(const Instruction *I = nullptr, + const Value *Splat = nullptr) const override; EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Context, EVT VT) const override; bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -833,6 +833,17 @@ return isa(GV) ? false : TargetLowering::isOffsetFoldingLegal(GA); } +bool WebAssemblyTargetLowering::isShiftAmountScalar() const { return true; } + +bool WebAssemblyTargetLowering::hasSplatValueUseForVectorOp( + const Instruction *I, const Value *Splat) const { + if (!I) { + return isShiftAmountScalar(); + } + + return I->isShift() && isShiftAmountScalar() && I->getOperand(1) == Splat; +} + EVT WebAssemblyTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &C, EVT VT) const { @@ -2383,9 +2394,25 @@ // Skip vector and operation ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1); - ShiftVal = DAG.getSplatValue(ShiftVal); - if (!ShiftVal) - return unrollVectorShift(Op, DAG); + if (ShiftVal.getValueType().isVector()) { + auto SavedShiftVal = ShiftVal; + ShiftVal = DAG.getSplatValue(ShiftVal); + if (!ShiftVal) { + Register InReg; + if (auto splat = + DAG.getExportedSplatSource(SavedShiftVal.getNode(), InReg)) { + EVT RegisterVT = getRegisterType( + splat->getContext(), + getValueType(DAG.getDataLayout(), splat->getType())); + ShiftVal = + DAG.getCopyFromReg(DAG.getEntryNode(), DL, InReg, RegisterVT); + } + + if (!ShiftVal) { + return unrollVectorShift(Op, DAG); + } + } + } // Skip scalar and operation ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1); diff --git a/llvm/test/CodeGen/WebAssembly/simd-shift-in-loop.ll b/llvm/test/CodeGen/WebAssembly/simd-shift-in-loop.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/simd-shift-in-loop.ll @@ -0,0 +1,30 @@ +; RUN: llc < %s -asm-verbose=false -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s + +; Test that SIMD shifts can be lowered correctly even when shift +; values are exported from outside blocks. + +target triple = "wasm32-unknown-unknown" + +; CHECK-LABEL: shl_loop: +; CHECK-NEXT: .functype shl_loop (i32, i32, i32) -> () +; CHECK-NOT: i8x16.splat +; CHECK-NOT: i8x16.extract_lane_u +; CHECK: i8x16.shl +define void @shl_loop(ptr %a, i8 %shift, i32 %count) { +entry: + %t1 = insertelement <16 x i8> undef, i8 %shift, i32 0 + %vshift = shufflevector <16 x i8> %t1, <16 x i8> undef, <16 x i32> zeroinitializer + br label %body +body: + %out = phi ptr [%a, %entry], [%b, %body] + %i = phi i32 [0, %entry], [%next, %body] + %v = load <16 x i8>, ptr %out, align 1 + %r = shl <16 x i8> %v, %vshift + %b = getelementptr inbounds i8, ptr %out, i32 16 + store <16 x i8> %r, ptr %b + %next = add i32 %i, 1 + %i.cmp = icmp eq i32 %next, %count + br i1 %i.cmp, label %body, label %exit +exit: + ret void +}