diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h --- a/mlir/include/mlir/TableGen/Format.h +++ b/mlir/include/mlir/TableGen/Format.h @@ -186,6 +186,20 @@ } }; +class FmtStrVecObject : public FmtObjectBase { +public: + using StrFormatAdapter = + decltype(llvm::detail::build_format_adapter(std::declval())); + + FmtStrVecObject(StringRef fmt, const FmtContext *ctx, + ArrayRef params); + FmtStrVecObject(FmtStrVecObject const &that) = delete; + FmtStrVecObject(FmtStrVecObject &&that); + +private: + SmallVector parameters; +}; + /// Formats text by substituting placeholders in format string with replacement /// parameters. /// @@ -234,6 +248,11 @@ llvm::detail::build_format_adapter(std::forward(vals))...)); } +inline FmtStrVecObject tgfmt(StringRef fmt, const FmtContext *ctx, + ArrayRef params) { + return FmtStrVecObject(fmt, ctx, params); +} + } // end namespace tblgen } // end namespace mlir diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp --- a/mlir/lib/TableGen/Format.cpp +++ b/mlir/lib/TableGen/Format.cpp @@ -173,3 +173,22 @@ adapters[repl.index]->format(s, /*Options=*/""); } } + +FmtStrVecObject::FmtStrVecObject(StringRef fmt, const FmtContext *ctx, + ArrayRef params) + : FmtObjectBase(fmt, ctx, params.size()) { + parameters.reserve(params.size()); + for (std::string p : params) + parameters.push_back(llvm::detail::build_format_adapter(std::move(p))); + + adapters.reserve(parameters.size()); + for (auto &p : parameters) + adapters.push_back(&p); +} + +FmtStrVecObject::FmtStrVecObject(FmtStrVecObject &&that) + : FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) { + adapters.reserve(parameters.size()); + for (auto &p : parameters) + adapters.push_back(&p); +} diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td --- a/mlir/test/mlir-tblgen/rewriter-indexing.td +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -58,3 +58,30 @@ def test3 : Pat<(BOp $attr, (AOp:$a $input)), (BOp $attr, (AOp $input), (location $a))>; +def DOp : NS_Op<"d_op", []> { + let arguments = (ins + AnyInteger:$v1, + AnyInteger:$v2, + AnyInteger:$v3, + AnyInteger:$v4, + AnyInteger:$v5, + AnyInteger:$v6, + AnyInteger:$v7, + AnyInteger:$v8, + AnyInteger:$v9, + AnyInteger:$v10 + ); + + let results = (outs AnyInteger); +} + +def NativeBuilder : + NativeCodeCall<[{ + nativeCall($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9) + }]>; + +// Check Pattern with large number of DAG arguments passed to NativeCodeCall +// CHECK: struct test4 : public ::mlir::RewritePattern { +// CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin())) +def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), + (NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -251,12 +251,8 @@ // TODO(suderman): iterate through arguments, determine their types, output // names. - SmallVector capture(8); - if (tree.getNumArgs() > 8) { - PrintFatalError(loc, - "unsupported NativeCodeCall matcher argument numbers: " + - Twine(tree.getNumArgs())); - } + SmallVector capture; + capture.push_back(opName.str()); raw_indented_ostream::DelimitedScope scope(os); @@ -274,7 +270,7 @@ } } - capture[i] = std::move(argName); + capture.push_back(std::move(argName)); } bool hasLocationDirective; @@ -282,21 +278,20 @@ std::tie(hasLocationDirective, locToUse) = getLocation(tree); auto fmt = tree.getNativeCodeTemplate(); - auto nativeCodeCall = std::string(tgfmt( - fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1], - capture[2], capture[3], capture[4], capture[5], capture[6], capture[7])); + auto nativeCodeCall = + std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture)); os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n"; for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto name = tree.getArgName(i); if (!name.empty() && name != "_") { - os << formatv("{0} = {1};\n", name, capture[i]); + os << formatv("{0} = {1};\n", name, capture[i + 1]); } } for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { - std::string argName = capture[i]; + std::string argName = capture[i + 1]; // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { @@ -915,29 +910,26 @@ LLVM_DEBUG(llvm::dbgs() << '\n'); auto fmt = tree.getNativeCodeTemplate(); - // TODO: replace formatv arguments with the exact specified args. - SmallVector attrs(8); - if (tree.getNumArgs() > 8) { - PrintFatalError(loc, - "unsupported NativeCodeCall replace argument numbers: " + - Twine(tree.getNumArgs())); - } + + SmallVector attrs; + bool hasLocationDirective; std::string locToUse; std::tie(hasLocationDirective, locToUse) = getLocation(tree); for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { if (tree.isNestedDagArg(i)) { - attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1); + attrs.push_back( + handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1)); } else { - attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); + attrs.push_back( + handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))); } LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i << " replacement: " << attrs[i] << "\n"); } - return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0], - attrs[1], attrs[2], attrs[3], attrs[4], attrs[5], - attrs[6], attrs[7])); + + return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs)); } int PatternEmitter::getNodeValueCount(DagNode node) {