diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -248,14 +248,7 @@ p << getOperationName() << ' ' << getAttr(inType()); if (hasLenParams()) { // print the LEN parameters to a derived type in parens - p << '('; - p.printOperands(getLenParams()); - p << " : "; - mlir::interleaveComma(getLenParams(), p.getStream(), - [&](const auto &opnd) { - p.printType(opnd.getType()); - }); - p << ')'; + p << '(' << getLenParams() << " : " << getLenParams().getTypes() << ')'; } // print the shape of the allocation (if any); all must be index type for (auto sh : getShapeOperands()) { diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -50,6 +50,10 @@ template using IterOfRange = decltype(std::begin(std::declval())); +template +using ValueOfRange = typename std::remove_reference()))>::type; + } // end namespace detail //===----------------------------------------------------------------------===// @@ -1658,6 +1662,69 @@ replace(Cont, ContIt, ContEnd, R.begin(), R.end()); } +/// An STL-style algorithm similar to std::for_each that applies a second +/// functor between every pair of elements. +/// +/// This provides the control flow logic to, for example, print a +/// comma-separated list: +/// \code +/// interleave(names.begin(), names.end(), +/// [&](StringRef name) { os << name; }, +/// [&] { os << ", "; }); +/// \endcode +template ::value && + !std::is_constructible::value>::type> +inline void interleave(ForwardIterator begin, ForwardIterator end, + UnaryFunctor each_fn, NullaryFunctor between_fn) { + if (begin == end) + return; + each_fn(*begin); + ++begin; + for (; begin != end; ++begin) { + between_fn(); + each_fn(*begin); + } +} + +template ::value && + !std::is_constructible::value>::type> +inline void interleave(const Container &c, UnaryFunctor each_fn, + NullaryFunctor between_fn) { + interleave(c.begin(), c.end(), each_fn, between_fn); +} + +/// Overload of interleave for the common case of string separator. +template > +inline void interleave(const Container &c, StreamT &os, UnaryFunctor each_fn, + const StringRef &separator) { + interleave(c.begin(), c.end(), each_fn, [&] { os << separator; }); +} +template > +inline void interleave(const Container &c, StreamT &os, + const StringRef &separator) { + interleave( + c, os, [&](const T &a) { os << a; }, separator); +} + +template > +inline void interleaveComma(const Container &c, StreamT &os, + UnaryFunctor each_fn) { + interleave(c, os, each_fn, ", "); +} +template > +inline void interleaveComma(const Container &c, StreamT &os) { + interleaveComma(c, os, [&](const T &a) { os << a; }); +} + //===----------------------------------------------------------------------===// // Extra additions to //===----------------------------------------------------------------------===// diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -61,7 +61,7 @@ ```c++ mlir::ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); ``` diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md --- a/mlir/docs/Tutorials/Toy/Ch-7.md +++ b/mlir/docs/Tutorials/Toy/Ch-7.md @@ -319,7 +319,7 @@ // Print the struct type according to the parser format. printer << "struct<"; - mlir::interleaveComma(structType.getElementTypes(), printer); + llvm::interleaveComma(structType.getElementTypes(), printer); printer << '>'; } ``` diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp --- a/mlir/examples/toy/Ch1/parser/AST.cpp +++ b/mlir/examples/toy/Ch1/parser/AST.cpp @@ -127,12 +127,12 @@ // Print the dimension for this literal first llvm::errs() << "<"; - mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - mlir::interleaveComma(literal->getValues(), llvm::errs(), + llvm::interleaveComma(literal->getValues(), llvm::errs(), [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } @@ -194,7 +194,7 @@ /// Print type: only the shape is printed in between '<' and '>' void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - mlir::interleaveComma(type.shape, llvm::errs()); + llvm::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } @@ -205,7 +205,7 @@ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - mlir::interleaveComma(node->getArgs(), llvm::errs(), + llvm::interleaveComma(node->getArgs(), llvm::errs(), [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } diff --git a/mlir/examples/toy/Ch2/parser/AST.cpp b/mlir/examples/toy/Ch2/parser/AST.cpp --- a/mlir/examples/toy/Ch2/parser/AST.cpp +++ b/mlir/examples/toy/Ch2/parser/AST.cpp @@ -127,12 +127,12 @@ // Print the dimension for this literal first llvm::errs() << "<"; - mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - mlir::interleaveComma(literal->getValues(), llvm::errs(), + llvm::interleaveComma(literal->getValues(), llvm::errs(), [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } @@ -194,7 +194,7 @@ /// Print type: only the shape is printed in between '<' and '>' void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - mlir::interleaveComma(type.shape, llvm::errs()); + llvm::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } @@ -205,7 +205,7 @@ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - mlir::interleaveComma(node->getArgs(), llvm::errs(), + llvm::interleaveComma(node->getArgs(), llvm::errs(), [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp --- a/mlir/examples/toy/Ch3/parser/AST.cpp +++ b/mlir/examples/toy/Ch3/parser/AST.cpp @@ -127,12 +127,12 @@ // Print the dimension for this literal first llvm::errs() << "<"; - mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - mlir::interleaveComma(literal->getValues(), llvm::errs(), + llvm::interleaveComma(literal->getValues(), llvm::errs(), [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } @@ -194,7 +194,7 @@ /// Print type: only the shape is printed in between '<' and '>' void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - mlir::interleaveComma(type.shape, llvm::errs()); + llvm::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } @@ -205,7 +205,7 @@ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - mlir::interleaveComma(node->getArgs(), llvm::errs(), + llvm::interleaveComma(node->getArgs(), llvm::errs(), [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } diff --git a/mlir/examples/toy/Ch4/parser/AST.cpp b/mlir/examples/toy/Ch4/parser/AST.cpp --- a/mlir/examples/toy/Ch4/parser/AST.cpp +++ b/mlir/examples/toy/Ch4/parser/AST.cpp @@ -127,12 +127,12 @@ // Print the dimension for this literal first llvm::errs() << "<"; - mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - mlir::interleaveComma(literal->getValues(), llvm::errs(), + llvm::interleaveComma(literal->getValues(), llvm::errs(), [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } @@ -194,7 +194,7 @@ /// Print type: only the shape is printed in between '<' and '>' void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - mlir::interleaveComma(type.shape, llvm::errs()); + llvm::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } @@ -205,7 +205,7 @@ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - mlir::interleaveComma(node->getArgs(), llvm::errs(), + llvm::interleaveComma(node->getArgs(), llvm::errs(), [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } diff --git a/mlir/examples/toy/Ch5/parser/AST.cpp b/mlir/examples/toy/Ch5/parser/AST.cpp --- a/mlir/examples/toy/Ch5/parser/AST.cpp +++ b/mlir/examples/toy/Ch5/parser/AST.cpp @@ -127,12 +127,12 @@ // Print the dimension for this literal first llvm::errs() << "<"; - mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - mlir::interleaveComma(literal->getValues(), llvm::errs(), + llvm::interleaveComma(literal->getValues(), llvm::errs(), [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } @@ -194,7 +194,7 @@ /// Print type: only the shape is printed in between '<' and '>' void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - mlir::interleaveComma(type.shape, llvm::errs()); + llvm::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } @@ -205,7 +205,7 @@ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - mlir::interleaveComma(node->getArgs(), llvm::errs(), + llvm::interleaveComma(node->getArgs(), llvm::errs(), [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } diff --git a/mlir/examples/toy/Ch6/parser/AST.cpp b/mlir/examples/toy/Ch6/parser/AST.cpp --- a/mlir/examples/toy/Ch6/parser/AST.cpp +++ b/mlir/examples/toy/Ch6/parser/AST.cpp @@ -127,12 +127,12 @@ // Print the dimension for this literal first llvm::errs() << "<"; - mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - mlir::interleaveComma(literal->getValues(), llvm::errs(), + llvm::interleaveComma(literal->getValues(), llvm::errs(), [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } @@ -194,7 +194,7 @@ /// Print type: only the shape is printed in between '<' and '>' void ASTDumper::dump(const VarType &type) { llvm::errs() << "<"; - mlir::interleaveComma(type.shape, llvm::errs()); + llvm::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } @@ -205,7 +205,7 @@ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - mlir::interleaveComma(node->getArgs(), llvm::errs(), + llvm::interleaveComma(node->getArgs(), llvm::errs(), [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -537,7 +537,7 @@ // Print the struct type according to the parser format. printer << "struct<"; - mlir::interleaveComma(structType.getElementTypes(), printer); + llvm::interleaveComma(structType.getElementTypes(), printer); printer << '>'; } diff --git a/mlir/examples/toy/Ch7/parser/AST.cpp b/mlir/examples/toy/Ch7/parser/AST.cpp --- a/mlir/examples/toy/Ch7/parser/AST.cpp +++ b/mlir/examples/toy/Ch7/parser/AST.cpp @@ -130,12 +130,12 @@ // Print the dimension for this literal first llvm::errs() << "<"; - mlir::interleaveComma(literal->getDims(), llvm::errs()); + llvm::interleaveComma(literal->getDims(), llvm::errs()); llvm::errs() << ">"; // Now print the content, recursing on every element of the list llvm::errs() << "[ "; - mlir::interleaveComma(literal->getValues(), llvm::errs(), + llvm::interleaveComma(literal->getValues(), llvm::errs(), [&](auto &elt) { printLitHelper(elt.get()); }); llvm::errs() << "]"; } @@ -210,7 +210,7 @@ if (!type.name.empty()) llvm::errs() << type.name; else - mlir::interleaveComma(type.shape, llvm::errs()); + llvm::interleaveComma(type.shape, llvm::errs()); llvm::errs() << ">"; } @@ -221,7 +221,7 @@ llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n"; indent(); llvm::errs() << "Params: ["; - mlir::interleaveComma(node->getArgs(), llvm::errs(), + llvm::interleaveComma(node->getArgs(), llvm::errs(), [](auto &arg) { llvm::errs() << arg->getName(); }); llvm::errs() << "]\n"; } diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -232,9 +232,8 @@ /// is ','. template class Container> Diagnostic &appendRange(const Container &c, const char *delim = ", ") { - interleave( - c, [&](const detail::ValueOfRange> &a) { *this << a; }, - [&]() { *this << delim; }); + llvm::interleave( + c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; }); return *this; } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -117,7 +117,7 @@ (*types.begin()).template isa(); if (wrapped) os << '('; - interleaveComma(types, *this); + llvm::interleaveComma(types, *this); if (wrapped) os << ')'; } @@ -131,7 +131,7 @@ void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { auto &os = getStream(); os << "("; - interleaveComma(inputs, *this); + llvm::interleaveComma(inputs, *this); os << ")"; printArrowTypeList(results); } @@ -199,11 +199,11 @@ template inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const ValueTypeRange &types) { - interleaveComma(types, p); + llvm::interleaveComma(types, p); return p; } inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef types) { - interleaveComma(types, p); + llvm::interleaveComma(types, p); return p; } diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -194,7 +194,7 @@ auto printElementFn = [&](const DataType &value) { printValue(os, this->getParser(), value); }; - interleave(*this, os, printElementFn, ","); + llvm::interleave(*this, os, printElementFn, ","); } /// Copy the value from the given option into this one. diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -19,75 +19,6 @@ namespace mlir { -namespace detail { -template -using ValueOfRange = typename std::remove_reference()))>::type; -} // end namespace detail - -/// An STL-style algorithm similar to std::for_each that applies a second -/// functor between every pair of elements. -/// -/// This provides the control flow logic to, for example, print a -/// comma-separated list: -/// \code -/// interleave(names.begin(), names.end(), -/// [&](StringRef name) { os << name; }, -/// [&] { os << ", "; }); -/// \endcode -template ::value && - !std::is_constructible::value>::type> -inline void interleave(ForwardIterator begin, ForwardIterator end, - UnaryFunctor each_fn, NullaryFunctor between_fn) { - if (begin == end) - return; - each_fn(*begin); - ++begin; - for (; begin != end; ++begin) { - between_fn(); - each_fn(*begin); - } -} - -template ::value && - !std::is_constructible::value>::type> -inline void interleave(const Container &c, UnaryFunctor each_fn, - NullaryFunctor between_fn) { - interleave(c.begin(), c.end(), each_fn, between_fn); -} - -/// Overload of interleave for the common case of string separator. -template > -inline void interleave(const Container &c, raw_ostream &os, - UnaryFunctor each_fn, const StringRef &separator) { - interleave(c.begin(), c.end(), each_fn, [&] { os << separator; }); -} -template > -inline void interleave(const Container &c, raw_ostream &os, - const StringRef &separator) { - interleave( - c, os, [&](const T &a) { os << a; }, separator); -} - -template > -inline void interleaveComma(const Container &c, raw_ostream &os, - UnaryFunctor each_fn) { - interleave(c, os, each_fn, ", "); -} -template > -inline void interleaveComma(const Container &c, raw_ostream &os) { - interleaveComma(c, os, [&](const T &a) { os << a; }); -} - } // end namespace mlir #endif // MLIR_SUPPORT_STLEXTRAS_H diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2304,7 +2304,7 @@ } if (!elideSteps) { p << " step ("; - interleaveComma(steps, p); + llvm::interleaveComma(steps, p); p << ')'; } p.printRegion(op.region(), /*printEntryBlockArgs=*/false, diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -610,8 +610,8 @@ return; p << ' ' << keyword << '('; - interleaveComma(values, p, - [&p](BlockArgument v) { p << v << " : " << v.getType(); }); + llvm::interleaveComma( + values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); p << ')'; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1056,7 +1056,7 @@ appendMangledType(ss, memref.getElementType()); } else if (auto vec = t.dyn_cast()) { ss << "vector"; - interleave( + llvm::interleave( vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); appendMangledType(ss, vec.getElementType()); } else if (t.isSignlessIntOrIndexOrFloat()) { @@ -1074,7 +1074,7 @@ llvm::raw_string_ostream ss(name); ss << "_"; auto types = op->getOperandTypes(); - interleave( + llvm::interleave( types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, [&]() { ss << "_"; }); return ss.str(); diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -107,7 +107,7 @@ auto regionArgs = op.getRegionIterArgs(); auto operands = op.getIterOperands(); - mlir::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { + llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ")"; diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -354,7 +354,7 @@ ArrayRef scales = type.getScales(); ArrayRef zeroPoints = type.getZeroPoints(); out << "{"; - interleave( + llvm::interleave( llvm::seq(0, scales.size()), out, [&](size_t index) { printQuantParams(scales[index], zeroPoints[index], out); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -566,12 +566,12 @@ auto eachFn = [&os](spirv::Decoration decoration) { os << stringifyDecoration(decoration); }; - interleaveComma(decorations, os, eachFn); + llvm::interleaveComma(decorations, os, eachFn); os << "]"; } }; - interleaveComma(llvm::seq(0, type.getNumElements()), os, - printMember); + llvm::interleaveComma(llvm::seq(0, type.getNumElements()), os, + printMember); os << ">"; } @@ -764,11 +764,11 @@ auto &os = printer.getStream(); printer << spirv::VerCapExtAttr::getKindName() << "<" << spirv::stringifyVersion(triple.getVersion()) << ", ["; - interleaveComma(triple.getCapabilities(), os, [&](spirv::Capability cap) { - os << spirv::stringifyCapability(cap); - }); + llvm::interleaveComma( + triple.getCapabilities(), os, + [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); printer << "], ["; - interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) { + llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) { os << attr.cast().getValue(); }); printer << "]>"; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1064,7 +1064,7 @@ if (auto weights = branchOp.branch_weights()) { printer << " ["; - interleaveComma(weights->getValue(), printer, [&](Attribute a) { + llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { printer << a.cast().getInt(); }); printer << "]"; @@ -1465,7 +1465,7 @@ auto interfaceVars = entryPointOp.interface().getValue(); if (!interfaceVars.empty()) { printer << ", "; - interleaveComma(interfaceVars, printer); + llvm::interleaveComma(interfaceVars, printer); } } @@ -1521,7 +1521,7 @@ if (!values.size()) return; printer << ", "; - interleaveComma(values, printer, [&](Attribute a) { + llvm::interleaveComma(values, printer, [&](Attribute a) { printer << a.cast().getInt(); }); } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1512,7 +1512,7 @@ p.printOperands(op.getOperands()); p.printOptionalAttrDict(op.getAttrs()); p << " : "; - interleaveComma(op.getOperation()->getOperandTypes(), p); + llvm::interleaveComma(op.getOperation()->getOperandTypes(), p); } static LogicalResult verify(TupleOp op) { return success(); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -933,7 +933,7 @@ template inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { - mlir::interleaveComma(c, os, each_fn); + llvm::interleaveComma(c, os, each_fn); } /// This enum describes the different kinds of elision for the type of an diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -216,11 +216,12 @@ types.size() > 1 || types[0].isa() || !attrs[0].empty(); if (needsParens) os << '('; - interleaveComma(llvm::zip(types, attrs), os, - [&](const std::tuple> &t) { - p.printType(std::get<0>(t)); - p.printOptionalAttrDict(std::get<1>(t)); - }); + llvm::interleaveComma( + llvm::zip(types, attrs), os, + [&](const std::tuple> &t) { + p.printType(std::get<0>(t)); + p.printOptionalAttrDict(std::get<1>(t)); + }); if (needsParens) os << ')'; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -52,11 +52,12 @@ void Pass::printAsTextualPipeline(raw_ostream &os) { // Special case for adaptors to use the 'op_name(sub_passes)' format. if (auto *adaptor = getAdaptorPassBase(this)) { - interleaveComma(adaptor->getPassManagers(), os, [&](OpPassManager &pm) { - os << pm.getOpName() << "("; - pm.printAsTextualPipeline(os); - os << ")"; - }); + llvm::interleaveComma(adaptor->getPassManagers(), os, + [&](OpPassManager &pm) { + os << pm.getOpName() << "("; + pm.printAsTextualPipeline(os); + os << ")"; + }); return; } // Otherwise, print the pass argument followed by its options. If the pass @@ -295,9 +296,10 @@ impl->passes, [](const std::unique_ptr &pass) { return !isa(pass); }); - interleaveComma(filteredPasses, os, [&](const std::unique_ptr &pass) { - pass->printAsTextualPipeline(os); - }); + llvm::interleaveComma(filteredPasses, os, + [&](const std::unique_ptr &pass) { + pass->printAsTextualPipeline(os); + }); } //===----------------------------------------------------------------------===// @@ -358,7 +360,7 @@ std::string OpToOpPassAdaptorBase::getName() { std::string name = "Pipeline Collection : ["; llvm::raw_string_ostream os(name); - interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) { + llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) { os << '\'' << pm.getOpName() << '\''; }); os << ']'; diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -184,7 +184,7 @@ // Interleave the options with ' '. os << '{'; - interleave( + llvm::interleave( orderedOps, os, [&](OptionBase *option) { option->print(os); }, " "); os << '}'; } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1250,7 +1250,7 @@ auto &os = rewriterImpl.logger; os.getOStream() << "\n"; os.startLine() << "* Pattern : '" << pattern->getRootKind() << " -> ("; - interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); + llvm::interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); os.getOStream() << ")' {\n"; os.indent(); }); diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -65,7 +65,7 @@ } // Print resultant types - interleaveComma(op->getResultTypes(), os); + llvm::interleaveComma(op->getResultTypes(), os); os << "\n"; // A value used to elide large container attribute. diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -120,7 +120,7 @@ opInst->emitRemark("NOT MATCHED"); } else { outs << "\nmatched: " << *opInst << " with shape ratio: "; - interleaveComma(MutableArrayRef(*ratio), outs); + llvm::interleaveComma(MutableArrayRef(*ratio), outs); } } } diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -51,7 +51,7 @@ os << opName << " extensions: ["; for (const auto &exts : extension.getExtensions()) { os << " ["; - interleaveComma(exts, os, [&](spirv::Extension ext) { + llvm::interleaveComma(exts, os, [&](spirv::Extension ext) { os << spirv::stringifyExtension(ext); }); os << "]"; @@ -63,7 +63,7 @@ os << opName << " capabilities: ["; for (const auto &caps : capability.getCapabilities()) { os << " ["; - interleaveComma(caps, os, [&](spirv::Capability cap) { + llvm::interleaveComma(caps, os, [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); os << "]"; diff --git a/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp --- a/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp @@ -38,7 +38,7 @@ else llvm::outs() << offset; llvm::outs() << " strides: "; - interleaveComma(strides, llvm::outs(), [&](int64_t v) { + llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) { if (v == MemRefType::getDynamicStrideOrOffset()) llvm::outs() << "?"; else diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1480,22 +1480,23 @@ std::string iteratorsStr; llvm::raw_string_ostream ss(iteratorsStr); unsigned pos = 0; - interleaveComma(state.dims, ss, [&](std::pair p) { - bool reduction = false; - for (auto &expr : state.expressions) { - visitPostorder(*expr, [&](const Expression &e) { - if (auto *pTensorExpr = dyn_cast(&e)) { - if (pTensorExpr->reductionDimensions.count(pos) > 0) - reduction = true; + llvm::interleaveComma( + state.dims, ss, [&](std::pair p) { + bool reduction = false; + for (auto &expr : state.expressions) { + visitPostorder(*expr, [&](const Expression &e) { + if (auto *pTensorExpr = dyn_cast(&e)) { + if (pTensorExpr->reductionDimensions.count(pos) > 0) + reduction = true; + } + }); + if (reduction) + break; } + ss << (reduction ? "getReductionIteratorTypeName()" + : "getParallelIteratorTypeName()"); + pos++; }); - if (reduction) - break; - } - ss << (reduction ? "getReductionIteratorTypeName()" - : "getParallelIteratorTypeName()"); - pos++; - }); ss.flush(); os << llvm::formatv(referenceReferenceIteratorsFmt, opId, iteratorsStr); @@ -1515,8 +1516,9 @@ std::string dimsStr; llvm::raw_string_ostream ss(dimsStr); - interleaveComma(state.dims, ss, - [&](std::pair p) { ss << p.second; }); + llvm::interleaveComma( + state.dims, ss, + [&](std::pair p) { ss << p.second; }); ss.flush(); std::string mapsStr; @@ -1524,7 +1526,7 @@ SmallVector orderedUses(state.orderedTensorArgs.size()); for (auto it : state.orderedTensorArgs) orderedUses[it.second] = it.first; - interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) { + llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) { assert(u.indexingMap); const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1})"; if (u.indexingMap.isEmpty()) { @@ -1535,7 +1537,7 @@ std::string exprsStr; llvm::raw_string_ostream exprsStringStream(exprsStr); exprsStringStream << "{"; - interleaveComma(u.indexingMap.getResults(), exprsStringStream); + llvm::interleaveComma(u.indexingMap.getResults(), exprsStringStream); exprsStringStream << "}"; exprsStringStream.flush(); @@ -1563,10 +1565,10 @@ } else { std::string subExprs; llvm::raw_string_ostream subExprsStringStream(subExprs); - interleaveComma(pTensorExpr->expressions, subExprsStringStream, - [&](const std::unique_ptr &e) { - printExpr(subExprsStringStream, *e); - }); + llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream, + [&](const std::unique_ptr &e) { + printExpr(subExprsStringStream, *e); + }); subExprsStringStream.flush(); const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});"; os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->opId, subExprs); @@ -1586,10 +1588,11 @@ unsigned idx = 0; std::string valueHandleStr; llvm::raw_string_ostream valueHandleStringStream(valueHandleStr); - interleaveComma(state.orderedTensorArgs, valueHandleStringStream, [&](auto) { - valueHandleStringStream << "_" << idx << "(args[" << idx << "])"; - idx++; - }); + llvm::interleaveComma( + state.orderedTensorArgs, valueHandleStringStream, [&](auto) { + valueHandleStringStream << "_" << idx << "(args[" << idx << "])"; + idx++; + }); std::string expressionsStr; llvm::raw_string_ostream expressionStringStream(expressionsStr); @@ -1601,10 +1604,10 @@ std::string yieldStr; llvm::raw_string_ostream yieldStringStream(yieldStr); - interleaveComma(state.expressions, yieldStringStream, - [&](const std::unique_ptr &e) { - printExpr(yieldStringStream, *e); - }); + llvm::interleaveComma(state.expressions, yieldStringStream, + [&](const std::unique_ptr &e) { + printExpr(yieldStringStream, *e); + }); valueHandleStringStream.flush(); expressionStringStream.flush(); diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -183,7 +183,7 @@ template void printBracketedRange(const Range &range, llvm::raw_ostream &os) { os << '['; - mlir::interleaveComma(range, os); + llvm::interleaveComma(range, os); os << ']'; } @@ -213,7 +213,7 @@ printBracketedRange(traits, os); os << ", " << (intr.getNumResults() == 0 ? 0 : 1) << ">, Arguments<(ins" << (operands.empty() ? "" : " "); - mlir::interleaveComma(operands, os); + llvm::interleaveComma(operands, os); os << ")>;\n\n"; return false; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1107,14 +1107,16 @@ body << " " << builderOpState << ".addAttribute(\"operand_segment_sizes\", " "odsBuilder->getI32VectorAttr({"; - interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { - if (op.getOperand(i).isOptional()) - body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; - else if (op.getOperand(i).isVariadic()) - body << "static_cast(" << getArgumentName(op, i) << ".size())"; - else - body << "1"; - }); + llvm::interleaveComma( + llvm::seq(0, op.getNumOperands()), body, [&](int i) { + if (op.getOperand(i).isOptional()) + body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; + else if (op.getOperand(i).isVariadic()) + body << "static_cast(" << getArgumentName(op, i) + << ".size())"; + else + body << "1"; + }); body << "}));\n"; } @@ -1212,7 +1214,7 @@ continue; std::string args; llvm::raw_string_ostream os(args); - mlir::interleaveComma(method.getArguments(), os, + llvm::interleaveComma(method.getArguments(), os, [&](const OpInterfaceMethod::Argument &arg) { os << arg.type << " " << arg.name; }); @@ -1766,7 +1768,7 @@ static void emitOpList(const std::vector &defs, raw_ostream &os) { IfDefScope scope("GET_OP_LIST", os); - interleave( + llvm::interleave( // TODO: We are constructing the Operator wrapper instance just for // getting it's qualified class name here. Reduce the overhead by having a // lightweight version of Operator class just for that purpose. diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -795,7 +795,7 @@ body << " if (parser.resolveOperands("; if (op.getNumOperands() > 1) { body << "llvm::concat("; - interleaveComma(op.getOperands(), body, [&](auto &operand) { + llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) { body << operand.name << "Operands"; }); body << ")"; @@ -815,11 +815,12 @@ // the case of a single range, so guard it here. if (op.getNumOperands() > 1) { body << "llvm::concat("; - interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { - body << "ArrayRef("; - emitTypeResolver(operandTypes[i], op.getOperand(i).name); - body << ")"; - }); + llvm::interleaveComma( + llvm::seq(0, op.getNumOperands()), body, [&](int i) { + body << "ArrayRef("; + emitTypeResolver(operandTypes[i], op.getOperand(i).name); + body << ")"; + }); body << ")"; } else { emitTypeResolver(operandTypes.front(), op.getOperand(0).name); @@ -875,7 +876,7 @@ else body << "1"; }; - interleaveComma(op.getOperands(), body, interleaveFn); + llvm::interleaveComma(op.getOperands(), body, interleaveFn); body << "}));\n"; } } @@ -897,7 +898,7 @@ // Elide the variadic segment size attributes if necessary. if (!fmt.allOperands && op.getTrait("OpTrait::AttrSizedOperandSegments")) body << "\"operand_segment_sizes\", "; - interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) { + llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; }); body << "});\n"; @@ -1016,13 +1017,13 @@ } else if (auto *successor = dyn_cast(element)) { const NamedSuccessor *var = successor->getVar(); if (var->isVariadic()) - body << " interleaveComma(" << var->name << "(), p);\n"; + body << " llvm::interleaveComma(" << var->name << "(), p);\n"; else body << " p << " << var->name << "();\n"; } else if (isa(element)) { body << " p << getOperation()->getOperands();\n"; } else if (isa(element)) { - body << " interleaveComma(getOperation()->getSuccessors(), p);\n"; + body << " llvm::interleaveComma(getOperation()->getSuccessors(), p);\n"; } else if (auto *dir = dyn_cast(element)) { body << " p << "; genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -36,10 +36,10 @@ os << method.getName() << '('; if (addOperationArg) os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", "); - interleaveComma(method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { - os << arg.type << " " << arg.name; - }); + llvm::interleaveComma(method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); os << ')'; } @@ -72,7 +72,7 @@ os << " {\n return getImpl()->" << method.getName() << '('; if (!method.isStatic()) os << "getOperation()" << (method.arg_empty() ? "" : ", "); - interleaveComma( + llvm::interleaveComma( method.getArguments(), os, [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n }\n"; @@ -135,7 +135,7 @@ // Add the arguments to the call. os << method.getName() << '('; - interleaveComma( + llvm::interleaveComma( method.getArguments(), os, [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n }\n"; @@ -255,10 +255,10 @@ if (method.isStatic()) os << "static "; emitCPPType(method.getReturnType(), os) << method.getName() << '('; - interleaveComma(method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { - emitCPPType(arg.type, os) << arg.name; - }); + llvm::interleaveComma(method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { + emitCPPType(arg.type, os) << arg.name; + }); os << ");\n```\n"; // Emit the description. diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -500,7 +500,7 @@ llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { return lhs->getOperationName() < rhs->getOperationName(); }); - interleaveComma(sortedResultOps, os, [&](const Operator *op) { + llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { os << '"' << op->getOperationName() << '"'; }); os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -1305,7 +1305,7 @@ os << " case Capability::" << enumerant.getSymbol() << ": {static const Capability implies[" << impliedCapsDefs.size() << "] = {"; - mlir::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) { + llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) { os << "Capability::" << EnumAttrCase(capDef).getSymbol(); }); os << "}; return ArrayRef(implies, " << impliedCapsDefs.size()