diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -271,6 +271,13 @@ inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); } } // namespace matchers +inline LogicalResult ConstantMatcher(Operation *op, Attribute &attribute) { + if (!matchPattern(op->getResult(0), m_Constant(&attribute))) { + return failure(); + } + return success(); +} + } // end namespace mlir #endif // MLIR_MATCHERS_H diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2352,6 +2352,8 @@ string expression = expr; } +def ConstantLikeOp : NativeCodeCall<"ConstantMatcher($0, $1)">; + //===----------------------------------------------------------------------===// // Rewrite directives //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -247,6 +247,9 @@ static SymbolInfo getAttr(const Operator *op, int index) { return SymbolInfo(op, Kind::Attr, index); } + static SymbolInfo getAttr() { + return SymbolInfo(nullptr, Kind::Attr, llvm::None); + } static SymbolInfo getOperand(const Operator *op, int index) { return SymbolInfo(op, Kind::Operand, index); } @@ -311,6 +314,10 @@ // is already bound. bool bindValue(StringRef symbol); + // Registers the given `symbol` as bound to an attr. Returns false if `symbol` + // is already bound. + bool bindAttr(StringRef symbol); + // Returns true if the given `symbol` is bound. bool contains(StringRef symbol) const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -212,9 +212,12 @@ LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); switch (kind) { case Kind::Attr: { - auto type = - op->getArg(*argIndex).get()->attr.getStorageType(); - return std::string(formatv("{0} {1};\n", type, name)); + if (op) { + auto type = + op->getArg(*argIndex).get()->attr.getStorageType(); + return std::string(formatv("{0} {1};\n", type, name)); + } + return std::string(formatv("Attribute {0};\n", name)); } case Kind::Operand: { // Use operand range for captured operands (to support potential variadic @@ -371,6 +374,10 @@ return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; } +bool SymbolInfoMap::bindAttr(StringRef symbol) { + return symbolInfoMap.insert({symbol, SymbolInfo::getAttr()}).second; +} + bool SymbolInfoMap::contains(StringRef symbol) const { return find(symbol) != symbolInfoMap.end(); } @@ -520,19 +527,68 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern) { auto treeName = tree.getSymbol(); - if (!tree.isOperation()) { + auto numTreeArgs = tree.getNumArgs(); + + if (!tree.isOperation() && !tree.isNativeCodeCall()) { if (!treeName.empty()) { PrintFatalError( def.getLoc(), - formatv("binding symbol '{0}' to non-operation unsupported right now", + formatv("binding symbol '{0}' to non-operation/native code call " + "unsupported right now", treeName)); } return; } + if (tree.isNativeCodeCall()) { + if (!treeName.empty()) { + PrintFatalError( + def.getLoc(), + formatv( + "binding symbol '{0}' to native code call unsupported right now", + treeName)); + } + + for (int i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); + } else if (isSrcPattern) { + // We can only bind symbols to arguments in source pattern. Those + // symbols are referenced in result patterns. + auto treeArgName = tree.getArgName(i); + + // `$_` is a special symbol meaning ignore the current argument. + if (!treeArgName.empty() && treeArgName != "_") { + if (tree.isNestedDagArg(i)) { + auto err = formatv("cannot bind '{0}' for nest native call arg", + treeArgName); + PrintFatalError(def.getLoc(), err); + } + DagLeaf leaf = tree.getArgAsLeaf(i); + if (leaf.isAttrMatcher() || leaf.isEnumAttrCase() || + leaf.isConstantAttr()) { + if (!infoMap.bindAttr(treeArgName)) { + auto err = + formatv("symbol '{0}' bound more than once", treeArgName); + PrintFatalError(def.getLoc(), err); + } + } else { + if (!infoMap.bindValue(treeArgName)) { + auto err = + formatv("symbol '{0}' bound more than once", treeArgName); + PrintFatalError(def.getLoc(), err); + } + } + } + } + } + + return; + } + auto &op = getDialectOp(tree); auto numOpArgs = op.getNumArgs(); - auto numTreeArgs = tree.getNumArgs(); // The pattern might have the last argument specifying the location. bool hasLocDirective = false; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -606,6 +606,10 @@ return operand(); } +OpFoldResult TestOpConstant::fold(ArrayRef operands) { + return getValue(); +} + LogicalResult TestOpWithVariadicResultsAndFolder::fold( ArrayRef operands, SmallVectorImpl &results) { for (Value input : this->operands()) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -773,6 +773,23 @@ let hasCanonicalizer = 1; } +def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> { + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType); + let extraClassDeclaration = [{ + Attribute getValue() { return getAttr("value"); } + }]; + + let hasFolder = 1; +} + +def OpN : TEST_Op<"op_n">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>; +def OpO : TEST_Op<"op_o">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>; + +def : Pat<(OpN $input1, (ConstantLikeOp I32Attr:$input2)), + (OpO $input1, $input2), + []>; + // Op for testing trivial removal via folding of op with inner ops and no uses. def TestOpWithRegionFoldNoSideEffect : TEST_Op< "op_with_region_fold_no_side_effect", [NoSideEffect]> { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -9,6 +9,7 @@ #include "TestDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -190,6 +190,58 @@ return %0, %1 : i32, i32 } +//===----------------------------------------------------------------------===// +// Test Constant Matching +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: testConstOp +func @testConstOp() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + + // CHECK-NEXT: return [[C0]] + return %0 : i32 +} + +// CHECK-LABEL: testConstOpUsed +func @testConstOpUsed() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + + // CHECK-NEXT: [[V0:%.+]] = "test.op_o"([[C0]]) + %1 = "test.op_o"(%0) {value = 1 : i32} : (i32) -> i32 + + // CHECK-NEXT: return [[V0]] + return %1 : i32 +} + +// CHECK-LABEL: testConstOpReplaced +func @testConstOpReplaced() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + %1 = "test.constant"() {value = 2 : i32} : () -> i32 + + // CHECK: [[V0:%.+]] = "test.op_o"([[C0]]) {value = 2 : i32} + %2 = "test.op_n"(%0, %1) : (i32, i32) -> i32 + + // CHECK: [[V0]] + return %2 : i32 +} +// CHECK-LABEL: testConstOpMatchFailure +func @testConstOpMatchFailure() -> (i64) { + // CHECK-DAG: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i64} : () -> i64 + + // CHECK-DAG: [[C1:%.+]] = constant 2 + %1 = "test.constant"() {value = 2 : i64} : () -> i64 + + // CHECK: [[V0:%.+]] = "test.op_n"([[C0]], [[C1]]) + %2 = "test.op_n"(%0, %1) : (i64, i64) -> i64 + + // CHECK: [[V0]] + return %2 : i64 +} + //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -63,7 +63,7 @@ private: // Emits the code for matching ops. - void emitMatchLogic(DagNode tree); + void emitMatchLogic(DagNode tree, StringRef opName); // Emits the code for rewriting ops. void emitRewriteLogic(); @@ -72,21 +72,29 @@ // Match utilities //===--------------------------------------------------------------------===// + // Emits C++ statements for matching the DAG structure. + void emitMatch(DagNode tree, StringRef name, int depth); + + // Emits C++ statements for matching using a native code call. + void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); + // Emits C++ statements for matching the op constrained by the given DAG - // `tree`. - void emitOpMatch(DagNode tree, int depth); + // `tree` returning the op's variable name. + void emitOpMatch(DagNode tree, StringRef opName, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an operand. - void emitOperandMatch(DagNode tree, int argIndex, int depth); + void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, + int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. - void emitAttributeMatch(DagNode tree, int argIndex, int depth); + void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, + int depth); // Emits C++ for checking a match with a corresponding match failure // diagnostic. - void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, + void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt); //===--------------------------------------------------------------------===// @@ -108,7 +116,7 @@ // Emits the C++ statement to replace the matched DAG with a value built via // calling native C++ code. - std::string handleReplaceWithNativeCodeCall(DagNode resultTree); + std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); // Returns the symbol of the old value serving as the replacement. StringRef handleReplaceWithValue(DagNode tree); @@ -135,12 +143,13 @@ // Emits the concrete arguments used to call an op's builder. void supplyValuesForOpArgs(DagNode node, - const ChildNodeIndexNameMap &childNodeNames); + const ChildNodeIndexNameMap &childNodeNames, + int depth); // Emits the local variables for holding all values as a whole and all named // attributes as a whole to be used for creating an op. void createAggregateLocalVarsForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames); + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); // Returns the C++ expression to construct a constant attribute of the given // `value` for the given attribute kind `attr`. @@ -213,20 +222,118 @@ } // Helper function to match patterns. -void PatternEmitter::emitOpMatch(DagNode tree, int depth) { +void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { + if (tree.isNativeCodeCall()) { + emitNativeCodeMatch(tree, name, depth); + return; + } + + if (tree.isOperation()) { + emitOpMatch(tree, name, depth); + return; + } + + PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); +} + +// Helper function to match patterns. +void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, + int depth) { + int indent = 4 + 2 * depth; + LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); + LLVM_DEBUG(tree.print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); + + // TODO(suderman): iterate through arguments, determine their types, output + // names. + SmallVector capture(8); + if (tree.getNumArgs() > 8) { + PrintFatalError(loc, + "unsupported NativeCodeCall matcher argument numbers: " + + Twine(tree.getNumArgs())); + } + + os.indent(indent) << "{\n"; + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + std::string argName = formatv("arg{0}_{1}", depth, i); + if (DagNode argTree = tree.getArgAsNestedDag(i)) { + os.indent(indent + 2) << "Value " << argName << ";\n"; + } else { + auto leaf = tree.getArgAsLeaf(i); + if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { + os.indent(indent + 2) << "Attribute " << argName << ";\n"; + } else if (leaf.isOperandMatcher()) { + os.indent(indent + 2) << "Operation " << argName << ";\n"; + } + } + + capture[i] = std::move(argName); + } + + bool hasLocationDirective; + std::string locToUse; + std::tie(hasLocationDirective, locToUse) = getLocation(tree); + + auto fmt = tree.getNativeCodeTemplate(); + auto nativeCodeCall = std::string(tgfmt( + fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1], + capture[2], capture[3], capture[4], capture[5], capture[6], capture[7])); + + os.indent(indent + 2) << "if (failed(" << nativeCodeCall + << ")) return failure();\n"; + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + auto name = tree.getArgName(i); + if (!name.empty() && name != "_") { + os.indent(indent + 2) << formatv("{0} = {1};\n", name, capture[i]); + } + } + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + std::string argName = capture[i]; + + // Handle nested DAG construct first + if (DagNode argTree = tree.getArgAsNestedDag(i)) { + PrintFatalError( + loc, formatv("Matching nested tree in NativeCodecall not support for " + "{0} as arg {1}", + argName, i)); + } + + DagLeaf leaf = tree.getArgAsLeaf(i); + auto constraint = leaf.getAsConstraint(); + + auto self = formatv("{0}", argName); + emitMatchCheck( + opName, + tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), + formatv("\"operand {0} of native code call '{1}' failed to satisfy " + "constraint: " + "'{2}'\"", + i, tree.getNativeCodeTemplate(), constraint.getDescription())); + } + + os.indent(indent) << "}\n"; + LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); +} + +// Helper function to match patterns. +void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { Operator &op = tree.getDialectOp(opMap); LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" << op.getOperationName() << "' at depth " << depth << '\n'); int indent = 4 + 2 * depth; + std::string castedName = formatv("castedOp{0}", depth); os.indent(indent) << formatv( - "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", - depth, op.getQualCppClassName()); + "auto {0} = dyn_cast_or_null<{2}>({1}); (void){0};\n", castedName, opName, + op.getQualCppClassName()); // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { // Skip if there is no defining operation (e.g., arguments to function). - os << formatv("if (!castedOp{0})\n return failure();\n", depth); + os.indent(indent) << formatv("if (!{0}) return failure();\n", castedName); } if (tree.getNumArgs() != op.getNumArgs()) { PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " @@ -238,10 +345,11 @@ // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); if (!name.empty()) - os << formatv("{0} = castedOp{1};\n", name, depth); + os.indent(indent) << formatv("{0} = {1};\n", name, castedName); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); + std::string argName = formatv("op{0}", depth + 1); // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { @@ -255,21 +363,22 @@ } os << "{\n"; - os.indent() << formatv( - "auto *op{0} = " - "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n", - depth + 1, depth, i); - emitOpMatch(argTree, depth + 1); - os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); - os.unindent() << "}\n"; + os.indent(indent + 2) + << formatv("auto *{0} = " + "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", + argName, castedName, i); + emitMatch(argTree, argName, depth + 1); + os.indent(indent + 2) + << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); + os.indent(indent) << "}\n"; continue; } // Next handle DAG leaf: operand or attribute if (opArg.is()) { - emitOperandMatch(tree, i, depth); + emitOperandMatch(tree, castedName, i, depth); } else if (opArg.is()) { - emitAttributeMatch(tree, i, depth); + emitAttributeMatch(tree, opName, i, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); } @@ -279,10 +388,12 @@ << '\n'); } -void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { +void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, + int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *operand = op.getArg(argIndex).get(); auto matcher = tree.getArgAsLeaf(argIndex); + int indent = 4 + 2 * depth; // If a constraint is specified, we need to generate C++ statements to // check the constraint. @@ -303,11 +414,10 @@ op.getOperationName(), argIndex); PrintFatalError(loc, error); } - auto self = - formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth, - argIndex); + auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", + opName, argIndex); emitMatchCheck( - depth, + opName, tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " "'{2}'\"", @@ -326,21 +436,24 @@ op.arg_begin(), op.arg_begin() + argIndex, [](const Argument &arg) { return arg.is(); }); - os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth, - argIndex - numPrevAttrs); + os.indent(indent) << formatv("{0} = {1}.getODSOperands({2});\n", name, + opName, argIndex - numPrevAttrs); } } -void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { +void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, + int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *namedAttr = op.getArg(argIndex).get(); const auto &attr = namedAttr->attr; + int indent = 4 + 2 * depth; - os << "{\n"; - os.indent() << formatv( - "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); " + os.indent(indent) << "{\n"; + indent += 2; + os.indent(indent) << formatv( + "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" "(void)tblgen_attr;\n", - depth, attr.getStorageType(), namedAttr->name); + opName, attr.getStorageType(), namedAttr->name); // TODO: This should use getter method to avoid duplication. if (attr.hasDefaultValue()) { @@ -353,7 +466,7 @@ // should just capture a mlir::Attribute() to signal the missing state. // That is precisely what getAttr() returns on missing attributes. } else { - emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx), + emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), formatv("\"expected op '{0}' to have attribute '{1}' " "of type '{2}'\"", op.getOperationName(), namedAttr->name, @@ -371,7 +484,7 @@ // If a constraint is specified, we need to generate C++ statements to // check the constraint. emitMatchCheck( - depth, + opName, tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " "{2}\"", @@ -390,19 +503,22 @@ } void PatternEmitter::emitMatchCheck( - int depth, const FmtObjectBase &matchFmt, + StringRef opName, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt) { - os << "if (!(" << matchFmt.str() << "))"; - os.scope("{\n", "\n}\n").os - << "return rewriter.notifyMatchFailure(op" << depth - << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str() - << ";\n});"; + const char *matchStr = R"( + if (!({1})) { + return rewriter.notifyMatchFailure({0}, [&](::mlir::Diagnostic &diag) { + diag << {2}; + }); + })"; + os << llvm::formatv(matchStr, opName, matchFmt.str(), failureFmt.str()) + << "\n"; } -void PatternEmitter::emitMatchLogic(DagNode tree) { +void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); int depth = 0; - emitOpMatch(tree, depth); + emitMatch(tree, opName, depth); for (auto &appliedConstraint : pattern.getConstraints()) { auto &constraint = appliedConstraint.constraint; @@ -413,7 +529,7 @@ auto self = formatv("({0}.getType())", symbolInfoMap.getValueAndRangeUse(entities.front())); emitMatchCheck( - depth, tgfmt(condition, &fmtCtx.withSelf(self.str())), + opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", entities.front(), constraint.getDescription())); @@ -435,7 +551,7 @@ self = symbolInfoMap.getValueAndRangeUse(self); for (; i < 4; ++i) names.push_back(""); - emitMatchCheck(depth, + emitMatchCheck(opName, tgfmt(condition, &fmtCtx.withSelf(self), names[0], names[1], names[2], names[3]), formatv("\"entities '{0}' failed to satisfy constraint: " @@ -530,7 +646,7 @@ os << "// Match\n"; os << "tblgen_ops[0] = op0;\n"; - emitMatchLogic(sourceTree); + emitMatchLogic(sourceTree, "op0"); os << "\n// Rewrite\n"; emitRewriteLogic(); @@ -644,7 +760,7 @@ } if (resultTree.isNativeCodeCall()) { - auto symbol = handleReplaceWithNativeCodeCall(resultTree); + auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); symbolInfoMap.bindValue(symbol); return symbol; } @@ -761,7 +877,8 @@ PrintFatalError(loc, "unhandled case when rewriting op"); } -std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { +std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, + int depth) { LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); LLVM_DEBUG(tree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); @@ -770,15 +887,20 @@ // TODO: replace formatv arguments with the exact specified args. SmallVector attrs(8); if (tree.getNumArgs() > 8) { - PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + - Twine(tree.getNumArgs())); + PrintFatalError(loc, + "unsupported NativeCodeCall replace argument numbers: " + + Twine(tree.getNumArgs())); } bool hasLocationDirective; std::string locToUse; std::tie(hasLocationDirective, locToUse) = getLocation(tree); for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { - attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); + if (tree.isNestedDagArg(i)) { + attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1); + } else { + attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); + } LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i << " replacement: " << attrs[i] << "\n"); } @@ -887,7 +1009,7 @@ // create the ops. // First prepare local variables for op arguments used in builder call. - createAggregateLocalVarsForOpArgs(tree, childNodeNames); + createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); // Then create the op. os.scope("", "\n}\n").os << formatv( @@ -911,7 +1033,7 @@ os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, resultOp.getQualCppClassName(), locToUse); - supplyValuesForOpArgs(tree, childNodeNames); + supplyValuesForOpArgs(tree, childNodeNames, depth); os << "\n );\n}\n"; return resultValue; } @@ -922,7 +1044,7 @@ // here. // First prepare local variables for op arguments used in builder call. - createAggregateLocalVarsForOpArgs(tree, childNodeNames); + createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); // Then prepare the result types. We need to specify the types for all // results. @@ -1000,7 +1122,7 @@ } void PatternEmitter::supplyValuesForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames) { + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); argIndex != numOpArgs; ++argIndex) { @@ -1023,7 +1145,7 @@ PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv("/*{0}=*/{1}", opArgName, - handleReplaceWithNativeCodeCall(subTree)); + handleReplaceWithNativeCodeCall(subTree, depth)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. @@ -1043,7 +1165,7 @@ } void PatternEmitter::createAggregateLocalVarsForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames) { + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); auto scope = os.scope(); @@ -1065,7 +1187,7 @@ PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv(addAttrCmd, opArgName, - handleReplaceWithNativeCodeCall(subTree)); + handleReplaceWithNativeCodeCall(subTree, depth + 1)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern.