diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -184,6 +184,9 @@ void print(raw_ostream &os) const; private: + friend class SymbolInfoMap; + const void *getAsOpaquePointer() const { return node; } + const llvm::DagInit *node; // nullptr means null DagNode }; @@ -237,6 +240,10 @@ // Allow SymbolInfoMap to access private methods. friend class SymbolInfoMap; + // DagNode and DagLeaf are accessed by value which means it can't be used as + // identifier here. Use an opaque pointer type instead. + using DagAndIndex = std::pair; + // What kind of entity this symbol represents: // * Attr: op attribute // * Operand: op operand @@ -244,19 +251,21 @@ // * Value: a value not attached to an op (e.g., from NativeCodeCall) enum class Kind : uint8_t { Attr, Operand, Result, Value }; - // Creates a SymbolInfo instance. `index` is only used for `Attr` and - // `Operand` so should be negative for `Result` and `Value` kind. - SymbolInfo(const Operator *op, Kind kind, Optional index); + // Creates a SymbolInfo instance. `dagAndIndex` is only used for `Attr` and + // `Operand` so should be llvm::None for `Result` and `Value` kind. + SymbolInfo(const Operator *op, Kind kind, + Optional dagAndIndex); // Static methods for creating SymbolInfo. static SymbolInfo getAttr(const Operator *op, int index) { - return SymbolInfo(op, Kind::Attr, index); + return SymbolInfo(op, Kind::Attr, DagAndIndex(nullptr, index)); } static SymbolInfo getAttr() { return SymbolInfo(nullptr, Kind::Attr, llvm::None); } - static SymbolInfo getOperand(const Operator *op, int index) { - return SymbolInfo(op, Kind::Operand, index); + static SymbolInfo getOperand(DagNode node, const Operator *op, int index) { + return SymbolInfo(op, Kind::Operand, + DagAndIndex(node.getAsOpaquePointer(), index)); } static SymbolInfo getResult(const Operator *op) { return SymbolInfo(op, Kind::Result, llvm::None); @@ -291,8 +300,11 @@ const Operator *op; // The op where the bound entity belongs Kind kind; // The kind of the bound entity - // The argument index (for `Attr` and `Operand` only) - Optional argIndex; + // The pair of DagNode pointer and argument index (for `Attr` and `Operand` + // only). Note that operands may be bound to the same symbol, use the + // DagNode and index to distinguish them. For `Attr`, the Dag part will be + // nullptr. + Optional dagAndIndex; // Alternative name for the symbol. It is used in case the name // is not unique. Applicable for `Operand` only. Optional alternativeName; @@ -312,7 +324,8 @@ // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. // Returns false if `symbol` is already bound and symbols are not operands. - bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex); + bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, + int argIndex); // Binds the given `symbol` to the results the given `op`. Returns false if // `symbol` is already bound. @@ -334,8 +347,8 @@ // Returns an iterator to the information of the given symbol named as `key`, // with index `argIndex` for operator `op`. - const_iterator findBoundSymbol(StringRef key, const Operator &op, - int argIndex) const; + const_iterator findBoundSymbol(StringRef key, DagNode node, + const Operator &op, int argIndex) const; // Returns the bounds of a range that includes all the elements which // bind to the `key`. diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -193,8 +193,8 @@ } SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, - Optional index) - : op(op), kind(kind), argIndex(index) {} + Optional dagAndIndex) + : op(op), kind(kind), dagAndIndex(dagAndIndex) {} int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { switch (kind) { @@ -217,8 +217,9 @@ switch (kind) { case Kind::Attr: { if (op) { - auto type = - op->getArg(*argIndex).get()->attr.getStorageType(); + auto type = op->getArg((*dagAndIndex).second) + .get() + ->attr.getStorageType(); return std::string(formatv("{0} {1};\n", type, name)); } // TODO(suderman): Use a more exact type when available. @@ -254,7 +255,8 @@ } case Kind::Operand: { assert(index < 0); - auto *operand = op->getArg(*argIndex).get(); + auto *operand = + op->getArg((*dagAndIndex).second).get(); // If this operand is variadic, then return a range. Otherwise, return the // value itself. if (operand->isVariableLength()) { @@ -355,8 +357,8 @@ llvm_unreachable("unknown kind"); } -bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, - int argIndex) { +bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, + const Operator &op, int argIndex) { StringRef name = getValuePackName(symbol); if (name != symbol) { auto error = formatv( @@ -366,7 +368,7 @@ auto symInfo = op.getArg(argIndex).is() ? SymbolInfo::getAttr(&op, argIndex) - : SymbolInfo::getOperand(&op, argIndex); + : SymbolInfo::getOperand(node, &op, argIndex); std::string key = symbol.str(); if (symbolInfoMap.count(key)) { @@ -414,13 +416,15 @@ } SymbolInfoMap::const_iterator -SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op, +SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex) const { std::string name = getValuePackName(key).str(); auto range = symbolInfoMap.equal_range(name); + const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex); + for (auto it = range.first; it != range.second; ++it) { - if (it->second.op == &op && it->second.argIndex == argIndex) { + if (it->second.dagAndIndex == symbolInfo.dagAndIndex) { return it; } } @@ -722,7 +726,8 @@ if (!treeArgName.empty() && treeArgName != "_") { LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " << treeArgName << '\n'); - verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName); + verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i), + treeArgName); } } } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -716,6 +716,13 @@ def TestNestedOpEqualArgsPattern : Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>; +// Test when equality is enforced on same op and same operand but at different +// depth. We only bound one of the $x to the second operand of outer OpN and +// left another be the default value (which is the value of first operand of +// outer OpN). As a result, it ended up comparing wrong values in some cases. +def TestNestedSameOpAndSameArgEqualityPattern : + Pat<(OpN (OpN $_, $x), $x), (replaceWithValue $x)>; + // Test multiple equal arguments check enforced. def TestMultipleEqualArgsPattern : Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>; diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -158,6 +158,17 @@ return } +// CHECK-LABEL: verifyNestedSameOpAndSameArgEquality +func @verifyNestedSameOpAndSameArgEquality(%arg0: i32, %arg1: i32) -> i32 { + // def TestNestedSameOpAndSameArgEqualityPattern: + // Pat<(OpN (OpN $_, $x), $x), (replaceWithValue $x)>; + + %0 = "test.op_n"(%arg1, %arg0) : (i32, i32) -> (i32) + %1 = "test.op_n"(%0, %arg0) : (i32, i32) -> (i32) + // CHECK: return %arg0 : i32 + return %1 : i32 +} + // CHECK-LABEL: verifyMultipleEqualArgs func @verifyMultipleEqualArgs( %arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) { diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -454,7 +454,7 @@ op.arg_begin(), op.arg_begin() + argIndex, [](const Argument &arg) { return arg.is(); }); - auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); + auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex); os << formatv("{0} = {1}.getODSOperands({2});\n", res->second.getVarName(name), opName, argIndex - numPrevAttrs);