diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -84,6 +84,10 @@ /// Sets the visibility of the given symbol operation. static void setSymbolVisibility(Operation *symbol, Visibility vis); + /// Returns the nearest symbol table from a given operation `from`. Returns + /// nullptr if no valid parent symbol table could be found. + static Operation *getNearestSymbolTable(Operation *from); + /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -278,9 +278,7 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, OpBuilder &builder) { - Operation *parent = op->getParentOp(); - while (parent && !parent->hasTrait()) - parent = parent->getParentOp(); + Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { op->emitError("expected operation to be within a module-like op"); return nullptr; diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -20,23 +20,6 @@ return !op->getDialect() && op->getNumRegions() == 1; } -/// Returns the nearest symbol table from a given operation `from`. Returns -/// nullptr if no valid parent symbol table could be found. -static Operation *getNearestSymbolTable(Operation *from) { - assert(from && "expected valid operation"); - if (isPotentiallyUnknownSymbolTable(from)) - return nullptr; - - while (!from->hasTrait()) { - from = from->getParentOp(); - - // Check that this is a valid op and isn't an unknown symbol table. - if (!from || isPotentiallyUnknownSymbolTable(from)) - return nullptr; - } - return from; -} - /// Returns the string name of the given symbol, or None if this is not a /// symbol. static Optional getNameIfSymbol(Operation *symbol) { @@ -212,6 +195,23 @@ symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx)); } +/// Returns the nearest symbol table from a given operation `from`. Returns +/// nullptr if no valid parent symbol table could be found. +Operation *SymbolTable::getNearestSymbolTable(Operation *from) { + assert(from && "expected valid operation"); + if (isPotentiallyUnknownSymbolTable(from)) + return nullptr; + + while (!from->hasTrait()) { + from = from->getParentOp(); + + // Check that this is a valid op and isn't an unknown symbol table. + if (!from || isPotentiallyUnknownSymbolTable(from)) + return nullptr; + } + return from; +} + /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol @@ -466,7 +466,7 @@ if (limitAncestor == symbol) { // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr // doesn't support parent references. - if (getNearestSymbolTable(limit) != symbol->getParentOp()) + if (SymbolTable::getNearestSymbolTable(limit) != symbol->getParentOp()) return WalkResult::advance(); return callback(SymbolRefAttr::get(symbolName, symbol->getContext()), limit);