diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -46,6 +46,12 @@ STATISTIC(NumTailCalls, "Number of tail calls"); +static cl::opt ExtensionMaxWebSize( + DEBUG_TYPE "-ext-max-web-size", cl::Hidden, + cl::desc("Give the maximum size (in number of nodes) of the web of " + "instructions that we will consider for VW expansion"), + cl::init(18)); + static cl::opt AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden, cl::desc("Allow the formation of VW_W operations (e.g., " @@ -8547,9 +8553,9 @@ /// Root of the combine. SDNode *Root; /// LHS of the TargetOpcode. - const NodeExtensionHelper &LHS; + NodeExtensionHelper LHS; /// RHS of the TargetOpcode. - const NodeExtensionHelper &RHS; + NodeExtensionHelper RHS; CombineResult(unsigned TargetOpcode, SDNode *Root, const NodeExtensionHelper &LHS, Optional SExtLHS, @@ -8728,31 +8734,83 @@ assert(NodeExtensionHelper::isSupportedRoot(N) && "Shouldn't have called this method"); + SmallVector Worklist; + SmallSet Inserted; + Worklist.push_back(N); + Inserted.insert(N); + SmallVector CombinesToApply; + + while (!Worklist.empty()) { + SDNode *Root = Worklist.pop_back_val(); + if (!NodeExtensionHelper::isSupportedRoot(Root)) + return SDValue(); - NodeExtensionHelper LHS(N, 0, DAG); - NodeExtensionHelper RHS(N, 1, DAG); - - if (LHS.needToPromoteOtherUsers() && !LHS.OrigOperand.hasOneUse()) - return SDValue(); - - if (RHS.needToPromoteOtherUsers() && !RHS.OrigOperand.hasOneUse()) - return SDValue(); + NodeExtensionHelper LHS(N, 0, DAG); + NodeExtensionHelper RHS(N, 1, DAG); + auto AppendUsersIfNeeded = [&Worklist, + &Inserted](const NodeExtensionHelper &Op) { + if (Op.needToPromoteOtherUsers()) { + for (SDNode *TheUse : Op.OrigOperand->uses()) { + if (Inserted.insert(TheUse).second) + Worklist.push_back(TheUse); + } + } + }; + AppendUsersIfNeeded(LHS); + AppendUsersIfNeeded(RHS); - SmallVector FoldingStrategies = - NodeExtensionHelper::getSupportedFoldings(N); + // Control the compile time by limiting the number of node we look at in + // total. + if (Inserted.size() > ExtensionMaxWebSize) + return SDValue(); - assert(!FoldingStrategies.empty() && "Nothing to be folded"); - for (int Attempt = 0; Attempt != 1 + NodeExtensionHelper::isCommutative(N); - ++Attempt) { - for (NodeExtensionHelper::CombineToTry FoldingStrategy : - FoldingStrategies) { - Optional Res = FoldingStrategy(N, LHS, RHS); - if (Res) - return Res->materialize(DAG); + SmallVector FoldingStrategies = + NodeExtensionHelper::getSupportedFoldings(N); + + assert(!FoldingStrategies.empty() && "Nothing to be folded"); + bool Matched = false; + for (int Attempt = 0; + (Attempt != 1 + NodeExtensionHelper::isCommutative(N)) && !Matched; + ++Attempt) { + + for (NodeExtensionHelper::CombineToTry FoldingStrategy : + FoldingStrategies) { + Optional Res = FoldingStrategy(N, LHS, RHS); + if (Res) { + Matched = true; + CombinesToApply.push_back(*Res); + break; + } + } + std::swap(LHS, RHS); } - std::swap(LHS, RHS); + // Right now we do an all or nothing approach. + if (!Matched) + return SDValue(); } - return SDValue(); + // Store the value for the replacement of the input node separately. + SDValue InputRootReplacement; + // We do the RAUW after we materialize all the combines, because some replaced + // nodes may be feeding some of the yet-to-be-replaced nodes. Put differently, + // some of these nodes may appear in the NodeExtensionHelpers of some of the + // yet-to-be-visited CombinesToApply roots. + SmallVector> ValuesToReplace; + ValuesToReplace.reserve(CombinesToApply.size()); + for (CombineResult Res : CombinesToApply) { + SDValue NewValue = Res.materialize(DAG); + if (!InputRootReplacement) { + assert(Res.Root == N && + "First element is expected to be the current node"); + InputRootReplacement = NewValue; + } else { + ValuesToReplace.emplace_back(SDValue(Res.Root, 0), NewValue); + } + } + for (std::pair OldNewValues : ValuesToReplace) { + DAG.ReplaceAllUsesOfValueWith(OldNewValues.first, OldNewValues.second); + DCI.AddToWorklist(OldNewValues.second.getNode()); + } + return InputRootReplacement; } // Fold diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll @@ -0,0 +1,60 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING +; Check that the default value enables the web folding and +; that it is bigger than 3. +; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING + + +; Check that the add/sub/mul operations are all promoted into their +; vw counterpart when the folding of the web size is increased to 3. +; We need the web size to be at least 3 for the folding to happen, because +; %c has 3 uses. +define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) { +; NO_FOLDING-LABEL: vwmul_v2i16_multiple_users: +; NO_FOLDING: # %bb.0: +; NO_FOLDING-NEXT: vsetivli zero, 2, e16, mf4, ta, mu +; NO_FOLDING-NEXT: vle8.v v8, (a0) +; NO_FOLDING-NEXT: vle8.v v9, (a1) +; NO_FOLDING-NEXT: vle8.v v10, (a2) +; NO_FOLDING-NEXT: vsext.vf2 v11, v8 +; NO_FOLDING-NEXT: vsext.vf2 v8, v9 +; NO_FOLDING-NEXT: vsext.vf2 v9, v10 +; NO_FOLDING-NEXT: vmul.vv v8, v11, v8 +; NO_FOLDING-NEXT: vadd.vv v10, v11, v9 +; NO_FOLDING-NEXT: vsub.vv v9, v11, v9 +; NO_FOLDING-NEXT: vor.vv v8, v8, v10 +; NO_FOLDING-NEXT: vor.vv v8, v8, v9 +; NO_FOLDING-NEXT: ret +; +; FOLDING-LABEL: vwmul_v2i16_multiple_users: +; FOLDING: # %bb.0: +; FOLDING-NEXT: vsetivli zero, 2, e8, mf8, ta, mu +; FOLDING-NEXT: vle8.v v8, (a0) +; FOLDING-NEXT: vle8.v v9, (a1) +; FOLDING-NEXT: vle8.v v10, (a2) +; FOLDING-NEXT: vwmul.vv v11, v8, v9 +; FOLDING-NEXT: vwadd.vv v9, v8, v10 +; FOLDING-NEXT: vwsub.vv v12, v8, v10 +; FOLDING-NEXT: vsetvli zero, zero, e16, mf4, ta, mu +; FOLDING-NEXT: vor.vv v8, v11, v9 +; FOLDING-NEXT: vor.vv v8, v8, v12 +; FOLDING-NEXT: ret + %a = load <2 x i8>, <2 x i8>* %x + %b = load <2 x i8>, <2 x i8>* %y + %b2 = load <2 x i8>, <2 x i8>* %z + %c = sext <2 x i8> %a to <2 x i16> + %d = sext <2 x i8> %b to <2 x i16> + %d2 = sext <2 x i8> %b2 to <2 x i16> + %e = mul <2 x i16> %c, %d + %f = add <2 x i16> %c, %d2 + %g = sub <2 x i16> %c, %d2 + %h = or <2 x i16> %e, %f + %i = or <2 x i16> %h, %g + ret <2 x i16> %i +} diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll @@ -21,16 +21,14 @@ define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) { ; CHECK-LABEL: vwmul_v2i16_multiple_users: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, mu +; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, mu ; CHECK-NEXT: vle8.v v8, (a0) ; CHECK-NEXT: vle8.v v9, (a1) ; CHECK-NEXT: vle8.v v10, (a2) -; CHECK-NEXT: vsext.vf2 v11, v8 -; CHECK-NEXT: vsext.vf2 v8, v9 -; CHECK-NEXT: vsext.vf2 v9, v10 -; CHECK-NEXT: vmul.vv v8, v11, v8 -; CHECK-NEXT: vmul.vv v9, v11, v9 -; CHECK-NEXT: vor.vv v8, v8, v9 +; CHECK-NEXT: vwmul.vv v11, v8, v9 +; CHECK-NEXT: vwmul.vv v9, v8, v10 +; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, mu +; CHECK-NEXT: vor.vv v8, v11, v9 ; CHECK-NEXT: ret %a = load <2 x i8>, <2 x i8>* %x %b = load <2 x i8>, <2 x i8>* %y