diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1064,6 +1064,8 @@ const LocationDescription &Loc, InsertPointTy AllocaIP, ArrayRef SectionCBs, PrivatizeCallbackTy PrivCB, FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) { + assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required"); + if (!updateToLocation(Loc)) return Loc.IP; diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4077,7 +4077,12 @@ OMPBuilder.initialize(); F->setName("func"); IRBuilder<> Builder(BB); + + BasicBlock *EnterBB = BasicBlock::Create(Ctx, "sections.enter", F); + Builder.CreateBr(EnterBB); + Builder.SetInsertPoint(EnterBB); OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + llvm::SmallVector SectionCBVector; llvm::SmallVector CaseBBs; @@ -4232,7 +4237,11 @@ F->setName("func"); IRBuilder<> Builder(BB); + BasicBlock *EnterBB = BasicBlock::Create(Ctx, "sections.enter", F); + Builder.CreateBr(EnterBB); + Builder.SetInsertPoint(EnterBB); OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(), F->getEntryBlock().getFirstInsertionPt()); llvm::SmallVector SectionCBVector; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -72,6 +72,21 @@ return allocaInsertPoint; // Otherwise, insert to the entry block of the surrounding function. + // If the current IRBuilder InsertPoint is the function's entry, it cannot + // also be used for alloca insertion which would result in insertion order + // confusion. Create a new BasicBlock for the Builder and use the entry block + // for the allocs. + if (builder.GetInsertBlock() == + &builder.GetInsertBlock()->getParent()->getEntryBlock()) { + assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() && + "Assuming end of basic block"); + llvm::BasicBlock *entryBB = llvm::BasicBlock::Create( + builder.getContext(), "entry", builder.GetInsertBlock()->getParent(), + builder.GetInsertBlock()->getNextNode()); + builder.CreateBr(entryBB); + builder.SetInsertPoint(entryBB); + } + llvm::BasicBlock &funcEntryBlock = builder.GetInsertBlock()->getParent()->getEntryBlock(); return llvm::OpenMPIRBuilder::InsertPointTy( @@ -255,23 +270,12 @@ // TODO: Is the Parallel construct cancellable? bool isCancellable = false; - // Ensure that the BasicBlock for the the parallel region is sparate from the - // function entry which we may need to insert allocas. - if (builder.GetInsertBlock() == - &builder.GetInsertBlock()->getParent()->getEntryBlock()) { - assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() && - "Assuming end of basic block"); - llvm::BasicBlock *entryBB = - llvm::BasicBlock::Create(builder.getContext(), "parallel.entry", - builder.GetInsertBlock()->getParent(), - builder.GetInsertBlock()->getNextNode()); - builder.CreateBr(entryBB); - builder.SetInsertPoint(entryBB); - } + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel( - ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB, - privCB, finiCB, ifCond, numThreads, pbKind, isCancellable)); + ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind, + isCancellable)); return bodyGenStatus; } @@ -522,7 +526,6 @@ SmallVector vecValues = moduleTranslation.lookupValues(orderedOp.depend_vec_vars()); - llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); size_t indexVecValues = 0; while (indexVecValues < vecValues.size()) { SmallVector storeValues; @@ -531,9 +534,11 @@ storeValues.push_back(vecValues[indexVecValues]); indexVecValues++; } + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend( - ompLoc, findAllocaInsertPoint(builder, moduleTranslation), numLoops, - storeValues, ".cnt.addr", isDependSource)); + ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource)); } return success(); } @@ -634,10 +639,12 @@ // called for variables which have destructors/finalizers. auto finiCB = [&](InsertPointTy codeGenIP) {}; + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSections( - ompLoc, findAllocaInsertPoint(builder, moduleTranslation), sectionCBs, - privCB, finiCB, false, sectionsOp.nowait())); + ompLoc, allocaIP, sectionCBs, privCB, finiCB, false, + sectionsOp.nowait())); return bodyGenStatus; } @@ -1104,7 +1111,6 @@ llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); // Convert values and types. auto &innerOpList = opInst.region().front().getOperations(); @@ -1164,17 +1170,10 @@ // Handle ambiguous alloca, if any. auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation); - if (allocaIP.getPoint() == ompLoc.IP.getPoint()) { - // Same point => split basic block and make them unambigous. - llvm::UnreachableInst *unreachableInst = builder.CreateUnreachable(); - builder.SetInsertPoint(builder.GetInsertBlock()->splitBasicBlock( - unreachableInst, "alloca_split")); - ompLoc.IP = builder.saveIP(); - unreachableInst->eraseFromParent(); - } + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(ompBuilder->createAtomicUpdate( - ompLoc, findAllocaInsertPoint(builder, moduleTranslation), llvmAtomicX, - llvmExpr, atomicOrdering, binop, updateFn, isXBinopExpr)); + ompLoc, allocaIP, llvmAtomicX, llvmExpr, atomicOrdering, binop, updateFn, + isXBinopExpr)); return updateGenStatus; } @@ -1183,7 +1182,6 @@ llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); mlir::Value mlirExpr; bool isXBinopExpr = false, isPostfixUpdate = false; llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP; @@ -1262,20 +1260,13 @@ "argument"); return moduleTranslation.lookupValue(yieldop.results()[0]); }; + // Handle ambiguous alloca, if any. auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation); - if (allocaIP.getPoint() == ompLoc.IP.getPoint()) { - // Same point => split basic block and make them unambigous. - llvm::UnreachableInst *unreachableInst = builder.CreateUnreachable(); - builder.SetInsertPoint(builder.GetInsertBlock()->splitBasicBlock( - unreachableInst, "alloca_split")); - ompLoc.IP = builder.saveIP(); - unreachableInst->eraseFromParent(); - } + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(ompBuilder->createAtomicCapture( - ompLoc, findAllocaInsertPoint(builder, moduleTranslation), llvmAtomicX, - llvmAtomicV, llvmExpr, atomicOrdering, binop, updateFn, atomicUpdateOp, - isPostfixUpdate, isXBinopExpr)); + ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering, + binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr)); return updateGenStatus; } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -1854,6 +1854,9 @@ // CHECK-LABEL: @omp_sections_trivial llvm.func @omp_sections_trivial() -> () { + // CHECK: br label %[[ENTRY:[a-zA-Z_.]+]] + + // CHECK: [[ENTRY]]: // CHECK: br label %[[PREHEADER:.*]] // CHECK: [[PREHEADER]]: