diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -77,6 +77,63 @@ /// module requirements. static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); } + /// Stores the mapping between a function name and its LLVM IR representation. + void mapFunction(StringRef name, llvm::Function *func) { + auto result = functionMapping.try_emplace(name, func); + (void)result; + assert(result.second && + "attempting to map a function that is already mapped"); + } + + /// Finds an LLVM IR function by its name. + llvm::Function *lookupFunction(StringRef name) const { + return functionMapping.lookup(name); + } + + /// Stores the mapping between an MLIR value and its LLVM IR counterpart. + void mapValue(Value mlir, llvm::Value *llvm) { mapValue(mlir) = llvm; } + + /// Provides write-once access to store the LLVM IR value corresponding to the + /// given MLIR value. + llvm::Value *&mapValue(Value value) { + llvm::Value *&llvm = valueMapping[value]; + assert(llvm == nullptr && + "attempting to map a value that is already mapped"); + return llvm; + } + + /// Finds an LLVM IR value corresponding to the given MLIR value. + llvm::Value *lookupValue(Value value) const { + return valueMapping.lookup(value); + } + + /// Stores the mapping between an MLIR block and LLVM IR basic block. + void mapBlock(Block *mlir, llvm::BasicBlock *llvm) { + auto result = blockMapping.try_emplace(mlir, llvm); + (void)result; + assert(result.second && "attempting to map a block that is already mapped"); + } + + /// Finds an LLVM IR basic block that corresponds to the given MLIR block. + llvm::BasicBlock *lookupBlock(Block *block) const { + return blockMapping.lookup(block); + } + + /// Stores the mapping between an MLIR operation with successors and a + /// corresponding LLVM IR instruction. + void mapBranch(Operation *mlir, llvm::Instruction *llvm) { + auto result = branchMapping.try_emplace(mlir, llvm); + (void)result; + assert(result.second && + "attempting to map a branch that is already mapped"); + } + + /// Finds an LLVM IR instruction that corresponds to the given MLIR operation + /// with successors. + llvm::Instruction *lookupBranch(Operation *op) const { + return branchMapping.lookup(op); + } + protected: /// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an /// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an @@ -94,8 +151,6 @@ virtual LogicalResult convertOmpMaster(Operation &op, llvm::IRBuilder<> &builder); void convertOmpOpRegions(Region ®ion, StringRef blockName, - DenseMap &valueMapping, - DenseMap &blockMapping, llvm::BasicBlock &sourceBlock, llvm::BasicBlock &continuationBlock, llvm::IRBuilder<> &builder, @@ -147,7 +202,7 @@ /// A stateful object used to translate types. TypeToLLVMIRTranslator typeTranslator; -protected: +private: /// Mappings between original and translated values, used for lookups. llvm::StringMap functionMapping; DenseMap valueMapping; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -94,7 +94,7 @@ } else { return type; } - } while (1); + } while (true); } /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. @@ -119,8 +119,8 @@ if (auto floatAttr = attr.dyn_cast()) return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); if (auto funcAttr = attr.dyn_cast()) - return llvm::ConstantExpr::getBitCast( - functionMapping.lookup(funcAttr.getValue()), llvmType); + return llvm::ConstantExpr::getBitCast(lookupFunction(funcAttr.getValue()), + llvmType); if (auto splatAttr = attr.dyn_cast()) { llvm::Type *elementType; uint64_t numElements; @@ -337,7 +337,9 @@ return condBranchOp.getSuccessor(0) == current ? condBranchOp.trueDestOperands()[index] : condBranchOp.falseDestOperands()[index]; - } else if (auto switchOp = dyn_cast(terminator)) { + } + + if (auto switchOp = dyn_cast(terminator)) { // For switches, we take the operands from either the default case, or from // the case branch that was taken. if (switchOp.defaultDestination() == current) @@ -353,15 +355,12 @@ /// Connect the PHI nodes to the results of preceding blocks. template -static void connectPHINodes( - T &func, const DenseMap &valueMapping, - const DenseMap &blockMapping, - const DenseMap &branchMapping) { +static void connectPHINodes(T &func, const ModuleTranslation &state) { // Skip the first block, it cannot be branched to and its arguments correspond // to the arguments of the LLVM function. for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { Block *bb = &*it; - llvm::BasicBlock *llvmBB = blockMapping.lookup(bb); + llvm::BasicBlock *llvmBB = state.lookupBlock(bb); auto phis = llvmBB->phis(); auto numArguments = bb->getNumArguments(); assert(numArguments == std::distance(phis.begin(), phis.end())); @@ -371,15 +370,15 @@ for (auto *pred : bb->getPredecessors()) { // Find the LLVM IR block that contains the converted terminator // instruction and use it in the PHI node. Note that this block is not - // necessarily the same as blockMapping.lookup(pred), some operations + // necessarily the same as state.lookupBlock(pred), some operations // (in particular, OpenMP operations using OpenMPIRBuilder) may have // split the blocks. llvm::Instruction *terminator = - branchMapping.lookup(pred->getTerminator()); + state.lookupBranch(pred->getTerminator()); assert(terminator && "missing the mapping for a terminator"); - phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( - bb, pred, numArguments, index)), - terminator->getParent()); + phiNode.addIncoming( + state.lookupValue(getPHISourceValue(bb, pred, numArguments, index)), + terminator->getParent()); } } } @@ -415,9 +414,8 @@ llvm::BasicBlock &continuationBlock) { // ParallelOp has only one region associated with it. auto ®ion = cast(opInst).getRegion(); - convertOmpOpRegions(region, "omp.par.region", valueMapping, blockMapping, - *codeGenIP.getBlock(), continuationBlock, builder, - bodyGenStatus); + convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(), + continuationBlock, builder, bodyGenStatus); }; // TODO: Perform appropriate actions according to the data-sharing @@ -437,10 +435,10 @@ llvm::Value *ifCond = nullptr; if (auto ifExprVar = cast(opInst).if_expr_var()) - ifCond = valueMapping.lookup(ifExprVar); + ifCond = lookupValue(ifExprVar); llvm::Value *numThreads = nullptr; if (auto numThreadsVar = cast(opInst).num_threads_var()) - numThreads = valueMapping.lookup(numThreadsVar); + numThreads = lookupValue(numThreadsVar); llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = cast(opInst).proc_bind_val()) pbKind = llvm::omp::getProcBindKind(bind.getValue()); @@ -460,15 +458,13 @@ void ModuleTranslation::convertOmpOpRegions( Region ®ion, StringRef blockName, - DenseMap &valueMapping, - DenseMap &blockMapping, llvm::BasicBlock &sourceBlock, llvm::BasicBlock &continuationBlock, llvm::IRBuilder<> &builder, LogicalResult &bodyGenStatus) { llvm::LLVMContext &llvmContext = builder.getContext(); for (Block &bb : region) { llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create( llvmContext, blockName, builder.GetInsertBlock()->getParent()); - blockMapping[&bb] = llvmBB; + mapBlock(&bb, llvmBB); } llvm::Instruction *sourceTerminator = sourceBlock.getTerminator(); @@ -477,7 +473,7 @@ // defs are converted before uses. llvm::SetVector blocks = topologicalSort(region); for (Block *bb : blocks) { - llvm::BasicBlock *llvmBB = blockMapping[bb]; + llvm::BasicBlock *llvmBB = lookupBlock(bb); // Retarget the branch of the entry block to the entry block of the // converted region (regions are single-entry). if (bb->isEntryBlock()) { @@ -506,7 +502,7 @@ } // Finally, after all blocks have been traversed and values mapped, // connect the PHI nodes to the results of preceding blocks. - connectPHINodes(region, valueMapping, blockMapping, branchMapping); + connectPHINodes(region, *this); } LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst, @@ -520,9 +516,8 @@ llvm::BasicBlock &continuationBlock) { // MasterOp has only one region associated with it. auto ®ion = cast(opInst).getRegion(); - convertOmpOpRegions(region, "omp.master.region", valueMapping, blockMapping, - *codeGenIP.getBlock(), continuationBlock, builder, - bodyGenStatus); + convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(), + continuationBlock, builder, bodyGenStatus); }; // TODO: Perform finalization actions for variables. This has to be @@ -551,12 +546,12 @@ "only static (default) loop schedule is currently supported"); // Find the loop configuration. - llvm::Value *lowerBound = valueMapping.lookup(loop.lowerBound()[0]); - llvm::Value *upperBound = valueMapping.lookup(loop.upperBound()[0]); - llvm::Value *step = valueMapping.lookup(loop.step()[0]); + llvm::Value *lowerBound = lookupValue(loop.lowerBound()[0]); + llvm::Value *upperBound = lookupValue(loop.upperBound()[0]); + llvm::Value *step = lookupValue(loop.step()[0]); llvm::Type *ivType = step->getType(); llvm::Value *chunk = loop.schedule_chunk_var() - ? valueMapping[loop.schedule_chunk_var()] + ? lookupValue(loop.schedule_chunk_var()) : llvm::ConstantInt::get(ivType, 1); // Set up the source location value for OpenMP runtime. @@ -576,16 +571,15 @@ llvm::IRBuilder<>::InsertPointGuard guard(builder); // Make sure further conversions know about the induction variable. - valueMapping[loop.getRegion().front().getArgument(0)] = iv; + mapValue(loop.getRegion().front().getArgument(0), iv); llvm::BasicBlock *entryBlock = ip.getBlock(); llvm::BasicBlock *exitBlock = entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit"); // Convert the body of the loop. - convertOmpOpRegions(loop.region(), "omp.wsloop.region", valueMapping, - blockMapping, *entryBlock, *exitBlock, builder, - bodyGenStatus); + convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock, + *exitBlock, builder, bodyGenStatus); }; // Delegate actual loop construction to the OpenMP IRBuilder. @@ -715,17 +709,14 @@ auto convertCall = [this, &builder](Operation &op) -> llvm::Value * { auto operands = lookupValues(op.getOperands()); ArrayRef operandsRef(operands); - if (auto attr = op.getAttrOfType("callee")) { - return builder.CreateCall(functionMapping.lookup(attr.getValue()), - operandsRef); - } else { - auto *calleePtrType = - cast(operandsRef.front()->getType()); - auto *calleeType = - cast(calleePtrType->getElementType()); - return builder.CreateCall(calleeType, operandsRef.front(), - operandsRef.drop_front()); - } + if (auto attr = op.getAttrOfType("callee")) + return builder.CreateCall(lookupFunction(attr.getValue()), operandsRef); + auto *calleePtrType = + cast(operandsRef.front()->getType()); + auto *calleeType = + cast(calleePtrType->getElementType()); + return builder.CreateCall(calleeType, operandsRef.front(), + operandsRef.drop_front()); }; // Emit calls. If the called function has a result, remap the corresponding @@ -733,7 +724,7 @@ if (isa(opInst)) { llvm::Value *result = convertCall(opInst); if (opInst.getNumResults() != 0) { - valueMapping[opInst.getResult(0)] = result; + mapValue(opInst.getResult(0), result); return success(); } // Check that LLVM call returns void for 0-result functions. @@ -770,7 +761,7 @@ llvm::Value *result = builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands())); if (opInst.getNumResults() != 0) - valueMapping[opInst.getResult(0)] = result; + mapValue(opInst.getResult(0), result); return success(); } @@ -778,17 +769,17 @@ auto operands = lookupValues(opInst.getOperands()); ArrayRef operandsRef(operands); if (auto attr = opInst.getAttrOfType("callee")) { - builder.CreateInvoke(functionMapping.lookup(attr.getValue()), - blockMapping[invOp.getSuccessor(0)], - blockMapping[invOp.getSuccessor(1)], operandsRef); + builder.CreateInvoke(lookupFunction(attr.getValue()), + lookupBlock(invOp.getSuccessor(0)), + lookupBlock(invOp.getSuccessor(1)), operandsRef); } else { auto *calleePtrType = cast(operandsRef.front()->getType()); auto *calleeType = cast(calleePtrType->getElementType()); builder.CreateInvoke( - calleeType, operandsRef.front(), blockMapping[invOp.getSuccessor(0)], - blockMapping[invOp.getSuccessor(1)], operandsRef.drop_front()); + calleeType, operandsRef.front(), lookupBlock(invOp.getSuccessor(0)), + lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); } return success(); } @@ -799,12 +790,12 @@ builder.CreateLandingPad(ty, lpOp.getNumOperands()); // Add clauses - for (auto operand : lookupValues(lpOp.getOperands())) { + for (llvm::Value *operand : lookupValues(lpOp.getOperands())) { // All operands should be constant - checked by verifier - if (auto constOperand = dyn_cast(operand)) + if (auto *constOperand = dyn_cast(operand)) lpi->addClause(constOperand); } - valueMapping[lpOp.getResult()] = lpi; + mapValue(lpOp.getResult(), lpi); return success(); } @@ -812,8 +803,8 @@ // arguments that were transformed into PHI nodes. if (auto brOp = dyn_cast(opInst)) { llvm::BranchInst *branch = - builder.CreateBr(blockMapping[brOp.getSuccessor()]); - branchMapping.try_emplace(&opInst, branch); + builder.CreateBr(lookupBlock(brOp.getSuccessor())); + mapBranch(&opInst, branch); return success(); } if (auto condbrOp = dyn_cast(opInst)) { @@ -831,10 +822,10 @@ static_cast(falseWeight)); } llvm::BranchInst *branch = builder.CreateCondBr( - valueMapping.lookup(condbrOp.getOperand(0)), - blockMapping[condbrOp.getSuccessor(0)], - blockMapping[condbrOp.getSuccessor(1)], branchWeights); - branchMapping.try_emplace(&opInst, branch); + lookupValue(condbrOp.getOperand(0)), + lookupBlock(condbrOp.getSuccessor(0)), + lookupBlock(condbrOp.getSuccessor(1)), branchWeights); + mapBranch(&opInst, branch); return success(); } if (auto switchOp = dyn_cast(opInst)) { @@ -849,8 +840,8 @@ } llvm::SwitchInst *switchInst = - builder.CreateSwitch(valueMapping[switchOp.value()], - blockMapping[switchOp.defaultDestination()], + builder.CreateSwitch(lookupValue(switchOp.value()), + lookupBlock(switchOp.defaultDestination()), switchOp.caseDestinations().size(), branchWeights); auto *ty = @@ -860,9 +851,9 @@ switchOp.caseDestinations())) switchInst->addCase( llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), - blockMapping[std::get<1>(i)]); + lookupBlock(std::get<1>(i))); - branchMapping.try_emplace(&opInst, switchInst); + mapBranch(&opInst, switchInst); return success(); } @@ -877,9 +868,9 @@ assert((global || function) && "referencing an undefined global or function"); - valueMapping[addressOfOp.getResult()] = - global ? globalsMapping.lookup(global) - : functionMapping.lookup(function.getName()); + mapValue(addressOfOp.getResult(), global + ? globalsMapping.lookup(global) + : lookupFunction(function.getName())); return success(); } @@ -899,7 +890,7 @@ /// suitable for further insertion into the end of the block. LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilder<> &builder) { - builder.SetInsertPoint(blockMapping[&bb]); + builder.SetInsertPoint(lookupBlock(&bb)); auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); // Before traversing operations, make block arguments available through @@ -919,7 +910,7 @@ "block argument does not have an LLVM type"); llvm::Type *type = convertType(wrappedType); llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); - valueMapping[arg] = phi; + mapValue(arg, phi); } } @@ -957,11 +948,11 @@ llvm::IRBuilder<> builder(llvmModule->getContext()); for (auto &op : initializer->without_terminator()) { if (failed(convertOperation(op, builder)) || - !isa(valueMapping.lookup(op.getResult(0)))) + !isa(lookupValue(op.getResult(0)))) return emitError(op.getLoc(), "unemittable constant value"); } ReturnOp ret = cast(initializer->getTerminator()); - cst = cast(valueMapping.lookup(ret.getOperand(0))); + cst = cast(lookupValue(ret.getOperand(0))); } auto linkage = convertLinkageToLLVM(op.linkage()); @@ -1064,7 +1055,7 @@ blockMapping.clear(); valueMapping.clear(); branchMapping.clear(); - llvm::Function *llvmFunc = functionMapping.lookup(func.getName()); + llvm::Function *llvmFunc = lookupFunction(func.getName()); // Translate the debug information for this function. debugTranslation->translate(func, *llvmFunc); @@ -1118,7 +1109,7 @@ llvmArg.getType()->getPointerElementType())); } - valueMapping[mlirArg] = &llvmArg; + mapValue(mlirArg, &llvmArg); argIdx++; } @@ -1135,7 +1126,7 @@ for (auto &bb : func) { auto *llvmBB = llvm::BasicBlock::Create(llvmContext); llvmBB->insertInto(llvmFunc); - blockMapping[&bb] = llvmBB; + mapBlock(&bb, llvmBB); } // Then, convert blocks one by one in topological order to ensure defs are @@ -1149,7 +1140,7 @@ // Finally, after all blocks have been traversed and values mapped, connect // the PHI nodes to the results of preceding blocks. - connectPHINodes(func, valueMapping, blockMapping, branchMapping); + connectPHINodes(func, *this); return success(); } @@ -1170,7 +1161,7 @@ cast(convertType(function.getType()))); llvm::Function *llvmFunc = cast(llvmFuncCst.getCallee()); llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage())); - functionMapping[function.getName()] = llvmFunc; + mapFunction(function.getName(), llvmFunc); // Forward the pass-through attributes to LLVM. if (failed(forwardPassthroughAttributes(function.getLoc(), @@ -1204,10 +1195,8 @@ ModuleTranslation::lookupValues(ValueRange values) { SmallVector remapped; remapped.reserve(values.size()); - for (Value v : values) { - assert(valueMapping.count(v) && "referencing undefined value"); - remapped.push_back(valueMapping.lookup(v)); - } + for (Value v : values) + remapped.push_back(lookupValue(v)); return remapped; } diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -126,14 +126,13 @@ // Then, rewrite the name based on its kind. bool isVariadicOperand = isVariadicOperandName(op, name); if (isOperandName(op, name)) { - auto result = isVariadicOperand - ? formatv("lookupValues(op.{0}())", name) - : formatv("valueMapping.lookup(op.{0}())", name); + auto result = isVariadicOperand ? formatv("lookupValues(op.{0}())", name) + : formatv("lookupValue(op.{0}())", name); bs << result; } else if (isAttributeName(op, name)) { bs << formatv("op.{0}()", name); } else if (isResultName(op, name)) { - bs << formatv("valueMapping[op.{0}()]", name); + bs << formatv("mapValue(op.{0}())", name); } else if (name == "_resultType") { bs << "convertType(op.getResult().getType())"; } else if (name == "_hasResult") {