diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/AsmState.h @@ -0,0 +1,52 @@ +//===- AsmState.h - State class for AsmPrinter ------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the AsmState class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ASMSTATE_H_ +#define MLIR_IR_ASMSTATE_H_ + +#include + +namespace mlir { +class Operation; + +namespace detail { +class AsmStateImpl; +} // end namespace detail + +/// This class provides management for the lifetime of the state used when +/// printing the IR. It allows for alleviating the cost of recomputing the +/// internal state of the asm printer. +/// +/// The IR should not be mutated in-between invocations using this state, and +/// the IR being printed must not be an parent of the IR originally used to +/// initialize this state. This means that if a child operation is provided, a +/// parent operation cannot reuse this state. +class AsmState { +public: + /// Initialize the asm state at the level of the given operation. + AsmState(Operation *op); + ~AsmState(); + + /// Return an instance of the internal implementation. Returns nullptr if the + /// state has not been initialized. + detail::AsmStateImpl &getImpl() { return *impl; } + +private: + AsmState() = delete; + + /// A pointer to allocated storage for the impl state. + std::unique_ptr impl; +}; + +} // end namespace mlir + +#endif // MLIR_IR_ASMSTATE_H_ diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -312,12 +312,14 @@ } void print(raw_ostream &os); + void print(raw_ostream &os, AsmState &state); void dump(); /// Print out the name of the block without printing its body. /// NOTE: The printType argument is ignored. We keep it for compatibility /// with LLVM dominator machinery that expects it to exist. void printAsOperand(raw_ostream &os, bool printType = true); + void printAsOperand(raw_ostream &os, AsmState &state); private: /// Pair of the parent object that owns this block and a bit that signifies if diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -57,6 +57,8 @@ /// Print the this module in the custom top-level form. void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); + void print(raw_ostream &os, AsmState &state, + OpPrintingFlags flags = llvm::None); void dump(); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -122,6 +122,10 @@ void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) { state->print(os, flags); } + void print(raw_ostream &os, AsmState &asmState, + OpPrintingFlags flags = llvm::None) { + state->print(os, asmState, flags); + } /// Dump this operation. void dump() { state->dump(); } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -187,6 +187,8 @@ bool isBeforeInBlock(Operation *other); void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); + void print(raw_ostream &os, AsmState &state, + OpPrintingFlags flags = llvm::None); void dump(); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -18,6 +18,7 @@ #include "mlir/Support/LLVM.h" namespace mlir { +class AsmState; class BlockArgument; class Operation; class OpResult; @@ -172,8 +173,12 @@ Kind getKind() const { return ownerAndKind.getInt(); } void print(raw_ostream &os); + void print(raw_ostream &os, AsmState &state); void dump(); + /// Print this value as if it were an operand. + void printAsOperand(raw_ostream &os, AsmState &state); + /// Methods for supporting PointerLikeTypeTraits. void *getAsOpaquePointer() const { return ownerAndKind.getOpaqueValue(); } static Value getFromOpaquePointer(const void *pointer) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -37,6 +38,7 @@ #include "llvm/Support/Regex.h" #include "llvm/Support/SaveAndRestore.h" using namespace mlir; +using namespace mlir::detail; void Identifier::print(raw_ostream &os) const { os << str(); } @@ -756,13 +758,14 @@ } //===----------------------------------------------------------------------===// -// ModuleState +// AsmState //===----------------------------------------------------------------------===// -namespace { -class ModuleState { +namespace mlir { +namespace detail { +class AsmStateImpl { public: - explicit ModuleState(Operation *op) + explicit AsmStateImpl(Operation *op) : interfaces(op->getContext()), nameState(op, interfaces) {} /// Initialize the alias state to enable the printing of aliases. @@ -792,7 +795,11 @@ /// The state used for SSA value names. SSANameState nameState; }; -} // end anonymous namespace +} // end namespace detail +} // end namespace mlir + +AsmState::AsmState(Operation *op) : impl(std::make_unique(op)) {} +AsmState::~AsmState() {} //===----------------------------------------------------------------------===// // ModulePrinter @@ -802,7 +809,7 @@ class ModulePrinter { public: ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, - ModuleState *state = nullptr) + AsmStateImpl *state = nullptr) : os(os), printerFlags(flags), state(state) {} explicit ModulePrinter(ModulePrinter &printer) : os(printer.os), printerFlags(printer.printerFlags), @@ -816,8 +823,6 @@ mlir::interleaveComma(c, os, each_fn); } - void print(ModuleOp module); - /// Print the given attribute. If 'mayElideType' is true, some attributes are /// printed without the type when the type matches the default used in the /// parser (for example i64 is the default for integer attributes). @@ -862,7 +867,7 @@ OpPrintingFlags printerFlags; /// An optional printer state for the module. - ModuleState *state; + AsmStateImpl *state; }; } // end anonymous namespace @@ -1815,10 +1820,12 @@ /// This class contains the logic for printing operations, regions, and blocks. class OperationPrinter : public ModulePrinter, private OpAsmPrinter { public: - explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) { - assert(state && "expected valid state when printing operation"); - } + explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, + AsmStateImpl &state) + : ModulePrinter(os, flags, &state) {} + /// Print the given top-level module. + void print(ModuleOp op); /// Print the given operation with its indent and location. void print(Operation *op); /// Print the bare location, not including indentation/location/etc. @@ -1903,6 +1910,15 @@ }; } // end anonymous namespace +void OperationPrinter::print(ModuleOp op) { + // Output the aliases at the top level. + state->getAliasState().printAttributeAliases(os); + state->getAliasState().printTypeAliases(os); + + // Print the module. + print(op.getOperation()); +} + void OperationPrinter::print(Operation *op) { os.indent(currentIndent); printOperation(op); @@ -2108,18 +2124,6 @@ }); } -void ModulePrinter::print(ModuleOp module) { - assert(state && "expected valid state when printing an operation"); - - // Output the aliases at the top level. - state->getAliasState().printAttributeAliases(os); - state->getAliasState().printTypeAliases(os); - - // Print the module. - OperationPrinter(*this).print(module); - os << '\n'; -} - //===----------------------------------------------------------------------===// // print and dump methods //===----------------------------------------------------------------------===// @@ -2179,18 +2183,34 @@ assert(isa()); os << "\n"; } +void Value::print(raw_ostream &os, AsmState &state) { + if (auto *op = getDefiningOp()) + return op->print(os, state); + + // TODO: Improve this. + assert(isa()); + os << "\n"; +} void Value::dump() { print(llvm::errs()); llvm::errs() << "\n"; } +void Value::printAsOperand(raw_ostream &os, AsmState &state) { + // TODO(riverriddle) This doesn't necessarily capture all potential cases. + // Currently, region arguments can be shadowed when printing the main + // operation. If the IR hasn't been printed, this will produce the old SSA + // name and not the shadowed name. + state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, + os); +} + void Operation::print(raw_ostream &os, OpPrintingFlags flags) { // Handle top-level operations or local printing. if (!getParent() || flags.shouldUseLocalScope()) { - ModuleState state(this); - ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(this); + OperationPrinter(os, flags, state.getImpl()).print(this); return; } @@ -2203,9 +2223,11 @@ while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(parentOp); + print(os, state, flags); +} +void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { + OperationPrinter(os, flags, state.getImpl()).print(this); } void Operation::dump() { @@ -2223,9 +2245,11 @@ while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(modulePrinter).print(this); + AsmState state(parentOp); + print(os, state); +} +void Block::print(raw_ostream &os, AsmState &state) { + OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this); } void Block::dump() { print(llvm::errs()); } @@ -2241,18 +2265,24 @@ while (auto *nextOp = parentOp->getParentOp()) parentOp = nextOp; - ModuleState state(parentOp); - ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(modulePrinter).printBlockName(this); + AsmState state(parentOp); + printAsOperand(os, state); +} +void Block::printAsOperand(raw_ostream &os, AsmState &state) { + OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); + printer.printBlockName(this); } void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { - ModuleState state(*this); + AsmState state(*this); // Don't populate aliases when printing at local scope. if (!flags.shouldUseLocalScope()) - state.initializeAliases(*this); - ModulePrinter(os, flags, &state).print(*this); + state.getImpl().initializeAliases(*this); + print(os, state, flags); +} +void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { + OperationPrinter(os, flags, state.getImpl()).print(*this); } void ModuleOp::dump() { print(llvm::errs()); }