diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1031,10 +1031,10 @@ let extraClassDeclaration = [{ /// Return the llvm.mlir.global operation that defined the value referenced /// here. - GlobalOp getGlobal(); + GlobalOp getGlobal(SymbolTableCollection &symbolTable); /// Return the llvm.func operation that is referenced here. - LLVMFuncOp getFunction(); + LLVMFuncOp getFunction(SymbolTableCollection &symbolTable); }]; let assemblyFormat = "$global_name attr-dict `:` type($res)"; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -16,6 +16,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" @@ -264,6 +265,8 @@ ModuleTranslation &moduleTranslation; }; + SymbolTableCollection& symbolTable() { return symbolTableCollection; } + private: ModuleTranslation(Operation *module, std::unique_ptr llvmModule); @@ -333,6 +336,9 @@ /// Stack of user-specified state elements, useful when translating operations /// with regions. SmallVector> stack; + + /// A cache for the symbol tables constructed during symbols lookup. + SymbolTableCollection symbolTableCollection; }; namespace detail { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1737,14 +1737,14 @@ return module; } -GlobalOp AddressOfOp::getGlobal() { +GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) { return dyn_cast_or_null( - SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName())); + symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); } -LLVMFuncOp AddressOfOp::getFunction() { +LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) { return dyn_cast_or_null( - SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName())); + symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); } LogicalResult diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -466,8 +466,10 @@ // operation and store it in the MLIR-to-LLVM value mapping. This does not // emit any LLVM instruction. if (auto addressOfOp = dyn_cast(opInst)) { - LLVM::GlobalOp global = addressOfOp.getGlobal(); - LLVM::LLVMFuncOp function = addressOfOp.getFunction(); + LLVM::GlobalOp global = + addressOfOp.getGlobal(moduleTranslation.symbolTable()); + LLVM::LLVMFuncOp function = + addressOfOp.getFunction(moduleTranslation.symbolTable()); // The verifier should not have allowed this. assert((global || function) && diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1285,7 +1285,8 @@ return opInst.emitError("Addressing symbol not found"); LLVM::AddressOfOp addressOfOp = dyn_cast(symOp); - LLVM::GlobalOp global = addressOfOp.getGlobal(); + LLVM::GlobalOp global = + addressOfOp.getGlobal(moduleTranslation.symbolTable()); llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global); llvm::Value *data = builder.CreateBitCast(globalValue, builder.getInt8PtrTy());