diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -29,6 +29,7 @@ class MLIRContext; class Operation; class OperationName; +class OpPrintingFlags; class Type; class Value; @@ -218,6 +219,8 @@ Diagnostic &operator<<(Operation *val) { return *this << *val; } + /// Append an operation with the given printing flags. + Diagnostic &appendOp(Operation &val, const OpPrintingFlags &flags); /// Stream in a Value. Diagnostic &operator<<(Value val); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -127,9 +127,13 @@ /// Stream in an Operation. Diagnostic &Diagnostic::operator<<(Operation &val) { + return appendOp(val, OpPrintingFlags()); +} +Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) { std::string str; llvm::raw_string_ostream os(str); - val.print(os, OpPrintingFlags().useLocalScope().elideLargeElementsAttrs()); + val.print(os, + OpPrintingFlags(flags).useLocalScope().elideLargeElementsAttrs()); return *this << os.str(); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -276,17 +276,9 @@ InFlightDiagnostic Operation::emitError(const Twine &message) { InFlightDiagnostic diag = mlir::emitError(getLoc(), message); if (getContext()->shouldPrintOpOnDiagnostic()) { - // Print out the operation explicitly here so that we can print the generic - // form. - // TODO: It would be nice if we could instead provide the - // specific printing flags when adding the operation as an argument to the - // diagnostic. - std::string printedOp; - { - llvm::raw_string_ostream os(printedOp); - print(os, OpPrintingFlags().printGenericOpForm().useLocalScope()); - } - diag.attachNote(getLoc()) << "see current operation: " << printedOp; + diag.attachNote(getLoc()) + .append("see current operation: ") + .appendOp(*this, OpPrintingFlags().printGenericOpForm()); } return diag; }