diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -47,6 +47,9 @@ LocationMap *locationMap = nullptr); ~AsmState(); + /// Get the printer flags. + const OpPrintingFlags &getPrinterFlags() const; + /// Return an instance of the internal implementation. Returns nullptr if the /// state has not been initialized. detail::AsmStateImpl &getImpl() { return *impl; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -112,9 +112,8 @@ void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) { state->print(os, flags); } - void print(raw_ostream &os, AsmState &asmState, - OpPrintingFlags flags = llvm::None) { - state->print(os, asmState, flags); + void print(raw_ostream &os, AsmState &asmState) { + state->print(os, asmState); } /// Dump this operation. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -192,8 +192,7 @@ bool isBeforeInBlock(Operation *other); void print(raw_ostream &os, const OpPrintingFlags &flags = llvm::None); - void print(raw_ostream &os, AsmState &state, - const OpPrintingFlags &flags = llvm::None); + void print(raw_ostream &os, AsmState &state); void dump(); //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1216,6 +1216,9 @@ /// Get the state used for SSA names. SSANameState &getSSANameState() { return nameState; } + /// Get the printer flags. + const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } + /// Register the location, line and column, within the buffer that the given /// operation was printed at. void registerOperationLocation(Operation *op, unsigned line, unsigned col) { @@ -1247,6 +1250,10 @@ : impl(std::make_unique(op, printerFlags, locationMap)) {} AsmState::~AsmState() = default; +const OpPrintingFlags &AsmState::getPrinterFlags() const { + return impl->getPrinterFlags(); +} + //===----------------------------------------------------------------------===// // AsmPrinter::Impl //===----------------------------------------------------------------------===// @@ -2405,9 +2412,9 @@ using Impl = AsmPrinter::Impl; using Impl::printType; - explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, - AsmStateImpl &state) - : Impl(os, flags, &state), OpAsmPrinter(static_cast(*this)) {} + explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) + : Impl(os, state.getPrinterFlags(), &state), + OpAsmPrinter(static_cast(*this)) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); @@ -2893,7 +2900,7 @@ if (!getParent() && !printerFlags.shouldUseLocalScope()) { AsmState state(this, printerFlags); state.getImpl().initializeAliases(this); - print(os, state, printerFlags); + print(os, state); return; } @@ -2914,12 +2921,11 @@ } while (true); AsmState state(op, printerFlags); - print(os, state, printerFlags); + print(os, state); } -void Operation::print(raw_ostream &os, AsmState &state, - const OpPrintingFlags &flags) { - OperationPrinter printer(os, flags, state.getImpl()); - if (!getParent() && !flags.shouldUseLocalScope()) +void Operation::print(raw_ostream &os, AsmState &state) { + OperationPrinter printer(os, state.getImpl()); + if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) printer.printTopLevelOperation(this); else printer.print(this); @@ -2944,7 +2950,7 @@ print(os, state); } void Block::print(raw_ostream &os, AsmState &state) { - OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this); + OperationPrinter(os, state.getImpl()).print(this); } void Block::dump() { print(llvm::errs()); } @@ -2960,6 +2966,6 @@ printAsOperand(os, state); } void Block::printAsOperand(raw_ostream &os, AsmState &state) { - OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); + OperationPrinter printer(os, state.getImpl()); printer.printBlockName(this); } diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp --- a/mlir/lib/Transforms/LocationSnapshot.cpp +++ b/mlir/lib/Transforms/LocationSnapshot.cpp @@ -27,7 +27,7 @@ // Print the IR to the stream, and collect the raw line+column information. AsmState::LocationMap opToLineCol; AsmState state(op, flags, &opToLineCol); - op->print(os, state, flags); + op->print(os, state); Builder builder(op->getContext()); Optional tagIdentifier;