diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2351,8 +2351,6 @@ string expression = expr; } -def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">; - //===----------------------------------------------------------------------===// // Rewrite directives //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -252,9 +252,6 @@ static SymbolInfo getAttr(const Operator *op, int index) { return SymbolInfo(op, Kind::Attr, index); } - static SymbolInfo getAttr() { - return SymbolInfo(nullptr, Kind::Attr, llvm::None); - } static SymbolInfo getOperand(const Operator *op, int index) { return SymbolInfo(op, Kind::Operand, index); } @@ -322,10 +319,6 @@ // is already bound. bool bindValue(StringRef symbol); - // Registers the given `symbol` as bound to an attr. Returns false if `symbol` - // is already bound. - bool bindAttr(StringRef symbol); - // Returns true if the given `symbol` is bound. bool contains(StringRef symbol) const; @@ -428,9 +421,6 @@ std::vector getLocation() const; private: - // Helper function to verify variabld binding. - void verifyBind(bool result, StringRef symbolName); - // Recursively collects all bound symbols inside the DAG tree rooted // at `tree` and updates the given `infoMap`. void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -216,13 +216,9 @@ LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); switch (kind) { case Kind::Attr: { - if (op) { - auto type = - op->getArg(*argIndex).get()->attr.getStorageType(); - return std::string(formatv("{0} {1};\n", type, name)); - } - // TODO(suderman): Use a more exact type when available. - return std::string(formatv("Attribute {0};\n", name)); + auto type = + op->getArg(*argIndex).get()->attr.getStorageType(); + return std::string(formatv("{0} {1};\n", type, name)); } case Kind::Operand: { // Use operand range for captured operands (to support potential variadic @@ -398,11 +394,6 @@ return symbolInfoMap.count(inserted->first) == 1; } -bool SymbolInfoMap::bindAttr(StringRef symbol) { - auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getAttr()); - return symbolInfoMap.count(inserted->first) == 1; -} - bool SymbolInfoMap::contains(StringRef symbol) const { return find(symbol) != symbolInfoMap.end(); } @@ -567,15 +558,15 @@ for (auto it : *listInit) { auto *dagInit = dyn_cast(it); if (!dagInit) - PrintFatalError(&def, "all elements in Pattern multi-entity " - "constraints should be DAG nodes"); + PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " + "constraints should be DAG nodes"); std::vector entities; entities.reserve(dagInit->arg_size()); for (auto *argName : dagInit->getArgNames()) { if (!argName) { PrintFatalError( - &def, + def.getLoc(), "operands to additional constraints can only be symbol references"); } entities.push_back(std::string(argName->getValue())); @@ -593,7 +584,7 @@ int initBenefit = getSourcePattern().getNumOps(); llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); if (delta->getNumArgs() != 1 || !isa(delta->getArg(0))) { - PrintFatalError(&def, + PrintFatalError(def.getLoc(), "The 'addBenefit' takes and only takes one integer value"); } return initBenefit + dyn_cast(delta->getArg(0))->getValue(); @@ -612,120 +603,64 @@ return result; } -void Pattern::verifyBind(bool result, StringRef symbolName) { - if (!result) { - auto err = formatv("symbol '{0}' bound more than once", symbolName); - PrintFatalError(&def, err); - } -} - void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern) { auto treeName = tree.getSymbol(); - auto numTreeArgs = tree.getNumArgs(); - - if (tree.isNativeCodeCall()) { + if (!tree.isOperation()) { if (!treeName.empty()) { PrintFatalError( - &def, - formatv( - "binding symbol '{0}' to native code call unsupported right now", - treeName)); + def.getLoc(), + formatv("binding symbol '{0}' to non-operation unsupported right now", + treeName)); } + return; + } - for (int i = 0; i != numTreeArgs; ++i) { - if (auto treeArg = tree.getArgAsNestedDag(i)) { - // This DAG node argument is a DAG node itself. Go inside recursively. - collectBoundSymbols(treeArg, infoMap, isSrcPattern); - continue; - } + auto &op = getDialectOp(tree); + auto numOpArgs = op.getNumArgs(); + auto numTreeArgs = tree.getNumArgs(); + + // The pattern might have the last argument specifying the location. + bool hasLocDirective = false; + if (numTreeArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) + hasLocDirective = lastArg.isLocationDirective(); + } - if (!isSrcPattern) - continue; + if (numOpArgs != numTreeArgs - hasLocDirective) { + auto err = formatv("op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + op.getOperationName(), numTreeArgs, numOpArgs); + PrintFatalError(def.getLoc(), err); + } - // We can only bind symbols to arguments in source pattern. Those + // The name attached to the DAG node's operator is for representing the + // results generated from this op. It should be remembered as bound results. + if (!treeName.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "found symbol bound to op result: " << treeName << '\n'); + if (!infoMap.bindOpResult(treeName, op)) + PrintFatalError(def.getLoc(), + formatv("symbol '{0}' bound more than once", treeName)); + } + + for (int i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); + } else if (isSrcPattern) { + // We can only bind symbols to op arguments in source pattern. Those // symbols are referenced in result patterns. auto treeArgName = tree.getArgName(i); - // `$_` is a special symbol meaning ignore the current argument. if (!treeArgName.empty() && treeArgName != "_") { - if (tree.isNestedDagArg(i)) { - auto err = formatv("cannot bind '{0}' for nested native call arg", - treeArgName); - PrintFatalError(&def, err); + LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " + << treeArgName << '\n'); + if (!infoMap.bindOpArgument(treeArgName, op, i)) { + auto err = formatv("symbol '{0}' bound more than once", treeArgName); + PrintFatalError(def.getLoc(), err); } - - DagLeaf leaf = tree.getArgAsLeaf(i); - auto constraint = leaf.getAsConstraint(); - bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || - leaf.isConstantAttr() || - constraint.getKind() == Constraint::Kind::CK_Attr; - - if (isAttr) { - verifyBind(infoMap.bindAttr(treeArgName), treeArgName); - continue; - } - - verifyBind(infoMap.bindValue(treeArgName), treeArgName); } } - - return; - } - - if (tree.isOperation()) { - auto &op = getDialectOp(tree); - auto numOpArgs = op.getNumArgs(); - - // The pattern might have the last argument specifying the location. - bool hasLocDirective = false; - if (numTreeArgs != 0) { - if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) - hasLocDirective = lastArg.isLocationDirective(); - } - - if (numOpArgs != numTreeArgs - hasLocDirective) { - auto err = formatv("op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs); - PrintFatalError(&def, err); - } - - // The name attached to the DAG node's operator is for representing the - // results generated from this op. It should be remembered as bound results. - if (!treeName.empty()) { - LLVM_DEBUG(llvm::dbgs() - << "found symbol bound to op result: " << treeName << '\n'); - verifyBind(infoMap.bindOpResult(treeName, op), treeName); - } - - for (int i = 0; i != numTreeArgs; ++i) { - if (auto treeArg = tree.getArgAsNestedDag(i)) { - // This DAG node argument is a DAG node itself. Go inside recursively. - collectBoundSymbols(treeArg, infoMap, isSrcPattern); - continue; - } - - if (isSrcPattern) { - // We can only bind symbols to op arguments in source pattern. Those - // symbols are referenced in result patterns. - auto treeArgName = tree.getArgName(i); - // `$_` is a special symbol meaning ignore the current argument. - if (!treeArgName.empty() && treeArgName != "_") { - LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " - << treeArgName << '\n'); - verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName); - } - } - } - return; - } - - if (!treeName.empty()) { - PrintFatalError( - &def, formatv("binding symbol '{0}' to non-operation/native code call " - "unsupported right now", - treeName)); } - return; } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -615,10 +615,6 @@ return operand(); } -OpFoldResult TestOpConstant::fold(ArrayRef operands) { - return getValue(); -} - LogicalResult TestOpWithVariadicResultsAndFolder::fold( ArrayRef operands, SmallVectorImpl &results) { for (Value input : this->operands()) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -799,22 +799,6 @@ let hasCanonicalizer = 1; } -def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> { - let arguments = (ins AnyAttr:$value); - let results = (outs AnyType); - let extraClassDeclaration = [{ - Attribute getValue() { return getAttr("value"); } - }]; - - let hasFolder = 1; -} - -def OpR : TEST_Op<"op_r">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>; -def OpS : TEST_Op<"op_s">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>; - -def : Pat<(OpR $input1, (ConstantLikeMatcher I32Attr:$input2)), - (OpS:$unused $input1, $input2)>; - // Op for testing trivial removal via folding of op with inner ops and no uses. def TestOpWithRegionFoldNoSideEffect : TEST_Op< "op_with_region_fold_no_side_effect", [NoSideEffect]> { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -9,7 +9,6 @@ #include "TestDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -248,58 +248,6 @@ return %0, %1 : i32, i32 } -//===----------------------------------------------------------------------===// -// Test Constant Matching -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: testConstOp -func @testConstOp() -> (i32) { - // CHECK-NEXT: [[C0:%.+]] = constant 1 - %0 = "test.constant"() {value = 1 : i32} : () -> i32 - - // CHECK-NEXT: return [[C0]] - return %0 : i32 -} - -// CHECK-LABEL: testConstOpUsed -func @testConstOpUsed() -> (i32) { - // CHECK-NEXT: [[C0:%.+]] = constant 1 - %0 = "test.constant"() {value = 1 : i32} : () -> i32 - - // CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]]) - %1 = "test.op_s"(%0) {value = 1 : i32} : (i32) -> i32 - - // CHECK-NEXT: return [[V0]] - return %1 : i32 -} - -// CHECK-LABEL: testConstOpReplaced -func @testConstOpReplaced() -> (i32) { - // CHECK-NEXT: [[C0:%.+]] = constant 1 - %0 = "test.constant"() {value = 1 : i32} : () -> i32 - %1 = "test.constant"() {value = 2 : i32} : () -> i32 - - // CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32} - %2 = "test.op_r"(%0, %1) : (i32, i32) -> i32 - - // CHECK: [[V0]] - return %2 : i32 -} -// CHECK-LABEL: testConstOpMatchFailure -func @testConstOpMatchFailure() -> (i64) { - // CHECK-DAG: [[C0:%.+]] = constant 1 - %0 = "test.constant"() {value = 1 : i64} : () -> i64 - - // CHECK-DAG: [[C1:%.+]] = constant 2 - %1 = "test.constant"() {value = 2 : i64} : () -> i64 - - // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]]) - %2 = "test.op_r"(%0, %1) : (i64, i64) -> i64 - - // CHECK: [[V0]] - return %2 : i64 -} - //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td deleted file mode 100644 --- a/mlir/test/mlir-tblgen/rewriter-errors.td +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s -// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s - -include "mlir/IR/OpBase.td" - -// Check using the dialect name as the namespace -def A_Dialect : Dialect { - let name = "a"; -} - -class A_Op traits = []> : - Op; - -def OpA : A_Op<"op_a">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>; -def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>; - -#ifdef ERROR1 -def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; -// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now -def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg), - (OpB $val, $arg)>; -#endif - -#ifdef ERROR2 -def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; -// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for -def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg), - (OpB $val, $arg)>; -#endif diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -63,7 +63,7 @@ private: // Emits the code for matching ops. - void emitMatchLogic(DagNode tree, StringRef opName); + void emitMatchLogic(DagNode tree); // Emits the code for rewriting ops. void emitRewriteLogic(); @@ -72,34 +72,26 @@ // Match utilities //===--------------------------------------------------------------------===// - // Emits C++ statements for matching the DAG structure. - void emitMatch(DagNode tree, StringRef name, int depth); - - // Emits C++ statements for matching using a native code call. - void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); - // Emits C++ statements for matching the op constrained by the given DAG - // `tree` returning the op's variable name. - void emitOpMatch(DagNode tree, StringRef opName, int depth); + // `tree`. + void emitOpMatch(DagNode tree, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an operand. - void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, - int depth); + void emitOperandMatch(DagNode tree, int argIndex, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. - void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, - int depth); + void emitAttributeMatch(DagNode tree, int argIndex, int depth); // Emits C++ for checking a match with a corresponding match failure // diagnostic. - void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, + void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt); // Emits C++ for checking a match with a corresponding match failure // diagnostics. - void emitMatchCheck(StringRef opName, const std::string &matchStr, + void emitMatchCheck(int depth, const std::string &matchStr, const std::string &failureStr); //===--------------------------------------------------------------------===// @@ -121,7 +113,7 @@ // Emits the C++ statement to replace the matched DAG with a value built via // calling native C++ code. - std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); + std::string handleReplaceWithNativeCodeCall(DagNode resultTree); // Returns the symbol of the old value serving as the replacement. StringRef handleReplaceWithValue(DagNode tree); @@ -148,13 +140,12 @@ // Emits the concrete arguments used to call an op's builder. void supplyValuesForOpArgs(DagNode node, - const ChildNodeIndexNameMap &childNodeNames, - int depth); + const ChildNodeIndexNameMap &childNodeNames); // Emits the local variables for holding all values as a whole and all named // attributes as a whole to be used for creating an op. void createAggregateLocalVarsForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); + DagNode node, const ChildNodeIndexNameMap &childNodeNames); // Returns the C++ expression to construct a constant attribute of the given // `value` for the given attribute kind `attr`. @@ -227,114 +218,21 @@ } // Helper function to match patterns. -void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { - if (tree.isNativeCodeCall()) { - emitNativeCodeMatch(tree, name, depth); - return; - } - - if (tree.isOperation()) { - emitOpMatch(tree, name, depth); - return; - } - - PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); -} - -// Helper function to match patterns. -void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, - int depth) { - LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); - LLVM_DEBUG(tree.print(llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << '\n'); - - // TODO(suderman): iterate through arguments, determine their types, output - // names. - SmallVector capture(8); - if (tree.getNumArgs() > 8) { - PrintFatalError(loc, - "unsupported NativeCodeCall matcher argument numbers: " + - Twine(tree.getNumArgs())); - } - - raw_indented_ostream::DelimitedScope scope(os); - - for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { - std::string argName = formatv("arg{0}_{1}", depth, i); - if (DagNode argTree = tree.getArgAsNestedDag(i)) { - os << "Value " << argName << ";\n"; - } else { - auto leaf = tree.getArgAsLeaf(i); - if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { - os << "Attribute " << argName << ";\n"; - } else if (leaf.isOperandMatcher()) { - os << "Operation " << argName << ";\n"; - } - } - - capture[i] = std::move(argName); - } - - bool hasLocationDirective; - std::string locToUse; - std::tie(hasLocationDirective, locToUse) = getLocation(tree); - - auto fmt = tree.getNativeCodeTemplate(); - auto nativeCodeCall = std::string(tgfmt( - fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1], - capture[2], capture[3], capture[4], capture[5], capture[6], capture[7])); - - os << "if (failed(" << nativeCodeCall << ")) return failure();\n"; - - for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { - auto name = tree.getArgName(i); - if (!name.empty() && name != "_") { - os << formatv("{0} = {1};\n", name, capture[i]); - } - } - - for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { - std::string argName = capture[i]; - - // Handle nested DAG construct first - if (DagNode argTree = tree.getArgAsNestedDag(i)) { - PrintFatalError( - loc, formatv("Matching nested tree in NativeCodecall not support for " - "{0} as arg {1}", - argName, i)); - } - - DagLeaf leaf = tree.getArgAsLeaf(i); - auto constraint = leaf.getAsConstraint(); - - auto self = formatv("{0}", argName); - emitMatchCheck( - opName, - tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), - formatv("\"operand {0} of native code call '{1}' failed to satisfy " - "constraint: " - "'{2}'\"", - i, tree.getNativeCodeTemplate(), constraint.getDescription())); - } - - LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); -} - -// Helper function to match patterns. -void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { +void PatternEmitter::emitOpMatch(DagNode tree, int depth) { Operator &op = tree.getDialectOp(opMap); LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" << op.getOperationName() << "' at depth " << depth << '\n'); - std::string castedName = formatv("castedOp{0}", depth); - os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); " - "(void){0};\n", - castedName, opName, op.getQualCppClassName()); + int indent = 4 + 2 * depth; + os.indent(indent) << formatv( + "auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); " + "(void)castedOp{0};\n", + depth, op.getQualCppClassName()); // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { // Skip if there is no defining operation (e.g., arguments to function). - os << formatv("if (!{0}) return failure();\n", castedName); + os << formatv("if (!castedOp{0})\n return failure();\n", depth); } if (tree.getNumArgs() != op.getNumArgs()) { PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " @@ -346,11 +244,10 @@ // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); if (!name.empty()) - os << formatv("{0} = {1};\n", name, castedName); + os << formatv("{0} = castedOp{1};\n", name, depth); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); - std::string argName = formatv("op{0}", depth + 1); // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { @@ -365,20 +262,20 @@ os << "{\n"; os.indent() << formatv( - "auto *{0} = " - "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", - argName, castedName, i); - emitMatch(argTree, argName, depth + 1); - os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); + "auto *op{0} = " + "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n", + depth + 1, depth, i); + emitOpMatch(argTree, depth + 1); + os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); os.unindent() << "}\n"; continue; } // Next handle DAG leaf: operand or attribute if (opArg.is()) { - emitOperandMatch(tree, castedName, i, depth); + emitOperandMatch(tree, i, depth); } else if (opArg.is()) { - emitAttributeMatch(tree, opName, i, depth); + emitAttributeMatch(tree, i, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); } @@ -388,8 +285,7 @@ << '\n'); } -void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, - int argIndex, int depth) { +void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *operand = op.getArg(argIndex).get(); auto matcher = tree.getArgAsLeaf(argIndex); @@ -413,10 +309,11 @@ op.getOperationName(), argIndex); PrintFatalError(loc, error); } - auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", - opName, argIndex); + auto self = + formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth, + argIndex); emitMatchCheck( - opName, + depth, tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " "'{2}'\"", @@ -436,22 +333,21 @@ [](const Argument &arg) { return arg.is(); }); auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); - os << formatv("{0} = {1}.getODSOperands({2});\n", - res->second.getVarName(name), opName, - argIndex - numPrevAttrs); + os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", + res->second.getVarName(name), depth, argIndex - numPrevAttrs); } } -void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, - int argIndex, int depth) { +void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *namedAttr = op.getArg(argIndex).get(); const auto &attr = namedAttr->attr; os << "{\n"; - os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" - "(void)tblgen_attr;\n", - opName, attr.getStorageType(), namedAttr->name); + os.indent() << formatv( + "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); " + "(void)tblgen_attr;\n", + depth, attr.getStorageType(), namedAttr->name); // TODO: This should use getter method to avoid duplication. if (attr.hasDefaultValue()) { @@ -464,7 +360,7 @@ // should just capture a mlir::Attribute() to signal the missing state. // That is precisely what getAttr() returns on missing attributes. } else { - emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), + emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx), formatv("\"expected op '{0}' to have attribute '{1}' " "of type '{2}'\"", op.getOperationName(), namedAttr->name, @@ -482,7 +378,7 @@ // If a constraint is specified, we need to generate C++ statements to // check the constraint. emitMatchCheck( - opName, + depth, tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " "{2}\"", @@ -501,25 +397,24 @@ } void PatternEmitter::emitMatchCheck( - StringRef opName, const FmtObjectBase &matchFmt, + int depth, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt) { - emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); + emitMatchCheck(depth, matchFmt.str(), failureFmt.str()); } -void PatternEmitter::emitMatchCheck(StringRef opName, - const std::string &matchStr, +void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr, const std::string &failureStr) { - os << "if (!(" << matchStr << "))"; - os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName - << ", [&](::mlir::Diagnostic &diag) {\n diag << " - << failureStr << ";\n});"; + os.scope("{\n", "\n}\n").os + << "return rewriter.notifyMatchFailure(op" << depth + << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr + << ";\n});"; } -void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { +void PatternEmitter::emitMatchLogic(DagNode tree) { LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); int depth = 0; - emitMatch(tree, opName, depth); + emitOpMatch(tree, depth); for (auto &appliedConstraint : pattern.getConstraints()) { auto &constraint = appliedConstraint.constraint; @@ -530,7 +425,7 @@ auto self = formatv("({0}.getType())", symbolInfoMap.getValueAndRangeUse(entities.front())); emitMatchCheck( - opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), + depth, tgfmt(condition, &fmtCtx.withSelf(self.str())), formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", entities.front(), constraint.getDescription())); @@ -552,7 +447,7 @@ self = symbolInfoMap.getValueAndRangeUse(self); for (; i < 4; ++i) names.push_back(""); - emitMatchCheck(opName, + emitMatchCheck(depth, tgfmt(condition, &fmtCtx.withSelf(self), names[0], names[1], names[2], names[3]), formatv("\"entities '{0}' failed to satisfy constraint: " @@ -576,7 +471,7 @@ for (++startRange; startRange != endRange; ++startRange) { auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); emitMatchCheck( - opName, + depth, formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, secondOperand)); @@ -672,7 +567,7 @@ os << "// Match\n"; os << "tblgen_ops[0] = op0;\n"; - emitMatchLogic(sourceTree, "op0"); + emitMatchLogic(sourceTree); os << "\n// Rewrite\n"; emitRewriteLogic(); @@ -786,7 +681,7 @@ } if (resultTree.isNativeCodeCall()) { - auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); + auto symbol = handleReplaceWithNativeCodeCall(resultTree); symbolInfoMap.bindValue(symbol); return symbol; } @@ -903,8 +798,7 @@ PrintFatalError(loc, "unhandled case when rewriting op"); } -std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, - int depth) { +std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); LLVM_DEBUG(tree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); @@ -913,20 +807,15 @@ // TODO: replace formatv arguments with the exact specified args. SmallVector attrs(8); if (tree.getNumArgs() > 8) { - PrintFatalError(loc, - "unsupported NativeCodeCall replace argument numbers: " + - Twine(tree.getNumArgs())); + PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + + Twine(tree.getNumArgs())); } bool hasLocationDirective; std::string locToUse; std::tie(hasLocationDirective, locToUse) = getLocation(tree); for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { - if (tree.isNestedDagArg(i)) { - attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1); - } else { - attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); - } + attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i << " replacement: " << attrs[i] << "\n"); } @@ -1035,7 +924,7 @@ // create the ops. // First prepare local variables for op arguments used in builder call. - createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); + createAggregateLocalVarsForOpArgs(tree, childNodeNames); // Then create the op. os.scope("", "\n}\n").os << formatv( @@ -1059,7 +948,7 @@ os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, resultOp.getQualCppClassName(), locToUse); - supplyValuesForOpArgs(tree, childNodeNames, depth); + supplyValuesForOpArgs(tree, childNodeNames); os << "\n );\n}\n"; return resultValue; } @@ -1070,7 +959,7 @@ // here. // First prepare local variables for op arguments used in builder call. - createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); + createAggregateLocalVarsForOpArgs(tree, childNodeNames); // Then prepare the result types. We need to specify the types for all // results. @@ -1148,7 +1037,7 @@ } void PatternEmitter::supplyValuesForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { + DagNode node, const ChildNodeIndexNameMap &childNodeNames) { Operator &resultOp = node.getDialectOp(opMap); for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); argIndex != numOpArgs; ++argIndex) { @@ -1171,7 +1060,7 @@ PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv("/*{0}=*/{1}", opArgName, - handleReplaceWithNativeCodeCall(subTree, depth)); + handleReplaceWithNativeCodeCall(subTree)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. @@ -1191,7 +1080,7 @@ } void PatternEmitter::createAggregateLocalVarsForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { + DagNode node, const ChildNodeIndexNameMap &childNodeNames) { Operator &resultOp = node.getDialectOp(opMap); auto scope = os.scope(); @@ -1213,7 +1102,7 @@ PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv(addAttrCmd, opArgName, - handleReplaceWithNativeCodeCall(subTree, depth + 1)); + handleReplaceWithNativeCodeCall(subTree)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern.