diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -21,6 +21,8 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" +#include + namespace llvm { class DagInit; class Init; @@ -228,6 +230,9 @@ // value bound by this symbol. std::string getVarDecl(StringRef name) const; + // Returns a variable name for the symbol named as `name`. + std::string getVarName(StringRef name) const; + private: // Allow SymbolInfoMap to access private methods. friend class SymbolInfoMap; @@ -285,9 +290,12 @@ Kind kind; // The kind of the bound entity // The argument index (for `Attr` and `Operand` only) Optional argIndex; + // Alternative name for the symbol. It is used in case the name + // is not unique. Applicable for `Operand` only. + Optional alternativeName; }; - using BaseT = llvm::StringMap; + using BaseT = std::unordered_multimap; // Iterators for accessing all symbols. using iterator = BaseT::iterator; @@ -300,7 +308,7 @@ const_iterator end() const { return symbolInfoMap.end(); } // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. - // Returns false if `symbol` is already bound. + // Returns false if `symbol` is already bound and symbols are not operands. bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex); // Binds the given `symbol` to the results the given `op`. Returns false if @@ -317,6 +325,18 @@ // Returns an iterator to the information of the given symbol named as `key`. const_iterator find(StringRef key) const; + // Returns an iterator to the information of the given symbol named as `key`, + // with index `argIndex` for operator `op`. + const_iterator findBoundSymbol(StringRef key, const Operator &op, + int argIndex) const; + + // Returns the bounds of a range that includes all the elements which + // bind to the `key`. + std::pair getRangeOfEqualElements(StringRef key); + + // Returns number of times symbol named as `key` was used. + int count(StringRef key) const; + // Returns the number of static values of the given `symbol` corresponds to. // A static value is an operand/result declared in ODS. Normally a symbol only // represents one static value, but symbols bound to op results can represent @@ -338,6 +358,9 @@ std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}", const char *separator = ", ") const; + // Assign alternative unique names to Operands that have equal names. + void assignUniqueAlternativeNames(); + // Splits the given `symbol` into a value pack name and an index. Returns the // value pack name and writes the index to `index` on success. Returns // `symbol` itself if it does not contain an index. @@ -347,7 +370,7 @@ static StringRef getValuePackName(StringRef symbol, int *index = nullptr); private: - llvm::StringMap symbolInfoMap; + BaseT symbolInfoMap; // Pattern instantiation location. This is intended to be used as parameter // to PrintFatalError() to report errors. diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -208,6 +208,10 @@ llvm_unreachable("unknown kind"); } +std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { + return alternativeName.hasValue() ? alternativeName.getValue() : name.str(); +} + std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); switch (kind) { @@ -219,8 +223,9 @@ case Kind::Operand: { // Use operand range for captured operands (to support potential variadic // operands). - return std::string(formatv( - "::mlir::Operation::operand_range {0}(op0->getOperands());\n", name)); + return std::string( + formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n", + getVarName(name))); } case Kind::Value: { return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name)); @@ -359,16 +364,34 @@ ? SymbolInfo::getAttr(&op, argIndex) : SymbolInfo::getOperand(&op, argIndex); - return symbolInfoMap.insert({symbol, symInfo}).second; + std::string key = symbol.str(); + if (symbolInfoMap.count(key)) { + // Only non unique name for the operand is supported. + if (symInfo.kind != SymbolInfo::Kind::Operand) { + return false; + } + + // Cannot add new operand if there is already non operand with the same + // name. + if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { + return false; + } + } + + symbolInfoMap.emplace(key, symInfo); + return true; } bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { - StringRef name = getValuePackName(symbol); - return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; + std::string name = getValuePackName(symbol).str(); + auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); + + return symbolInfoMap.count(inserted->first) == 1; } bool SymbolInfoMap::bindValue(StringRef symbol) { - return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; + auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); + return symbolInfoMap.count(inserted->first) == 1; } bool SymbolInfoMap::contains(StringRef symbol) const { @@ -376,10 +399,38 @@ } SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { - StringRef name = getValuePackName(key); + std::string name = getValuePackName(key).str(); + return symbolInfoMap.find(name); } +SymbolInfoMap::const_iterator +SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op, + int argIndex) const { + std::string name = getValuePackName(key).str(); + auto range = symbolInfoMap.equal_range(name); + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.op == &op && it->second.argIndex == argIndex) { + return it; + } + } + + return symbolInfoMap.end(); +} + +std::pair +SymbolInfoMap::getRangeOfEqualElements(StringRef key) { + std::string name = getValuePackName(key).str(); + + return symbolInfoMap.equal_range(name); +} + +int SymbolInfoMap::count(StringRef key) const { + std::string name = getValuePackName(key).str(); + return symbolInfoMap.count(name); +} + int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { StringRef name = getValuePackName(symbol); if (name != symbol) { @@ -388,7 +439,7 @@ return 1; } // Otherwise, find how many it represents by querying the symbol's info. - return find(name)->getValue().getStaticValueCount(); + return find(name)->second.getStaticValueCount(); } std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, @@ -397,13 +448,13 @@ int index = -1; StringRef name = getValuePackName(symbol, &index); - auto it = symbolInfoMap.find(name); + auto it = symbolInfoMap.find(name.str()); if (it == symbolInfoMap.end()) { auto error = formatv("referencing unbound symbol '{0}'", symbol); PrintFatalError(loc, error); } - return it->getValue().getValueAndRangeUse(name, index, fmt, separator); + return it->second.getValueAndRangeUse(name, index, fmt, separator); } std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, @@ -411,13 +462,44 @@ int index = -1; StringRef name = getValuePackName(symbol, &index); - auto it = symbolInfoMap.find(name); + auto it = symbolInfoMap.find(name.str()); if (it == symbolInfoMap.end()) { auto error = formatv("referencing unbound symbol '{0}'", symbol); PrintFatalError(loc, error); } - return it->getValue().getAllRangeUse(name, index, fmt, separator); + return it->second.getAllRangeUse(name, index, fmt, separator); +} + +void SymbolInfoMap::assignUniqueAlternativeNames() { + llvm::StringSet<> usedNames; + + for (auto symbolInfoIt = symbolInfoMap.begin(); + symbolInfoIt != symbolInfoMap.end();) { + auto range = symbolInfoMap.equal_range(symbolInfoIt->first); + auto startRange = range.first; + auto endRange = range.second; + + auto operandName = symbolInfoIt->first; + int startSearchIndex = 0; + for (++startRange; startRange != endRange; ++startRange) { + // Current operand name is not unique, find a unique one + // and set the alternative name. + for (int i = startSearchIndex;; ++i) { + std::string alternativeName = operandName + std::to_string(i); + if (!usedNames.contains(alternativeName) && + symbolInfoMap.count(alternativeName) == 0) { + usedNames.insert(alternativeName); + startRange->second.alternativeName = alternativeName; + startSearchIndex = i + 1; + + break; + } + } + } + + symbolInfoIt = endRange; + } } //===----------------------------------------------------------------------===// @@ -445,6 +527,10 @@ LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); + + LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); + infoMap.assignUniqueAlternativeNames(); + LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); } void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -619,6 +619,32 @@ let results = (outs I32); } +def OpN : TEST_Op<"op_n"> { + let arguments = (ins I32, I32); + let results = (outs I32); +} + +def OpO : TEST_Op<"op_o"> { + let arguments = (ins I32); + let results = (outs I32); +} + +def OpP : TEST_Op<"op_p"> { + let arguments = (ins I32, I32, I32, I32, I32, I32); + let results = (outs I32); +} + +// Test same operand name enforces equality condition check. +def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>; + +// Test when equality is enforced at different depth. +def TestNestedOpEqualArgsPattern : + Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>; + +// Test multiple equal arguments check enforced. +def TestMultipleEqualArgsPattern : + Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>; + // Test for memrefs normalization of an op with normalizable memrefs. def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> { let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -111,6 +111,64 @@ return } +// CHECK-LABEL: verifyEqualArgs +func @verifyEqualArgs(%arg0: i32, %arg1: i32) { + // def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>; + + // CHECK: "test.op_o"(%arg0) : (i32) -> i32 + "test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32) + + // CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32 + "test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32) + + return +} + +// CHECK-LABEL: verifyNestedOpEqualArgs +func @verifyNestedOpEqualArgs( + %arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) { + // def TestNestedOpEqualArgsPattern : + // Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>; + + // CHECK: %arg1 + %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) + : (i32, i32, i32, i32, i32, i32) -> (i32) + %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32) + + // CHECK: test.op_p + // CHECK: test.op_n + %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) + : (i32, i32, i32, i32, i32, i32) -> (i32) + %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32) + + return +} + +// CHECK-LABEL: verifyMultipleEqualArgs +func @verifyMultipleEqualArgs( + %arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) { + // def TestMultipleEqualArgsPattern : + // Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>; + + // CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32 + "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) : + (i32, i32, i32, i32 , i32, i32) -> i32 + + // CHECK: test.op_p + "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) : + (i32, i32, i32, i32 , i32, i32) -> i32 + + // CHECK: test.op_p + "test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) : + (i32, i32, i32, i32 , i32, i32) -> i32 + + // CHECK: test.op_p + "test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) : + (i32, i32, i32, i32 , i32, i32) -> i32 + + return +} + //===----------------------------------------------------------------------===// // Test Symbol Binding //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -89,6 +89,11 @@ void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt); + // Emits C++ for checking a match with a corresponding match failure + // diagnostics. + void emitMatchCheck(int depth, const std::string &matchStr, + const std::string &failureStr); + //===--------------------------------------------------------------------===// // Rewrite utilities //===--------------------------------------------------------------------===// @@ -327,8 +332,9 @@ op.arg_begin(), op.arg_begin() + argIndex, [](const Argument &arg) { return arg.is(); }); - os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth, - argIndex - numPrevAttrs); + auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); + os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", + res->second.getVarName(name), depth, argIndex - numPrevAttrs); } } @@ -393,10 +399,15 @@ void PatternEmitter::emitMatchCheck( int depth, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt) { - os << "if (!(" << matchFmt.str() << "))"; + emitMatchCheck(depth, matchFmt.str(), failureFmt.str()); +} + +void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr, + const std::string &failureStr) { + os << "if (!(" << matchStr << "))"; os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(op" << depth - << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str() + << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr << ";\n});"; } @@ -445,6 +456,30 @@ constraint.getDescription())); } } + + // Some of the operands could be bound to the same symbol name, we need + // to enforce equality constraint on those. + // TODO: we should be able to emit equality checks early + // and short circuit unnecessary work if vars are not equal. + for (auto symbolInfoIt = symbolInfoMap.begin(); + symbolInfoIt != symbolInfoMap.end();) { + auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); + auto startRange = range.first; + auto endRange = range.second; + + auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); + for (++startRange; startRange != endRange; ++startRange) { + auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); + emitMatchCheck( + depth, + formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), + formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, + secondOperand)); + } + + symbolInfoIt = endRange; + } + LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); } @@ -518,8 +553,9 @@ // Create local variables for storing the arguments and results bound // to symbols. for (const auto &symbolInfoPair : symbolInfoMap) { - StringRef symbol = symbolInfoPair.getKey(); - auto &info = symbolInfoPair.getValue(); + const auto &symbol = symbolInfoPair.first; + const auto &info = symbolInfoPair.second; + os << info.getVarDecl(symbol); } // TODO: capture ops with consistent numbering so that it can be @@ -1093,7 +1129,7 @@ os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", range); } else { - os << formatv("tblgen_values.push_back(", varName); + os << formatv("tblgen_values.push_back("); if (node.isNestedDagArg(argIndex)) { os << symbolInfoMap.getValueAndRangeUse( childNodeNames.lookup(argIndex));