diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -15,6 +15,7 @@ include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -733,7 +734,9 @@ // CallOp //===----------------------------------------------------------------------===// -def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> { +def CallOp : Std_Op<"call", + [CallOpInterface, MemRefsNormalizable, + DeclareOpInterfaceMethods]> { let summary = "call operation"; let description = [{ The `call` operation represents a direct call to a function that is within @@ -788,6 +791,7 @@ let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; + let verifier = ?; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -158,6 +158,27 @@ }]; } +//===----------------------------------------------------------------------===// +// SymbolUserOpInterface +//===----------------------------------------------------------------------===// + +def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> { + let description = [{ + This interface describes an operation that may use a `Symbol`. This + interface allows for users of symbols to hook into verification and other + symbol related utilities that are either costly or otherwise disallowed + within a traditional operation. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<"Verify the symbol uses held by this operation.", + "LogicalResult", "verifySymbolUses", + (ins "::mlir::SymbolTableCollection &":$symbolTable) + >, + ]; +} + //===----------------------------------------------------------------------===// // Symbol Traits //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -236,6 +236,21 @@ LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, SmallVectorImpl &symbols); + /// Returns the operation registered with the given symbol name within the + /// closest parent operation of, or including, 'from' with the + /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was + /// found. + Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); + Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol); + template + T lookupNearestSymbolFrom(Operation *from, StringRef symbol) { + return dyn_cast_or_null(lookupNearestSymbolFrom(from, symbol)); + } + template + T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { + return dyn_cast_or_null(lookupNearestSymbolFrom(from, symbol)); + } + /// Lookup, or create, a symbol table for an operation. SymbolTable &getSymbolTable(Operation *op); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -740,34 +740,33 @@ // CallOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CallOp op) { +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the callee attribute was specified. - auto fnAttr = op.getAttrOfType("callee"); + auto fnAttr = getAttrOfType("callee"); if (!fnAttr) - return op.emitOpError("requires a 'callee' symbol reference attribute"); - auto fn = - op.getParentOfType().lookupSymbol(fnAttr.getValue()); + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); if (!fn) - return op.emitOpError() << "'" << fnAttr.getValue() - << "' does not reference a valid function"; + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; // Verify that the operand and result types match the callee. auto fnType = fn.getType(); - if (fnType.getNumInputs() != op.getNumOperands()) - return op.emitOpError("incorrect number of operands for callee"); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (op.getOperand(i).getType() != fnType.getInput(i)) - return op.emitOpError("operand type mismatch: expected operand type ") + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") << fnType.getInput(i) << ", but provided " - << op.getOperand(i).getType() << " for operand number " << i; + << getOperand(i).getType() << " for operand number " << i; - if (fnType.getNumResults() != op.getNumResults()) - return op.emitOpError("incorrect number of results for callee"); + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (op.getResult(i).getType() != fnType.getResult(i)) - return op.emitOpError("result type mismatch"); + if (getResult(i).getType() != fnType.getResult(i)) + return emitOpError("result type mismatch"); return success(); } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -68,6 +68,30 @@ return success(); } +/// Walk all of the operations within the given set of regions, without +/// traversing into any nested symbol tables. Stops walking if the result of the +/// callback is anything other than `WalkResult::advance`. +static Optional +walkSymbolTable(MutableArrayRef regions, + function_ref(Operation *)> callback) { + SmallVector worklist(llvm::make_pointer_range(regions)); + while (!worklist.empty()) { + for (Operation &op : worklist.pop_back_val()->getOps()) { + Optional result = callback(&op); + if (result != WalkResult::advance()) + return result; + + // If this op defines a new symbol table scope, we can't traverse. Any + // symbol references nested within 'op' are different semantically. + if (!op.hasTrait()) { + for (Region ®ion : op.getRegions()) + worklist.push_back(®ion); + } + } + } + return WalkResult::advance(); +} + //===----------------------------------------------------------------------===// // SymbolTable //===----------------------------------------------------------------------===// @@ -347,7 +371,18 @@ .append("see existing symbol definition here"); } } - return success(); + + // Verify any nested symbol user operations. + SymbolTableCollection symbolTable; + auto verifySymbolUserFn = [&](Operation *op) -> Optional { + if (SymbolUserOpInterface user = dyn_cast(op)) + return WalkResult(user.verifySymbolUses(symbolTable)); + return WalkResult::advance(); + }; + + Optional result = + walkSymbolTable(op->getRegions(), verifySymbolUserFn); + return success(result && !result->wasInterrupted()); } LogicalResult detail::verifySymbol(Operation *op) { @@ -452,25 +487,13 @@ static Optional walkSymbolUses( MutableArrayRef regions, function_ref)> callback) { - SmallVector worklist(llvm::make_pointer_range(regions)); - while (!worklist.empty()) { - for (Operation &op : worklist.pop_back_val()->getOps()) { - if (walkSymbolRefs(&op, callback).wasInterrupted()) - return WalkResult::interrupt(); - - // Check that this isn't a potentially unknown symbol table. - if (isPotentiallyUnknownSymbolTable(&op)) - return llvm::None; + return walkSymbolTable(regions, [&](Operation *op) -> Optional { + // Check that this isn't a potentially unknown symbol table. + if (isPotentiallyUnknownSymbolTable(op)) + return llvm::None; - // If this op defines a new symbol table scope, we can't traverse. Any - // symbol references nested within 'op' are different semantically. - if (!op.hasTrait()) { - for (Region ®ion : op.getRegions()) - worklist.push_back(®ion); - } - } - } - return WalkResult::advance(); + return walkSymbolRefs(op, callback); + }); } /// Walk all of the uses, for any symbol, that are nested within the given /// operation 'from', invoking the provided callback for each. This does not @@ -927,6 +950,22 @@ return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); } +/// Returns the operation registered with the given symbol name within the +/// closest parent operation of, or including, 'from' with the +/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was +/// found. +Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, + StringRef symbol) { + Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); + return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; +} +Operation * +SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, + SymbolRefAttr symbol) { + Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); + return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; +} + /// Lookup, or create, a symbol table for an operation. SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) { auto it = symbolTables.try_emplace(op, nullptr);