diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -93,11 +93,11 @@ llvm::IRBuilder<> &builder); virtual LogicalResult convertOmpMaster(Operation &op, llvm::IRBuilder<> &builder); - void convertOmpOpRegions(Region ®ion, + void convertOmpOpRegions(Region ®ion, StringRef blockName, DenseMap &valueMapping, DenseMap &blockMapping, - llvm::Instruction *codeGenIPBBTI, - llvm::BasicBlock &continuationIP, + llvm::BasicBlock &sourceBlock, + llvm::BasicBlock &continuationBlock, llvm::IRBuilder<> &builder, LogicalResult &bodyGenStatus); virtual LogicalResult convertOmpWsLoop(Operation &opInst, @@ -121,7 +121,8 @@ LogicalResult convertFunctions(); LogicalResult convertGlobals(); LogicalResult convertOneFunction(LLVMFuncOp func); - LogicalResult convertBlock(Block &bb, bool ignoreArguments); + LogicalResult convertBlock(Block &bb, bool ignoreArguments, + llvm::IRBuilder<> &builder); llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc); @@ -134,14 +135,11 @@ /// Builder for LLVM IR generation of OpenMP constructs. std::unique_ptr ompBuilder; + /// Precomputed pointer to OpenMP dialect. Note this can be nullptr if the /// OpenMP dialect hasn't been loaded (it is always loaded if there are OpenMP /// operations in the module though). const Dialect *ompDialect; - /// Stack which stores the target block to which a branch a must be added when - /// a terminator is seen. A stack is required to handle nested OpenMP parallel - /// regions. - SmallVector ompContinuationIPStack; /// Mappings between llvm.mlir.global definitions and corresponding globals. DenseMap globalsMapping; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -413,24 +413,11 @@ auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, llvm::BasicBlock &continuationIP) { - llvm::LLVMContext &llvmContext = llvmModule->getContext(); - - llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock(); - llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator(); - ompContinuationIPStack.push_back(&continuationIP); - - // ParallelOp has only `1` region associated with it. + // ParallelOp has only one region associated with it. auto ®ion = cast(opInst).getRegion(); - for (auto &bb : region) { - auto *llvmBB = llvm::BasicBlock::Create( - llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent()); - blockMapping[&bb] = llvmBB; - } - - convertOmpOpRegions(region, valueMapping, blockMapping, codeGenIPBBTI, - continuationIP, builder, bodyGenStatus); - ompContinuationIPStack.pop_back(); - + convertOmpOpRegions(region, "omp.par.region", valueMapping, blockMapping, + *codeGenIP.getBlock(), continuationIP, builder, + bodyGenStatus); }; // TODO: Perform appropriate actions according to the data-sharing @@ -472,29 +459,50 @@ } void ModuleTranslation::convertOmpOpRegions( - Region ®ion, DenseMap &valueMapping, + Region ®ion, StringRef blockName, + DenseMap &valueMapping, DenseMap &blockMapping, - llvm::Instruction *codeGenIPBBTI, llvm::BasicBlock &continuationIP, + llvm::BasicBlock &sourceBlock, llvm::BasicBlock &continuationBlock, llvm::IRBuilder<> &builder, LogicalResult &bodyGenStatus) { + llvm::LLVMContext &llvmContext = builder.getContext(); + for (Block &bb : region) { + llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create( + llvmContext, blockName, builder.GetInsertBlock()->getParent()); + blockMapping[&bb] = llvmBB; + } + + llvm::Instruction *sourceTerminator = sourceBlock.getTerminator(); + // Convert blocks one by one in topological order to ensure // defs are converted before uses. llvm::SetVector blocks = topologicalSort(region); - for (auto indexedBB : llvm::enumerate(blocks)) { - Block *bb = indexedBB.value(); - llvm::BasicBlock *curLLVMBB = blockMapping[bb]; + for (Block *bb : blocks) { + llvm::BasicBlock *llvmBB = blockMapping[bb]; + // Retarget the branch of the entry block to the entry block of the + // converted region (regions are single-entry). if (bb->isEntryBlock()) { - assert(codeGenIPBBTI->getNumSuccessors() == 1 && - "OpenMPIRBuilder provided entry block has multiple successors"); - assert(codeGenIPBBTI->getSuccessor(0) == &continuationIP && - "ContinuationIP is not the successor of OpenMPIRBuilder " - "provided entry block"); - codeGenIPBBTI->setSuccessor(0, curLLVMBB); + assert(sourceTerminator->getNumSuccessors() == 1 && + "provided entry block has multiple successors"); + assert(sourceTerminator->getSuccessor(0) == &continuationBlock && + "ContinuationBlock is not the successor of the entry block"); + sourceTerminator->setSuccessor(0, llvmBB); } - if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) { + llvm::IRBuilder<>::InsertPointGuard guard(builder); + if (failed(convertBlock(*bb, bb->isEntryBlock(), builder))) { bodyGenStatus = failure(); return; } + + // Special handling for `omp.yield` and `omp.terminator` (we may have more + // than one): they return the control to the parent OpenMP dialect operation + // so replace them with the branch to the continuation block. We handle this + // here to avoid relying inter-function communication through the + // ModuleTranslation class to set up the correct insertion point. This is + // also consistent with MLIR's idiom of handling special region terminators + // in the same code that handles the region-owning operation. + if (isa(bb->getTerminator())) + builder.CreateBr(&continuationBlock); } // Finally, after all blocks have been traversed and values mapped, // connect the PHI nodes to the results of preceding blocks. @@ -510,22 +518,11 @@ auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, llvm::BasicBlock &continuationIP) { - llvm::LLVMContext &llvmContext = llvmModule->getContext(); - - llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock(); - llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator(); - ompContinuationIPStack.push_back(&continuationIP); - - // MasterOp has only `1` region associated with it. + // MasterOp has only one region associated with it. auto ®ion = cast(opInst).getRegion(); - for (auto &bb : region) { - auto *llvmBB = llvm::BasicBlock::Create( - llvmContext, "omp.master.region", codeGenIP.getBlock()->getParent()); - blockMapping[&bb] = llvmBB; - } - convertOmpOpRegions(region, valueMapping, blockMapping, codeGenIPBBTI, - continuationIP, builder, bodyGenStatus); - ompContinuationIPStack.pop_back(); + convertOmpOpRegions(region, "omp.master.region", valueMapping, blockMapping, + *codeGenIP.getBlock(), continuationIP, builder, + bodyGenStatus); }; // TODO: Perform finalization actions for variables. This has to be @@ -553,9 +550,6 @@ return opInst.emitOpError( "only static (default) loop schedule is currently supported"); - llvm::Function *func = builder.GetInsertBlock()->getParent(); - llvm::LLVMContext &llvmContext = llvmModule->getContext(); - // Find the loop configuration. llvm::Value *lowerBound = valueMapping.lookup(loop.lowerBound()[0]); llvm::Value *upperBound = valueMapping.lookup(loop.upperBound()[0]); @@ -589,44 +583,9 @@ entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit"); // Convert the body of the loop. - Region ®ion = loop.region(); - for (Block &bb : region) { - llvm::BasicBlock *llvmBB = - llvm::BasicBlock::Create(llvmContext, "omp.wsloop.region", func); - blockMapping[&bb] = llvmBB; - - // Retarget the branch of the entry block to the entry block of the - // converted region (regions are single-entry). - if (bb.isEntryBlock()) { - auto *branch = cast(entryBlock->getTerminator()); - branch->setSuccessor(0, llvmBB); - } - } - - // Block conversion creates a new IRBuilder every time so need not bother - // about maintaining the insertion point. - llvm::SetVector blocks = topologicalSort(region); - for (Block *bb : blocks) { - if (failed(convertBlock(*bb, bb->isEntryBlock()))) { - bodyGenStatus = failure(); - return; - } - - // Special handling for `omp.yield` terminators (we may have more than - // one): they return the control to the parent WsLoop operation so replace - // them with the branch to the exit block. We handle this here to avoid - // relying inter-function communication through the ModuleTranslation - // class to set up the correct insertion point. This is also consistent - // with MLIR's idiom of handling special region terminators in the same - // code that handles the region-owning operation. - if (isa(bb->getTerminator())) { - llvm::BasicBlock *llvmBB = blockMapping[bb]; - builder.SetInsertPoint(llvmBB, llvmBB->end()); - builder.CreateBr(exitBlock); - } - } - - connectPHINodes(region, valueMapping, blockMapping, branchMapping); + convertOmpOpRegions(loop.region(), "omp.wsloop.region", valueMapping, + blockMapping, *entryBlock, *exitBlock, builder, + bodyGenStatus); }; // Delegate actual loop construction to the OpenMP IRBuilder. @@ -690,18 +649,15 @@ ompBuilder->createFlush(builder.saveIP()); return success(); }) - .Case([&](omp::TerminatorOp) { - builder.CreateBr(ompContinuationIPStack.back()); - return success(); - }) .Case( [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); }) .Case([&](omp::MasterOp) { return convertOmpMaster(opInst, builder); }) .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(opInst, builder); }) - .Case([&](omp::YieldOp op) { - // Yields are loop terminators that can be just omitted. The loop - // structure was created in the function that handles WsLoopOp. - assert(op.getNumOperands() == 0 && "unexpected yield with operands"); + .Case([](auto op) { + // `yield` and `terminator` can be just omitted. The block structure was + // created in the function that handles their parent operation. + assert(op->getNumOperands() == 0 && + "unexpected OpenMP terminator with operands"); return success(); }) .Default([&](Operation *inst) { @@ -911,9 +867,14 @@ /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes /// to define values corresponding to the MLIR block arguments. These nodes -/// are not connected to the source basic blocks, which may not exist yet. -LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { - llvm::IRBuilder<> builder(blockMapping[&bb]); +/// are not connected to the source basic blocks, which may not exist yet. Uses +/// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have +/// been created for `bb` and included in the block mapping. Inserts new +/// instructions at the end of the block and leaves `builder` in a state +/// suitable for further insertion into the end of the block. +LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments, + llvm::IRBuilder<> &builder) { + builder.SetInsertPoint(blockMapping[&bb]); auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); // Before traversing operations, make block arguments available through @@ -1137,9 +1098,9 @@ // Then, convert blocks one by one in topological order to ensure defs are // converted before uses. auto blocks = topologicalSort(func); - for (auto indexedBB : llvm::enumerate(blocks)) { - auto *bb = indexedBB.value(); - if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) + for (Block *bb : blocks) { + llvm::IRBuilder<> builder(llvmContext); + if (failed(convertBlock(*bb, bb->isEntryBlock(), builder))) return failure(); }