diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/AsmState.h @@ -0,0 +1,46 @@ +//===- AsmState.h - State class for AsmPrinter ------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ASMSTATE_H_ +#define MLIR_IR_ASMSTATE_H_ + +#include + +namespace mlir { +class Operation; + +namespace detail { +class AsmStateImpl; +} // end namespace detail + +/// This class provides management for the lifetime of the state used when +/// printing the IR. It allows for alleviating the cost of recomputing the +/// internal state of the asm printer. +/// +/// The IR should not be mutated in-between invocations using this state, and +/// the IR being printed must not be at a higher level than the IR originally +/// used to initialize this state. This means that if a child operation is +/// provided, the parent operation cannot reuse this state. +class AsmState { +public: + /// Initialize the asm state at the level of the given operation. + AsmState(Operation *op); + ~AsmState(); + + /// Return an instance of the internal implementation. Returns nullptr if the + /// state has not been initialized. + detail::AsmStateImpl &getImpl() { return *impl; } + +private: + /// A pointer to allocated storage for the impl state. + std::unique_ptr impl; +}; + +} // end namespace mlir + +#endif // MLIR_IR_ASMSTATE_H_ diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -307,12 +307,14 @@ } void print(raw_ostream &os); + void print(raw_ostream &os, AsmState &state); void dump(); /// Print out the name of the block without printing its body. /// NOTE: The printType argument is ignored. We keep it for compatibility /// with LLVM dominator machinery that expects it to exist. void printAsOperand(raw_ostream &os, bool printType = true); + void printAsOperand(raw_ostream &os, AsmState &state); private: /// Pair of the parent object that owns this block and a bit that signifies if diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -57,6 +57,8 @@ /// Print the this module in the custom top-level form. void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); + void print(raw_ostream &os, AsmState &state, + OpPrintingFlags flags = llvm::None); void dump(); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -122,6 +122,10 @@ void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) { state->print(os, flags); } + void print(raw_ostream &os, AsmState &asmState, + OpPrintingFlags flags = llvm::None) { + state->print(os, asmState, flags); + } /// Dump this operation. void dump() { state->dump(); } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -186,6 +186,8 @@ bool isBeforeInBlock(Operation *other); void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); + void print(raw_ostream &os, AsmState &state, + OpPrintingFlags flags = llvm::None); void dump(); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -18,6 +18,7 @@ #include "mlir/Support/LLVM.h" namespace mlir { +class AsmState; class Block; class BlockArgument; class Operation; @@ -192,8 +193,12 @@ // Utilities void print(raw_ostream &os); + void print(raw_ostream &os, AsmState &state); void dump(); + /// Print this value as if it were an operand. + void printAsOperand(raw_ostream &os, AsmState &state); + /// Methods for supporting PointerLikeTypeTraits. void *getAsOpaquePointer() const { return static_cast(impl); } static Value getFromOpaquePointer(const void *pointer) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -36,6 +37,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Regex.h" using namespace mlir; +using namespace mlir::detail; void Identifier::print(raw_ostream &os) const { os << str(); } @@ -759,13 +761,14 @@ } //===----------------------------------------------------------------------===// -// ModuleState +// AsmState //===----------------------------------------------------------------------===// -namespace { -class ModuleState { +namespace mlir { +namespace detail { +class AsmStateImpl { public: - explicit ModuleState(Operation *op) + explicit AsmStateImpl(Operation *op) : interfaces(op->getContext()), nameState(op, interfaces) {} /// Initialize the alias state to enable the printing of aliases. @@ -795,7 +798,11 @@ /// The state used for SSA value names. SSANameState nameState; }; -} // end anonymous namespace +} // end namespace detail +} // end namespace mlir + +AsmState::AsmState(Operation *op) : impl(std::make_unique(op)) {} +AsmState::~AsmState() {} //===----------------------------------------------------------------------===// // ModulePrinter @@ -805,7 +812,7 @@ class ModulePrinter { public: ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, - ModuleState *state = nullptr) + AsmStateImpl *state = nullptr) : os(os), printerFlags(flags), state(state) {} explicit ModulePrinter(ModulePrinter &printer) : os(printer.os), printerFlags(printer.printerFlags), @@ -819,8 +826,6 @@ mlir::interleaveComma(c, os, each_fn); } - void print(ModuleOp module); - /// Print the given attribute. If 'mayElideType' is true, some attributes are /// printed without the type when the type matches the default used in the /// parser (for example i64 is the default for integer attributes). @@ -865,7 +870,7 @@ OpPrintingFlags printerFlags; /// An optional printer state for the module. - ModuleState *state; + AsmStateImpl *state; }; } // end anonymous namespace @@ -1812,10 +1817,12 @@ /// This class contains the logic for printing operations, regions, and blocks. class OperationPrinter : public ModulePrinter, private OpAsmPrinter { public: - explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) { - assert(state && "expected valid state when printing operation"); - } + explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, + AsmStateImpl &state) + : ModulePrinter(os, flags, &state) {} + /// Print the given top-level module. + void print(ModuleOp op); /// Print the given operation with its indent and location. void print(Operation *op); /// Print the bare location, not including indentation/location/etc. @@ -1900,6 +1907,15 @@ }; } // end anonymous namespace +void OperationPrinter::print(ModuleOp op) { + // Output the aliases at the top level. + state->getAliasState().printAttributeAliases(os); + state->getAliasState().printTypeAliases(os); + + // Print the module. + print(op.getOperation()); +} + void OperationPrinter::print(Operation *op) { os.indent(currentIndent); printOperation(op); @@ -2107,18 +2123,6 @@ }); } -void ModulePrinter::print(ModuleOp module) { - assert(state && "expected valid state when printing an operation"); - - // Output the aliases at the top level. - state->getAliasState().printAttributeAliases(os); - state->getAliasState().printTypeAliases(os); - - // Print the module. - OperationPrinter(*this).print(module); - os << '\n'; -} - //===----------------------------------------------------------------------===// // print and dump methods //===----------------------------------------------------------------------===// @@ -2178,18 +2182,34 @@ assert(isa()); os << "\n"; } +void Value::print(raw_ostream &os, AsmState &state) { + if (auto *op = getDefiningOp()) + return op->print(os, state); + + // TODO: Improve this. + assert(isa()); + os << "\n"; +} void Value::dump() { print(llvm::errs()); llvm::errs() << "\n"; } +void Value::printAsOperand(raw_ostream &os, AsmState &state) { + // TODO(riverriddle) This doesn't necessarily capture all potential cases. + // Currently, region arguments can be shadowed when printing the main + // operation. If the IR hasn't been printed, this will produce the old SSA + // name and not the shadowed name. + state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, + os); +} + void Operation::print(raw_ostream &os, OpPrintingFlags flags) { // Handle top-level operations or local printing. if (!getParent() || flags.shouldUseLocalScope()) { - ModuleState state(this); - ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(this); + OperationPrinter(os, flags, state.getImpl()).print(this); return; } @@ -2202,9 +2222,11 @@ while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(parentOp); + print(os, state, flags); +} +void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { + OperationPrinter(os, flags, state.getImpl()).print(this); } void Operation::dump() { @@ -2222,9 +2244,11 @@ while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(parentOp); + print(os, state); +} +void Block::print(raw_ostream &os, AsmState &state) { + OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this); } void Block::dump() { print(llvm::errs()); } @@ -2240,18 +2264,24 @@ while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(modulePrinter).printBlockName(this); + AsmState state(parentOp); + printAsOperand(os, state); +} +void Block::printAsOperand(raw_ostream &os, AsmState &state) { + OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); + printer.printBlockName(this); } void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { - ModuleState state(*this); + AsmState state(*this); // Don't populate aliases when printing at local scope. if (!flags.shouldUseLocalScope()) - state.initializeAliases(*this); - ModulePrinter(os, flags, &state).print(*this); + state.getImpl().initializeAliases(*this); + print(os, state, flags); +} +void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { + OperationPrinter(os, flags, state.getImpl()).print(*this); } void ModuleOp::dump() { print(llvm::errs()); }