diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -974,7 +974,8 @@ let cppNamespace = "::mlir::LLVM"; } -def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", [NoSideEffect]> { +def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", + [NoSideEffect, DeclareOpInterfaceMethods]> { let arguments = (ins FlatSymbolRefAttr:$global_name); let results = (outs LLVM_AnyPointer:$res); @@ -1036,7 +1037,6 @@ }]; let assemblyFormat = "$global_name attr-dict `:` type($res)"; - let hasVerifier = 1; } def LLVM_MetadataOp : LLVM_Op<"metadata", [ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1729,27 +1729,28 @@ // Verifier for LLVM::AddressOfOp. //===----------------------------------------------------------------------===// -static Operation *lookupSymbolInModule(Operation *parent, StringRef name) { - Operation *module = parent; +static Operation *parentLLVMModule(Operation *op) { + Operation *module = op->getParentOp(); while (module && !satisfiesLLVMModule(module)) module = module->getParentOp(); assert(module && "unexpected operation outside of a module"); - return mlir::SymbolTable::lookupSymbolIn(module, name); + return module; } GlobalOp AddressOfOp::getGlobal() { return dyn_cast_or_null( - lookupSymbolInModule((*this)->getParentOp(), getGlobalName())); + SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName())); } LLVMFuncOp AddressOfOp::getFunction() { return dyn_cast_or_null( - lookupSymbolInModule((*this)->getParentOp(), getGlobalName())); + SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName())); } -LogicalResult AddressOfOp::verify() { +LogicalResult +AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) { Operation *symbol = - lookupSymbolInModule((*this)->getParentOp(), getGlobalName()); + symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()); auto global = dyn_cast_or_null(symbol); auto function = dyn_cast_or_null(symbol); diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -155,6 +155,7 @@ // them to trigger the attribute type mismatch error. // expected-error @+1 {{invalid kind of attribute specified}} llvm.mlir.addressof "foo" : i64 : !llvm.ptr> + llvm.return } // ----- @@ -162,6 +163,7 @@ func.func @foo() { // expected-error @+1 {{must reference a global defined by 'llvm.mlir.global'}} llvm.mlir.addressof @foo : !llvm.ptr> + llvm.return } // ----- @@ -171,6 +173,7 @@ func.func @bar() { // expected-error @+1 {{the type must be a pointer to the type of the referenced global}} llvm.mlir.addressof @foo : !llvm.ptr + llvm.return } // ----- @@ -180,6 +183,7 @@ llvm.func @bar() { // expected-error @+1 {{the type must be a pointer to the type of the referenced function}} llvm.mlir.addressof @foo : !llvm.ptr + llvm.return } // ----- @@ -211,6 +215,7 @@ func.func @mismatch_addr_space_implicit_global() { // expected-error @+1 {{pointer address space must match address space of the referenced global}} llvm.mlir.addressof @g : !llvm.ptr + llvm.return } // ----- @@ -219,6 +224,7 @@ func.func @mismatch_addr_space() { // expected-error @+1 {{pointer address space must match address space of the referenced global}} llvm.mlir.addressof @g : !llvm.ptr + llvm.return } // ----- @@ -227,6 +233,7 @@ func.func @mismatch_addr_space_opaque() { // expected-error @+1 {{pointer address space must match address space of the referenced global}} llvm.mlir.addressof @g : !llvm.ptr<4> + llvm.return } // -----