diff --git a/mlir/include/mlir/Support/IndentedOstream.h b/mlir/include/mlir/Support/IndentedOstream.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Support/IndentedOstream.h @@ -0,0 +1,88 @@ +//===- IndentedOstream.h - raw ostream wrapper to indent --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// raw_ostream subclass that keeps track of indentation for textual output +// where indentation helps readability. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_INDENTEDOSTREAM_H_ +#define MLIR_SUPPORT_INDENTEDOSTREAM_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { + +/// raw_ostream subclass that simplifies indention a sequence of code. +class raw_indented_ostream : public raw_ostream { +public: + explicit raw_indented_ostream(llvm::raw_ostream &os) : os(os) { + SetUnbuffered(); + } + + /// Simple RAII struct to use to indentation around entering/exiting region. + struct Region { + Region(raw_indented_ostream &ros) : os(ros) { os.increaseIndent(); } + ~Region() { os.decreaseIndent(); } + raw_indented_ostream &os; + }; + + /// Returns Region on this. + Region region() { return Region(*this); } + + /// Re-indent by removing the leading whitespace from the first non-empty line + /// from every line of the the string, skipping over empty lines at the start. + raw_indented_ostream &reindent(StringRef str); + + /// Increasing the indent and returning this raw_indented_ostream. + raw_indented_ostream &increaseIndent() { + currentIndent += indentSize; + return *this; + } + + /// Decreasing the indent and returning this raw_indented_ostream. + raw_indented_ostream &decreaseIndent() { + currentIndent -= indentSize; + return *this; + } + + /// Emits whitespace and sets the indendation for the stream. + raw_indented_ostream &indent(int with) { + os.indent(with); + atStartOfLine = false; + currentIndent = with; + return *this; + } + +private: + /// See raw_ostream::write_impl. + void write_impl(const char *ptr, size_t size) override; + + /// Return the current position within the stream, not counting the bytes + /// currently in the buffer. + uint64_t current_pos() const override { return os.tell(); } + + /// Constant indent added/removed. + static constexpr int indentSize = 2; + + // Tracker for current indentation. + int currentIndent = 0; + + // The leading whitespace of the string being printed, if reindent is uses. + int leadingWs = 0; + + // Tracks whether at start of line and so indent is required or not. + bool atStartOfLine = true; + + // The underlying raw_ostream. + raw_ostream &os; +}; + +} // namespace mlir +#endif // MLIR_SUPPORT_INDENTEDOSTREAM_H_ diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt --- a/mlir/lib/Support/CMakeLists.txt +++ b/mlir/lib/Support/CMakeLists.txt @@ -1,5 +1,6 @@ set(LLVM_OPTIONAL_SOURCES FileUtilities.cpp + IndentedOstream.cpp MlirOptMain.cpp StorageUniquer.cpp ToolUtilities.cpp @@ -27,3 +28,10 @@ MLIRParser MLIRSupport ) + +# This doesn't use add_mlir_library as it is used in mlir-tblgen and else +# mlir-tblgen ends up depending on mlir-generic-headers. +add_llvm_library(MLIRSupportIdentedOstream + IndentedOstream.cpp + + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Support) diff --git a/mlir/lib/Support/IndentedOstream.cpp b/mlir/lib/Support/IndentedOstream.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Support/IndentedOstream.cpp @@ -0,0 +1,65 @@ +//===- IndentedOstream.cpp - raw ostream wrapper to indent ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// raw_ostream subclass that keeps track of indentation for textual output +// where indentation helps readability. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/IndentedOstream.h" + +using namespace mlir; + +raw_indented_ostream &mlir::raw_indented_ostream::reindent(StringRef str) { + StringRef remaining = str; + // Find leading whitespace indent. + while (!remaining.empty()) { + auto split = remaining.split('\n'); + size_t indent = split.first.find_first_not_of(" \t"); + if (indent != StringRef::npos) { + leadingWs = indent; + break; + } + remaining = split.second; + } + // Print, skipping the empty lines. + *this << remaining; + leadingWs = 0; + return *this; +} + +void mlir::raw_indented_ostream::write_impl(const char *ptr, size_t size) { + StringRef str(ptr, size); + // Print out indented. + auto print = [this](StringRef str) { + if (atStartOfLine) + os.indent(currentIndent) << str.substr(leadingWs); + else + os << str.substr(leadingWs); + }; + + while (!str.empty()) { + size_t idx = str.find('\n'); + if (idx == StringRef::npos) { + if (!str.substr(leadingWs).empty()) { + print(str); + atStartOfLine = false; + } + break; + } + + auto split = + std::make_pair(str.slice(0, idx), str.slice(idx + 1, StringRef::npos)); + // Print empty new line without spaces if line only has spaces. + if (!split.first.ltrim().empty()) + print(split.first); + os << '\n'; + atStartOfLine = true; + str = split.second; + } +} diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -24,6 +24,7 @@ set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning") target_link_libraries(mlir-tblgen PRIVATE + MLIRSupportIdentedOstream MLIRTableGen) mlir_check_all_link_libraries(mlir-tblgen) diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "DocGenUtilities.h" +#include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/DenseMap.h" @@ -35,39 +36,8 @@ // in a way the user wanted but has some additional indenting due to being // nested in the op definition. void mlir::tblgen::emitDescription(StringRef description, raw_ostream &os) { - // Determine the minimum number of spaces in a line. - size_t min_indent = -1; - StringRef remaining = description; - while (!remaining.empty()) { - auto split = remaining.split('\n'); - size_t indent = split.first.find_first_not_of(" \t"); - if (indent != StringRef::npos) - min_indent = std::min(indent, min_indent); - remaining = split.second; - } - - // Print out the description indented. - os << "\n"; - remaining = description; - bool printed = false; - while (!remaining.empty()) { - auto split = remaining.split('\n'); - if (split.second.empty()) { - // Skip last line with just spaces. - if (split.first.ltrim().empty()) - break; - } - // Print empty new line without spaces if line only has spaces, unless no - // text has been emitted before. - if (split.first.ltrim().empty()) { - if (printed) - os << "\n"; - } else { - os << split.first.substr(min_indent) << "\n"; - printed = true; - } - remaining = split.second; - } + raw_indented_ostream ros(os); + ros.reindent(description.rtrim(" \t")); } // Emits `str` with trailing newline if not empty. @@ -116,7 +86,7 @@ // Emit the summary, syntax, and description if present. if (op.hasSummary()) - os << "\n" << op.getSummary() << "\n"; + os << "\n" << op.getSummary() << "\n\n"; if (op.hasAssemblyFormat()) emitAssemblyFormat(op.getOperationName(), op.getAssemblyFormat().trim(), os); @@ -228,7 +198,7 @@ } os << "\n"; - for (auto dialectWithOps : dialectOps) + for (const auto &dialectWithOps : dialectOps) emitDialectDoc(dialectWithOps.first, dialectWithOps.second, dialectTypes[dialectWithOps.first], os); } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" @@ -77,11 +78,11 @@ // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an operand. - void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent); + void emitOperandMatch(DagNode tree, int argIndex, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. - void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent); + void emitAttributeMatch(DagNode tree, int argIndex, int depth); // Emits C++ for checking a match with a corresponding match failure // diagnostic. @@ -181,7 +182,7 @@ // The next unused ID for newly created values. unsigned nextValueId; - raw_ostream &os; + raw_indented_ostream os; // Format contexts containing placeholder substitutions. FmtContext fmtCtx; @@ -222,8 +223,7 @@ // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { // Skip if there is no defining operation (e.g., arguments to function). - os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n", - depth); + os << formatv("if (!castedOp{0})\n return failure();\n", depth); } if (tree.getNumArgs() != op.getNumArgs()) { PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " @@ -235,7 +235,7 @@ // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); if (!name.empty()) - os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); + os << formatv("{0} = castedOp{1};\n", name, depth); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); @@ -250,24 +250,23 @@ PrintFatalError(loc, error); } } - os.indent(indent) << "{\n"; + os << "{\n"; - os.indent(indent + 2) << formatv( + os.increaseIndent() << formatv( "auto *op{0} = " "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n", depth + 1, depth, i); emitOpMatch(argTree, depth + 1); - os.indent(indent + 2) - << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); - os.indent(indent) << "}\n"; + os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); + os.decreaseIndent() << "}\n"; continue; } // Next handle DAG leaf: operand or attribute if (opArg.is()) { - emitOperandMatch(tree, i, depth, indent); + emitOperandMatch(tree, i, depth); } else if (opArg.is()) { - emitAttributeMatch(tree, i, depth, indent); + emitAttributeMatch(tree, i, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); } @@ -277,8 +276,7 @@ << '\n'); } -void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth, - int indent) { +void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *operand = op.getArg(argIndex).get(); auto matcher = tree.getArgAsLeaf(argIndex); @@ -325,30 +323,28 @@ op.arg_begin(), op.arg_begin() + argIndex, [](const Argument &arg) { return arg.is(); }); - os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", - name, depth, argIndex - numPrevAttrs); + os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth, + argIndex - numPrevAttrs); } } -void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth, - int indent) { +void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *namedAttr = op.getArg(argIndex).get(); const auto &attr = namedAttr->attr; - os.indent(indent) << "{\n"; - indent += 2; - os.indent(indent) << formatv( - "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");" + os << "{\n"; + os.increaseIndent() << formatv( + "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); " "(void)tblgen_attr;\n", depth, attr.getStorageType(), namedAttr->name); // TODO: This should use getter method to avoid duplication. if (attr.hasDefaultValue()) { - os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " - << std::string(tgfmt(attr.getConstBuilderTemplate(), - &fmtCtx, attr.getDefaultValue())) - << ";\n"; + os << "if (!tblgen_attr) tblgen_attr = " + << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, + attr.getDefaultValue())) + << ";\n"; } else if (attr.isOptional()) { // For a missing attribute that is optional according to definition, we // should just capture a mlir::Attribute() to signal the missing state. @@ -384,27 +380,23 @@ auto name = tree.getArgName(argIndex); // `$_` is a special symbol to ignore op argument matching. if (!name.empty() && name != "_") { - os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); + os << formatv("{0} = tblgen_attr;\n", name); } - indent -= 2; - os.indent(indent) << "}\n"; + os.decreaseIndent() << "}\n"; } void PatternEmitter::emitMatchCheck( int depth, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt) { - // {0} The match depth (used to get the operation that failed to match). - // {1} The format for the match string. - // {2} The format for the failure string. - const char *matchStr = R"( - if (!({1})) { - return rewriter.notifyMatchFailure(op{0}, [&](::mlir::Diagnostic &diag) { - diag << {2}; - }); - })"; - os << llvm::formatv(matchStr, depth, matchFmt.str(), failureFmt.str()) - << "\n"; + os << "if (!(" << matchFmt.str() << ")) {\n"; + { + auto ifRegion = os.region(); + os << "return rewriter.notifyMatchFailure(op" << depth + << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str() + << ";\n});"; + } + os << "\n}\n"; } void PatternEmitter::emitMatchLogic(DagNode tree) { @@ -488,7 +480,7 @@ // Emit RewritePattern for Pattern. auto locs = pattern.getLocation(); - os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", + os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", make_range(locs.rbegin(), locs.rend())); os << formatv(R"(struct {0} : public ::mlir::RewritePattern { {0}(::mlir::MLIRContext *context) @@ -506,44 +498,48 @@ os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; // Emit matchAndRewrite() function. - os << R"( - ::mlir::LogicalResult - matchAndRewrite(::mlir::Operation *op0, - ::mlir::PatternRewriter &rewriter) const override { -)"; - - // Register all symbols bound in the source pattern. - pattern.collectSourcePatternBoundSymbols(symbolInfoMap); - - LLVM_DEBUG( - llvm::dbgs() << "start creating local variables for capturing matches\n"); - os.indent(4) << "// Variables for capturing values and attributes used for " - "creating ops\n"; - // Create local variables for storing the arguments and results bound - // to symbols. - for (const auto &symbolInfoPair : symbolInfoMap) { - StringRef symbol = symbolInfoPair.getKey(); - auto &info = symbolInfoPair.getValue(); - os.indent(4) << info.getVarDecl(symbol); + { + auto classRegion = os.region(); + os.reindent(R"( + ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override {)") + << '\n'; + { + auto functionRegion = os.region(); + + // Register all symbols bound in the source pattern. + pattern.collectSourcePatternBoundSymbols(symbolInfoMap); + + LLVM_DEBUG(llvm::dbgs() + << "start creating local variables for capturing matches\n"); + os << "// Variables for capturing values and attributes used while " + "creating ops\n"; + // Create local variables for storing the arguments and results bound + // to symbols. + for (const auto &symbolInfoPair : symbolInfoMap) { + StringRef symbol = symbolInfoPair.getKey(); + auto &info = symbolInfoPair.getValue(); + os << info.getVarDecl(symbol); + } + // TODO: capture ops with consistent numbering so that it can be + // reused for fused loc. + os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", + pattern.getSourcePattern().getNumOps()); + LLVM_DEBUG(llvm::dbgs() + << "done creating local variables for capturing matches\n"); + + os << "// Match\n"; + os << "tblgen_ops[0] = op0;\n"; + emitMatchLogic(sourceTree); + + os << "\n// Rewrite\n"; + emitRewriteLogic(); + + os << "return success();\n"; + } + os << "};\n"; } - // TODO: capture ops with consistent numbering so that it can be - // reused for fused loc. - os.indent(4) << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", - pattern.getSourcePattern().getNumOps()); - LLVM_DEBUG( - llvm::dbgs() << "done creating local variables for capturing matches\n"); - - os.indent(4) << "// Match\n"; - os.indent(4) << "tblgen_ops[0] = op0;\n"; - emitMatchLogic(sourceTree); - os << "\n"; - - os.indent(4) << "// Rewrite\n"; - emitRewriteLogic(); - - os.indent(4) << "return success();\n"; - os << " };\n"; - os << "};\n"; + os << "};\n\n"; } void PatternEmitter::emitRewriteLogic() { @@ -583,7 +579,7 @@ PrintFatalError(loc, error); } - os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({"; + os << "auto odsLoc = rewriter.getFusedLoc({"; for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; } @@ -598,22 +594,21 @@ // we are handling auxiliary patterns so we want the side effect even if // NativeCodeCall is not replacing matched root op's results. if (resultTree.isNativeCodeCall()) - os.indent(4) << val << ";\n"; + os << val << ";\n"; } if (numExpectedResults == 0) { assert(replStartIndex >= numResultPatterns && "invalid auxiliary vs. replacement pattern division!"); // No result to replace. Just erase the op. - os.indent(4) << "rewriter.eraseOp(op0);\n"; + os << "rewriter.eraseOp(op0);\n"; } else { // Process replacement result patterns. - os.indent(4) - << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; + os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; for (int i = replStartIndex; i < numResultPatterns; ++i) { DagNode resultTree = pattern.getResultPattern(i); auto val = handleResultPattern(resultTree, offsets[i], 0); - os.indent(4) << "\n"; + os << "\n"; // Resolve each symbol for all range use so that we can loop over them. // We need an explicit cast to `SmallVector` to capture the cases where // `{0}` resolves to an `Operation::result_range` as well as cases that @@ -622,12 +617,11 @@ // TODO: Revisit the need for materializing a vector. os << symbolInfoMap.getAllRangeUse( val, - " for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{ " - "tblgen_repl_values.push_back(v); }", + "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" + " tblgen_repl_values.push_back(v);\n}\n", "\n"); } - os.indent(4) << "\n"; - os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n"; + os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; } LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); @@ -863,9 +857,8 @@ } // Create the local variable for this op. - os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), - valuePackName); - os.indent(4) << "{\n"; + os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), + valuePackName); // Right now ODS don't have general type inference support. Except a few // special cases listed below, DRR needs to supply types for all results @@ -883,10 +876,10 @@ createAggregateLocalVarsForOpArgs(tree, childNodeNames); // Then create the op. - os.indent(6) << formatv( + os.region().os << formatv( "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n", valuePackName, resultOp.getQualCppClassName(), locToUse); - os.indent(4) << "}\n"; + os << "}\n"; return resultValue; } @@ -903,11 +896,10 @@ // aggregate-parameter builders. createSeparateLocalVarsForOpArgs(tree, childNodeNames); - os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, - resultOp.getQualCppClassName(), locToUse); + os.region().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, + resultOp.getQualCppClassName(), locToUse); supplyValuesForOpArgs(tree, childNodeNames); - os << "\n );\n"; - os.indent(4) << "}\n"; + os << "\n );\n}\n"; return resultValue; } @@ -921,20 +913,20 @@ // Then prepare the result types. We need to specify the types for all // results. - os.indent(6) << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " - "(void)tblgen_types;\n"); + os.increaseIndent() << formatv( + "::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " + "(void)tblgen_types;\n"); int numResults = resultOp.getNumResults(); if (numResults != 0) { for (int i = 0; i < numResults; ++i) - os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{" - "tblgen_types.push_back(v.getType()); }\n", - resultIndex + i); + os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" + " tblgen_types.push_back(v.getType());\n}\n", + resultIndex + i); } - os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " - "tblgen_values, tblgen_attrs);\n", - valuePackName, resultOp.getQualCppClassName(), - locToUse); - os.indent(4) << "}\n"; + os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " + "tblgen_values, tblgen_attrs);\n", + valuePackName, resultOp.getQualCppClassName(), locToUse); + os.decreaseIndent() << "}\n"; return resultValue; } @@ -951,16 +943,15 @@ for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { const auto *operand = resultOp.getArg(argIndex).dyn_cast(); - if (!operand) { - // We do not need special handling for attributes. + // We do not need special handling for attributes. + if (!operand) continue; - } + auto region = os.region(); std::string varName; if (operand->isVariadic()) { varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); - os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", - varName); + os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName); std::string range; if (node.isNestedDagArg(argIndex)) { range = childNodeNames[argIndex]; @@ -970,11 +961,11 @@ // Resolve the symbol for all range use so that we have a uniform way of // capturing the values. range = symbolInfoMap.getValueAndRangeUse(range); - os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range, - varName); + os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, + varName); } else { varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); - os.indent(6) << formatv("::mlir::Value {0} = ", varName); + os << formatv("::mlir::Value {0} = ", varName); if (node.isNestedDagArg(argIndex)) { os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); } else { @@ -1002,7 +993,7 @@ for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); argIndex != numOpArgs; ++argIndex) { // Start each argument on its own line. - (os << ",\n").indent(8); + os << ",\n "; Argument opArg = resultOp.getArg(argIndex); // Handle the case of operand first. @@ -1043,14 +1034,16 @@ DagNode node, const ChildNodeIndexNameMap &childNodeNames) { Operator &resultOp = node.getDialectOp(opMap); - os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> " - "tblgen_values; (void)tblgen_values;\n"); - os.indent(6) << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " - "tblgen_attrs; (void)tblgen_attrs;\n"); + auto region = os.region(); + os << formatv("::mlir::SmallVector<::mlir::Value, 4> " + "tblgen_values; (void)tblgen_values;\n"); + os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " + "tblgen_attrs; (void)tblgen_attrs;\n"); const char *addAttrCmd = - "if (auto tmpAttr = {1}) " - "tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n"; + "if (auto tmpAttr = {1}) {\n" + " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " + "tmpAttr);\n}\n"; for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { if (resultOp.getArg(argIndex).is()) { // The argument in the op definition. @@ -1059,14 +1052,14 @@ if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); - os.indent(6) << formatv(addAttrCmd, opArgName, - handleReplaceWithNativeCodeCall(subTree)); + os << formatv(addAttrCmd, opArgName, + handleReplaceWithNativeCodeCall(subTree)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. auto patArgName = node.getArgName(argIndex); - os.indent(6) << formatv(addAttrCmd, opArgName, - handleOpArgument(leaf, patArgName)); + os << formatv(addAttrCmd, opArgName, + handleOpArgument(leaf, patArgName)); } continue; } @@ -1084,10 +1077,10 @@ // Resolve the symbol for all range use so that we have a uniform way of // capturing the values. range = symbolInfoMap.getValueAndRangeUse(range); - os.indent(6) << formatv( - "for (auto v : {0}) tblgen_values.push_back(v);\n", range); + os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", + range); } else { - os.indent(6) << formatv("tblgen_values.push_back(", varName); + os << formatv("tblgen_values.push_back(", varName); if (node.isNestedDagArg(argIndex)) { os << symbolInfoMap.getValueAndRangeUse( childNodeNames.lookup(argIndex)); diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Support/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_unittest(MLIRSupportTests + IndentedOstreamTest.cpp +) + +target_link_libraries(MLIRSupportTests + PRIVATE MLIRSupportIdentedOstream MLIRSupport) diff --git a/mlir/unittests/Support/IndentedOstreamTest.cpp b/mlir/unittests/Support/IndentedOstreamTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Support/IndentedOstreamTest.cpp @@ -0,0 +1,102 @@ +//===- IndentedOstreamTest.cpp - Indented raw ostream Tests ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/IndentedOstream.h" +#include "gmock/gmock.h" + +using namespace mlir; +using ::testing::StrEq; + +TEST(FormatTest, SingleLine) { + std::string str; + llvm::raw_string_ostream os(str); + raw_indented_ostream ros(os); + ros << 10; + ros.flush(); + EXPECT_THAT(os.str(), StrEq("10")); +} + +TEST(FormatTest, SimpleMultiLine) { + std::string str; + llvm::raw_string_ostream os(str); + raw_indented_ostream ros(os); + ros << "a"; + ros << "b"; + ros << "\n"; + ros << "c"; + ros << "\n"; + ros.flush(); + EXPECT_THAT(os.str(), StrEq("ab\nc\n")); +} + +TEST(FormatTest, SimpleMultiLineIndent) { + std::string str; + llvm::raw_string_ostream os(str); + raw_indented_ostream ros(os); + ros.indent(2) << "a"; + ros.indent(4) << "b"; + ros << "\n"; + ros << "c"; + ros << "\n"; + ros.flush(); + EXPECT_THAT(os.str(), StrEq(" a b\n c\n")); +} + +TEST(FormatTest, SingleRegion) { + std::string str; + llvm::raw_string_ostream os(str); + raw_indented_ostream ros(os); + ros << "before\n"; + { + raw_indented_ostream::Region region(ros); + ros << "inside " << 10; + ros << "\n two\n"; + } + ros << "after"; + ros.flush(); + const auto *expected = + R"(before + inside 10 + two +after)"; + EXPECT_THAT(os.str(), StrEq(expected)); + + // Repeat the above with inline form. + str.clear(); + ros << "before\n"; + ros.region().os << "inside " << 10 << "\n two\n"; + ros << "after"; + ros.flush(); + EXPECT_THAT(os.str(), StrEq(expected)); +} + +TEST(FormatTest, Reindent) { + std::string str; + llvm::raw_string_ostream os(str); + raw_indented_ostream ros(os); + + // String to print with some additional empty lines at the start and lines + // with just spaces. + const auto *desc = R"( + + + First line + second line + + + )"; + ros.reindent(desc); + ros.flush(); + const auto *expected = + R"(First line + second line + + +)"; + EXPECT_THAT(os.str(), StrEq(expected)); +}