diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -151,6 +151,11 @@ llvm::StringMap functionMapping; DenseMap valueMapping; DenseMap blockMapping; + + /// A mapping between MLIR LLVM dialect terminators and LLVM IR terminators + /// they are converted to. This allows for conneting PHI nodes to the source + /// values after all operations are converted. + DenseMap branchMapping; }; } // namespace LLVM diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -340,9 +340,10 @@ /// Connect the PHI nodes to the results of preceding blocks. template -static void -connectPHINodes(T &func, const DenseMap &valueMapping, - const DenseMap &blockMapping) { +static void connectPHINodes( + T &func, const DenseMap &valueMapping, + const DenseMap &blockMapping, + const DenseMap &branchMapping) { // Skip the first block, it cannot be branched to and its arguments correspond // to the arguments of the LLVM function. for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { @@ -355,9 +356,17 @@ auto &phiNode = numberedPhiNode.value(); unsigned index = numberedPhiNode.index(); for (auto *pred : bb->getPredecessors()) { + // Find the LLVM IR block that contains the converted terminator + // instruction and use it in the PHI node. Note that this block is not + // necessarily the same as blockMapping.lookup(pred), some operations + // (in particular, OpenMP operations using OpenMPIRBuilder) may have + // split the blocks. + llvm::Instruction *terminator = + branchMapping.lookup(pred->getTerminator()); + assert(terminator && "missing the mapping for a terminator"); phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( bb, pred, numArguments, index)), - blockMapping.lookup(pred)); + terminator->getParent()); } } } @@ -476,7 +485,7 @@ } // Finally, after all blocks have been traversed and values mapped, // connect the PHI nodes to the results of preceding blocks. - connectPHINodes(region, valueMapping, blockMapping); + connectPHINodes(region, valueMapping, blockMapping, branchMapping); } LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst, @@ -682,7 +691,9 @@ // Emit branches. We need to look up the remapped blocks and ignore the block // arguments that were transformed into PHI nodes. if (auto brOp = dyn_cast(opInst)) { - builder.CreateBr(blockMapping[brOp.getSuccessor()]); + llvm::BranchInst *branch = + builder.CreateBr(blockMapping[brOp.getSuccessor()]); + branchMapping.try_emplace(&opInst, branch); return success(); } if (auto condbrOp = dyn_cast(opInst)) { @@ -699,9 +710,11 @@ .createBranchWeights(static_cast(trueWeight), static_cast(falseWeight)); } - builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), - blockMapping[condbrOp.getSuccessor(0)], - blockMapping[condbrOp.getSuccessor(1)], branchWeights); + llvm::BranchInst *branch = builder.CreateCondBr( + valueMapping.lookup(condbrOp.getOperand(0)), + blockMapping[condbrOp.getSuccessor(0)], + blockMapping[condbrOp.getSuccessor(1)], branchWeights); + branchMapping.try_emplace(&opInst, branch); return success(); } @@ -893,10 +906,11 @@ } LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { - // Clear the block and value mappings, they are only relevant within one + // Clear the block, branch value mappings, they are only relevant within one // function. blockMapping.clear(); valueMapping.clear(); + branchMapping.clear(); llvm::Function *llvmFunc = functionMapping.lookup(func.getName()); // Translate the debug information for this function. @@ -964,7 +978,7 @@ // Finally, after all blocks have been traversed and values mapped, connect // the PHI nodes to the results of preceding blocks. - connectPHINodes(func, valueMapping, blockMapping); + connectPHINodes(func, valueMapping, blockMapping, branchMapping); return success(); }