Index: llvm/trunk/include/llvm/Analysis/DivergenceAnalysis.h =================================================================== --- llvm/trunk/include/llvm/Analysis/DivergenceAnalysis.h +++ llvm/trunk/include/llvm/Analysis/DivergenceAnalysis.h @@ -37,12 +37,21 @@ // Print all divergent branches in the function. void print(raw_ostream &OS, const Module *) const override; - // Returns true if V is divergent. + // Returns true if V is divergent at its definition. + // + // Even if this function returns false, V may still be divergent when used + // in a different basic block. bool isDivergent(const Value *V) const { return DivergentValues.count(V); } // Returns true if V is uniform/non-divergent. + // + // Even if this function returns true, V may still be divergent when used + // in a different basic block. bool isUniform(const Value *V) const { return !isDivergent(V); } + // Keep the analysis results uptodate by removing an erased value. + void removeValue(const Value *V) { DivergentValues.erase(V); } + private: // Stores all divergent values. DenseSet DivergentValues; Index: llvm/trunk/lib/Transforms/Scalar/StructurizeCFG.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/StructurizeCFG.cpp +++ llvm/trunk/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -56,6 +56,12 @@ namespace { +static cl::opt ForceSkipUniformRegions( + "structurizecfg-skip-uniform-regions", + cl::Hidden, + cl::desc("Force whether the StructurizeCFG pass skips uniform regions"), + cl::init(false)); + // Definition of the complex types used in this pass. using BBValuePair = std::pair; @@ -177,6 +183,7 @@ Function *Func; Region *ParentRegion; + DivergenceAnalysis *DA; DominatorTree *DT; LoopInfo *LI; @@ -243,8 +250,11 @@ public: static char ID; - explicit StructurizeCFG(bool SkipUniformRegions = false) - : RegionPass(ID), SkipUniformRegions(SkipUniformRegions) { + explicit StructurizeCFG(bool SkipUniformRegions_ = false) + : RegionPass(ID), + SkipUniformRegions(SkipUniformRegions_) { + if (ForceSkipUniformRegions.getNumOccurrences()) + SkipUniformRegions = ForceSkipUniformRegions.getValue(); initializeStructurizeCFGPass(*PassRegistry::getPassRegistry()); } @@ -612,6 +622,8 @@ SI != SE; ++SI) delPhiValues(BB, *SI); + if (DA) + DA->removeValue(Term); Term->eraseFromParent(); } @@ -879,16 +891,37 @@ } } -static bool hasOnlyUniformBranches(const Region *R, +static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, const DivergenceAnalysis &DA) { - for (const BasicBlock *BB : R->blocks()) { - const BranchInst *Br = dyn_cast(BB->getTerminator()); - if (!Br || !Br->isConditional()) - continue; + for (auto E : R->elements()) { + if (!E->isSubRegion()) { + auto Br = dyn_cast(E->getEntry()->getTerminator()); + if (!Br || !Br->isConditional()) + continue; - if (!DA.isUniform(Br->getCondition())) - return false; - DEBUG(dbgs() << "BB: " << BB->getName() << " has uniform terminator\n"); + if (!DA.isUniform(Br)) + return false; + DEBUG(dbgs() << "BB: " << Br->getParent()->getName() + << " has uniform terminator\n"); + } else { + // Explicitly refuse to treat regions as uniform if they have non-uniform + // subregions. We cannot rely on DivergenceAnalysis for branches in + // subregions because those branches may have been removed and re-created, + // so we look for our metadata instead. + // + // Warning: It would be nice to treat regions as uniform based only on + // their direct child basic blocks' terminators, regardless of whether + // subregions are uniform or not. However, this requires a very careful + // look at SIAnnotateControlFlow to make sure nothing breaks there. + for (auto BB : E->getNodeAs()->blocks()) { + auto Br = dyn_cast(BB->getTerminator()); + if (!Br || !Br->isConditional()) + continue; + + if (!Br->getMetadata(UniformMDKindID)) + return false; + } + } } return true; } @@ -898,10 +931,18 @@ if (R->isTopLevelRegion()) return false; + DA = nullptr; + if (SkipUniformRegions) { // TODO: We could probably be smarter here with how we handle sub-regions. - auto &DA = getAnalysis(); - if (hasOnlyUniformBranches(R, DA)) { + // We currently rely on the fact that metadata is set by earlier invocations + // of the pass on sub-regions, and that this metadata doesn't get lost -- + // but we shouldn't rely on metadata for correctness! + unsigned UniformMDKindID = + R->getEntry()->getContext().getMDKindID("structurizecfg.uniform"); + DA = &getAnalysis(); + + if (hasOnlyUniformBranches(R, UniformMDKindID, *DA)) { DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R << '\n'); // Mark all direct child block terminators as having been treated as @@ -914,7 +955,7 @@ continue; if (Instruction *Term = E->getEntry()->getTerminator()) - Term->setMetadata("structurizecfg.uniform", MD); + Term->setMetadata(UniformMDKindID, MD); } return false; Index: llvm/trunk/test/Transforms/StructurizeCFG/AMDGPU/uniform-regions.ll =================================================================== --- llvm/trunk/test/Transforms/StructurizeCFG/AMDGPU/uniform-regions.ll +++ llvm/trunk/test/Transforms/StructurizeCFG/AMDGPU/uniform-regions.ll @@ -0,0 +1,82 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -mtriple=amdgcn-- -S -o - -structurizecfg -structurizecfg-skip-uniform-regions < %s | FileCheck %s + +define amdgpu_cs void @uniform(i32 inreg %v) { +; CHECK-LABEL: @uniform( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CC:%.*]] = icmp eq i32 [[V:%.*]], 0 +; CHECK-NEXT: br i1 [[CC]], label [[IF:%.*]], label [[END:%.*]], !structurizecfg.uniform !0 +; CHECK: if: +; CHECK-NEXT: br label [[END]], !structurizecfg.uniform !0 +; CHECK: end: +; CHECK-NEXT: ret void +; +entry: + %cc = icmp eq i32 %v, 0 + br i1 %cc, label %if, label %end + +if: + br label %end + +end: + ret void +} + +define amdgpu_cs void @nonuniform(i32 addrspace(2)* %ptr) { +; CHECK-LABEL: @nonuniform( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[FLOW:%.*]] ] +; CHECK-NEXT: [[CC:%.*]] = icmp ult i32 [[I]], 4 +; CHECK-NEXT: br i1 [[CC]], label [[MID_LOOP:%.*]], label [[FLOW]] +; CHECK: mid.loop: +; CHECK-NEXT: [[V:%.*]] = call i32 @llvm.amdgcn.workitem.id.x() +; CHECK-NEXT: [[CC2:%.*]] = icmp eq i32 [[V]], 0 +; CHECK-NEXT: br i1 [[CC2]], label [[END_LOOP:%.*]], label [[FLOW1:%.*]] +; CHECK: Flow: +; CHECK-NEXT: [[TMP0]] = phi i32 [ [[TMP2:%.*]], [[FLOW1]] ], [ undef, [[FOR_BODY]] ] +; CHECK-NEXT: [[TMP1:%.*]] = phi i1 [ [[TMP3:%.*]], [[FLOW1]] ], [ true, [[FOR_BODY]] ] +; CHECK-NEXT: br i1 [[TMP1]], label [[FOR_END:%.*]], label [[FOR_BODY]] +; CHECK: end.loop: +; CHECK-NEXT: [[I_INC:%.*]] = add i32 [[I]], 1 +; CHECK-NEXT: br label [[FLOW1]] +; CHECK: Flow1: +; CHECK-NEXT: [[TMP2]] = phi i32 [ [[I_INC]], [[END_LOOP]] ], [ undef, [[MID_LOOP]] ] +; CHECK-NEXT: [[TMP3]] = phi i1 [ false, [[END_LOOP]] ], [ true, [[MID_LOOP]] ] +; CHECK-NEXT: br label [[FLOW]] +; CHECK: for.end: +; CHECK-NEXT: br i1 [[CC]], label [[IF:%.*]], label [[END:%.*]] +; CHECK: if: +; CHECK-NEXT: br label [[END]] +; CHECK: end: +; CHECK-NEXT: ret void +; +entry: + br label %for.body + +for.body: + %i = phi i32 [0, %entry], [%i.inc, %end.loop] + %cc = icmp ult i32 %i, 4 + br i1 %cc, label %mid.loop, label %for.end + +mid.loop: + %v = call i32 @llvm.amdgcn.workitem.id.x() + %cc2 = icmp eq i32 %v, 0 + br i1 %cc2, label %end.loop, label %for.end + +end.loop: + %i.inc = add i32 %i, 1 + br label %for.body + +for.end: + br i1 %cc, label %if, label %end + +if: + br label %end + +end: + ret void +} + +declare i32 @llvm.amdgcn.workitem.id.x()