diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -13,6 +13,7 @@ #ifndef MLIR_IR_ASMSTATE_H_ #define MLIR_IR_ASMSTATE_H_ +#include "mlir/IR/OperationSupport.h" #include "mlir/Support/LLVM.h" #include @@ -41,7 +42,9 @@ /// Initialize the asm state at the level of the given operation. A location /// map may optionally be provided to be populated when printing. - AsmState(Operation *op, LocationMap *locationMap = nullptr); + AsmState(Operation *op, + const OpPrintingFlags &printerFlags = OpPrintingFlags(), + LocationMap *locationMap = nullptr); ~AsmState(); /// Return an instance of the internal implementation. Returns nullptr if the diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -204,9 +204,9 @@ /// take O(N) where N is the number of operations within the parent block. bool isBeforeInBlock(Operation *other); - void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); + void print(raw_ostream &os, const OpPrintingFlags &flags = llvm::None); void print(raw_ostream &os, AsmState &state, - OpPrintingFlags flags = llvm::None); + const OpPrintingFlags &flags = llvm::None); void dump(); //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -356,9 +356,9 @@ /// in the output, and trims down unnecessary output. class DummyAliasOperationPrinter : private OpAsmPrinter { public: - explicit DummyAliasOperationPrinter(const OpPrintingFlags &flags, + explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, AliasInitializer &initializer) - : printerFlags(flags), initializer(initializer) {} + : printerFlags(printerFlags), initializer(initializer) {} /// Print the given operation. void print(Operation *op) { @@ -767,7 +767,7 @@ /// A sentinel value used for values with names set. enum : unsigned { NameSentinel = ~0U }; - SSANameState(Operation *op, + SSANameState(Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces); /// Print the SSA identifier for the given value to 'stream'. If @@ -833,14 +833,18 @@ /// This is the next ID to assign when a name conflict is detected. unsigned nextConflictID = 0; + /// These are the printing flags. They control, eg., whether to print in + /// generic form. + OpPrintingFlags printerFlags; + DialectInterfaceCollection &interfaces; }; } // end anonymous namespace SSANameState::SSANameState( - Operation *op, + Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces) - : interfaces(interfaces) { + : printerFlags(printerFlags), interfaces(interfaces) { llvm::SaveAndRestore valueIDSaver(nextValueID); llvm::SaveAndRestore argumentIDSaver(nextArgumentID); llvm::SaveAndRestore conflictIDSaver(nextConflictID); @@ -1134,12 +1138,13 @@ namespace detail { class AsmStateImpl { public: - explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap) - : interfaces(op->getContext()), nameState(op, interfaces), - locationMap(locationMap) {} + explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, + AsmState::LocationMap *locationMap) + : interfaces(op->getContext()), nameState(op, printerFlags, interfaces), + printerFlags(printerFlags), locationMap(locationMap) {} /// Initialize the alias state to enable the printing of aliases. - void initializeAliases(Operation *op, const OpPrintingFlags &printerFlags) { + void initializeAliases(Operation *op) { aliasState.initialize(op, printerFlags, interfaces); } @@ -1172,14 +1177,18 @@ /// The state used for SSA value names. SSANameState nameState; + /// Flags that control op output. + OpPrintingFlags printerFlags; + /// An optional location map to be populated. AsmState::LocationMap *locationMap; }; } // end namespace detail } // end namespace mlir -AsmState::AsmState(Operation *op, LocationMap *locationMap) - : impl(std::make_unique(op, locationMap)) {} +AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, + LocationMap *locationMap) + : impl(std::make_unique(op, printerFlags, locationMap)) {} AsmState::~AsmState() {} //===----------------------------------------------------------------------===// @@ -2760,18 +2769,18 @@ os); } -void Operation::print(raw_ostream &os, OpPrintingFlags flags) { +void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { // If this is a top level operation, we also print aliases. - if (!getParent() && !flags.shouldUseLocalScope()) { - AsmState state(this); - state.getImpl().initializeAliases(this, flags); - print(os, state, flags); + if (!getParent() && !printerFlags.shouldUseLocalScope()) { + AsmState state(this, printerFlags); + state.getImpl().initializeAliases(this); + print(os, state, printerFlags); return; } // Find the operation to number from based upon the provided flags. Operation *op = this; - bool shouldUseLocalScope = flags.shouldUseLocalScope(); + bool shouldUseLocalScope = printerFlags.shouldUseLocalScope(); do { // If we are printing local scope, stop at the first operation that is // isolated from above. @@ -2785,10 +2794,11 @@ op = parentOp; } while (true); - AsmState state(op); - print(os, state, flags); + AsmState state(op, printerFlags); + print(os, state, printerFlags); } -void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { +void Operation::print(raw_ostream &os, AsmState &state, + const OpPrintingFlags &flags) { OperationPrinter printer(os, flags, state.getImpl()); if (!getParent() && !flags.shouldUseLocalScope()) printer.printTopLevelOperation(this); diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp --- a/mlir/lib/Transforms/LocationSnapshot.cpp +++ b/mlir/lib/Transforms/LocationSnapshot.cpp @@ -22,11 +22,11 @@ /// NameLoc with the given tag as the name, and then fused with the existing /// locations. Otherwise, the existing locations are replaced. static void generateLocationsFromIR(raw_ostream &os, StringRef fileName, - Operation *op, OpPrintingFlags flags, + Operation *op, const OpPrintingFlags &flags, StringRef tag) { // Print the IR to the stream, and collect the raw line+column information. AsmState::LocationMap opToLineCol; - AsmState state(op, &opToLineCol); + AsmState state(op, flags, &opToLineCol); op->print(os, state, flags); Builder builder(op->getContext());