diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -51,6 +51,36 @@ (ins "::mlir::OpAsmSetValueNameFn":$setNameFn), "", ";" >, + InterfaceMethod<[{ + Get the name to use for a given block inside a region attached to this + operation. + + For example if this operation has multiple blocks: + + ```mlir + some.op() ({ + ^bb0: + ... + ^bb1: + ... + }) + ``` + + the method will be invoked on each of the blocks allowing the op to + print: + + ```mlir + some.op() ({ + ^custom_foo_name: + ... + ^custom_bar_name: + ... + }) + ``` + }], + "void", "getAsmBlockNames", + (ins "::mlir::OpAsmSetBlockNameFn":$setNameFn), "", ";" + >, StaticInterfaceMethod<[{ Return the default dialect used when printing/parsing operations in regions nested under this operation. This allows for eliding the dialect diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1316,6 +1316,10 @@ /// operation. See 'getAsmResultNames' below for more details. using OpAsmSetValueNameFn = function_ref; +/// A functor used to set the name of blocks in regions directly nested under +/// an operation. +using OpAsmSetBlockNameFn = function_ref; + class OpAsmDialectInterface : public DialectInterface::Base { public: diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -791,6 +791,13 @@ //===----------------------------------------------------------------------===// namespace { +/// Info about block printing: a number which is its position in the visitation +/// order, and a name that is used to print reference to it, e.g. ^bb42. +struct BlockInfo { + int ordering; + StringRef name; +}; + /// This class manages the state of SSA value names. class SSANameState { public: @@ -809,8 +816,8 @@ /// operation, or empty if none exist. ArrayRef getOpResultGroups(Operation *op); - /// Get the ID for the given block. - unsigned getBlockID(Block *block); + /// Get the info for the given block. + BlockInfo getBlockInfo(Block *block); /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for @@ -847,8 +854,9 @@ /// value of this map are the result numbers that start a result group. DenseMap> opResultGroups; - /// This is the block ID for each block in the current. - DenseMap blockIDs; + /// This maps blocks to there visitation number in the current region as well + /// as the string representing their name. + DenseMap blockNames; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. @@ -971,9 +979,10 @@ return it == opResultGroups.end() ? ArrayRef() : it->second; } -unsigned SSANameState::getBlockID(Block *block) { - auto it = blockIDs.find(block); - return it != blockIDs.end() ? it->second : NameSentinel; +BlockInfo SSANameState::getBlockInfo(Block *block) { + auto it = blockNames.find(block); + BlockInfo invalidBlock{-1, "INVALIDBLOCK"}; + return it != blockNames.end() ? it->second : invalidBlock; } void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { @@ -1011,7 +1020,16 @@ for (auto &block : region) { // Each block gets a unique ID, and all of the operations within it get // numbered as well. - blockIDs[&block] = nextBlockID++; + auto blockInfoIt = blockNames.insert({&block, {-1, ""}}); + if (blockInfoIt.second) { + // This block hasn't been named through `getAsmBlockArgumentNames`, use + // default `^bbNNN` format. + std::string name; + llvm::raw_string_ostream(name) << "^bb" << nextBlockID; + blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator); + } + blockInfoIt.first->second.ordering = nextBlockID++; + numberValuesInBlock(block); } } @@ -1052,11 +1070,6 @@ } void SSANameState::numberValuesInOp(Operation &op) { - unsigned numResults = op.getNumResults(); - if (numResults == 0) - return; - Value resultBegin = op.getResult(0); - // Function used to set the special result names for the operation. SmallVector resultGroups(/*Size=*/1, /*Value=*/0); auto setResultNameFn = [&](Value result, StringRef name) { @@ -1068,13 +1081,37 @@ if (int resultNo = result.cast().getResultNumber()) resultGroups.push_back(resultNo); }; + // Operations can customize the printing of block names in OpAsmOpInterface. + auto setBlockNameFn = [&](Block *block, StringRef name) { + assert(block->getParentOp() == &op && + "getAsmBlockArgumentNames callback invoked on a block not directly " + "nested under the current operation"); + assert(!blockNames.count(block) && "block numbered multiple times"); + SmallString<16> tmpBuffer{"^"}; + name = sanitizeIdentifier(name, tmpBuffer); + if (name.data() != tmpBuffer.data()) { + tmpBuffer.append(name); + name = tmpBuffer.str(); + } + name = name.copy(usedNameAllocator); + blockNames[block] = {-1, name}; + }; + if (!printerFlags.shouldPrintGenericOpForm()) { - if (OpAsmOpInterface asmInterface = dyn_cast(&op)) + if (OpAsmOpInterface asmInterface = dyn_cast(&op)) { + asmInterface.getAsmBlockNames(setBlockNameFn); asmInterface.getAsmResultNames(setResultNameFn); - else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect())) + } else if (auto *asmInterface = + interfaces.getInterfaceFor(op.getDialect())) { asmInterface->getAsmResultNames(&op, setResultNameFn); + } } + unsigned numResults = op.getNumResults(); + if (numResults == 0) + return; + Value resultBegin = op.getResult(0); + // If the first result wasn't numbered, give it a default number. if (valueIDs.try_emplace(resultBegin, nextValueID).second) ++nextValueID; @@ -2621,11 +2658,7 @@ } void OperationPrinter::printBlockName(Block *block) { - auto id = state->getSSANameState().getBlockID(block); - if (id != SSANameState::NameSentinel) - os << "^bb" << id; - else - os << "^INVALIDBLOCK"; + os << state->getSSANameState().getBlockInfo(block).name; } void OperationPrinter::print(Block *block, bool printBlockArgs, @@ -2658,18 +2691,18 @@ os << " // pred: "; printBlockName(pred); } else { - // We want to print the predecessors in increasing numeric order, not in + // We want to print the predecessors in a stable order, not in // whatever order the use-list is in, so gather and sort them. - SmallVector, 4> predIDs; + SmallVector predIDs; for (auto *pred : block->getPredecessors()) - predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); - llvm::array_pod_sort(predIDs.begin(), predIDs.end()); + predIDs.push_back(state->getSSANameState().getBlockInfo(pred)); + llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { + return lhs.ordering < rhs.ordering; + }); os << " // " << predIDs.size() << " preds: "; - interleaveComma(predIDs, [&](std::pair pred) { - printBlockName(pred.second); - }); + interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; }); } os << newLine; } diff --git a/mlir/test/IR/pretty_printed_region_op.mlir b/mlir/test/IR/pretty_printed_region_op.mlir --- a/mlir/test/IR/pretty_printed_region_op.mlir +++ b/mlir/test/IR/pretty_printed_region_op.mlir @@ -33,3 +33,28 @@ return %0 : f32 } +// ----- + +// This tests the behavior of custom block names: +// operations like `test.block_names` can define custom names for blocks in +// nested regions. +// CHECK-CUSTOM-LABEL: func @block_names +func @block_names(%bool : i1) { + // CHECK: test.block_names + test.block_names { + // CHECK-CUSTOM: br ^foo1 + // CHECK-GENERIC: std.br{{.*}}^bb1 + br ^foo1 + // CHECK-CUSTOM: ^foo1: + // CHECK-GENERIC: ^bb1: + ^foo1: + // CHECK-CUSTOM: br ^foo2 + // CHECK-GENERIC: std.br{{.*}}^bb2 + br ^foo2 + // CHECK-CUSTOM: ^foo2: + // CHECK-GENERIC: ^bb2: + ^foo2: + "test.return"() : () -> () + } + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -627,6 +627,24 @@ let assemblyFormat = "regions attr-dict-with-keyword"; } +// This is used to test the OpAsmOpInterface::getAsmBlockName() feature: +// blocks nested in a region under this op will have a name defined by the +// interface. +def AsmBlockNameOp : TEST_Op<"block_names", [OpAsmOpInterface]> { + let regions = (region AnyRegion:$body); + let extraClassDeclaration = [{ + void getAsmBlockNames(mlir::OpAsmSetBlockNameFn setNameFn) { + std::string name; + int count = 0; + for (::mlir::Block &block : getRegion().getBlocks()) { + name = "foo" + std::to_string(count++); + setNameFn(&block, name); + } + } + }]; + let assemblyFormat = "regions attr-dict-with-keyword"; +} + // This operation requires its return type to have the trait 'TestTypeTrait'. def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> { let results = (outs AnyType);