diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -51,6 +51,14 @@ (ins "::mlir::OpAsmSetValueNameFn":$setNameFn), "", ";" >, + InterfaceMethod<[{ + Control the name used for the blocks inside the region attached to this + operation. + }], + "std::string", "getAsmBlockName", + (ins "::mlir::Block *":$block), + "", "return \"\";" + >, StaticInterfaceMethod<[{ Return the default dialect used when printing/parsing operations in regions nested under this operation. This allows for eliding the dialect diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -810,7 +810,7 @@ ArrayRef getOpResultGroups(Operation *op); /// Get the ID for the given block. - unsigned getBlockID(Block *block); + StringRef getBlockID(Block *block); /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for @@ -847,8 +847,8 @@ /// value of this map are the result numbers that start a result group. DenseMap> opResultGroups; - /// This is the block ID for each block in the current. - DenseMap blockIDs; + /// These are the block IDs for each block in the current region. + DenseMap blockIDs; /// This keeps track of all of the non-numeric names that are in flight, /// allowing us to check for duplicates. @@ -971,9 +971,9 @@ return it == opResultGroups.end() ? ArrayRef() : it->second; } -unsigned SSANameState::getBlockID(Block *block) { +StringRef SSANameState::getBlockID(Block *block) { auto it = blockIDs.find(block); - return it != blockIDs.end() ? it->second : NameSentinel; + return it != blockIDs.end() ? it->second : StringRef("INVALIDBLOCK"); } void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { @@ -1007,11 +1007,40 @@ void SSANameState::numberValuesInRegion(Region ®ion) { // Number the values within this region in a breadth-first order. + auto parentOp = dyn_cast_or_null(region.getParentOp()); unsigned nextBlockID = 0; + StringSet<> blockNames; for (auto &block : region) { - // Each block gets a unique ID, and all of the operations within it get - // numbered as well. - blockIDs[&block] = nextBlockID++; + if (parentOp && !printerFlags.shouldPrintGenericOpForm()) { + // Operation can customize the block name inside their regions. + std::string baseBlockName = parentOp.getAsmBlockName(&block); + std::string blockName = baseBlockName; + if (baseBlockName.empty()) { + blockName = "bb0"; + baseBlockName = "bb"; + } + // Each block gets a unique ID, try to use the provided name but detects + // name collision and append a increasing counter. + int count = 0; + do { + auto insertIt = blockNames.insert(blockName); + if (insertIt.second) { + // No collision, the name is unique in this region! + blockIDs.insert( + std::make_pair(&block, insertIt.first->first().str())); + break; + } + // We hit a collision, increment `count` and append to the name. + blockName = baseBlockName + std::to_string(++count); + } while (1); + } else { + std::string blockName = "bb" + std::to_string(nextBlockID++); + auto insertIt = blockNames.insert(blockName); + assert(insertIt.second && "unexpected collision in block ids"); + blockIDs.insert(std::make_pair(&block, insertIt.first->first().str())); + } + + // all of the operations within the block get numbered as well. numberValuesInBlock(block); } } @@ -2621,11 +2650,7 @@ } void OperationPrinter::printBlockName(Block *block) { - auto id = state->getSSANameState().getBlockID(block); - if (id != SSANameState::NameSentinel) - os << "^bb" << id; - else - os << "^INVALIDBLOCK"; + os << "^" << state->getSSANameState().getBlockID(block); } void OperationPrinter::print(Block *block, bool printBlockArgs, @@ -2658,18 +2683,16 @@ os << " // pred: "; printBlockName(pred); } else { - // We want to print the predecessors in increasing numeric order, not in + // We want to print the predecessors in a stable order, not in // whatever order the use-list is in, so gather and sort them. - SmallVector, 4> predIDs; + SmallVector predIDs; for (auto *pred : block->getPredecessors()) - predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); - llvm::array_pod_sort(predIDs.begin(), predIDs.end()); + predIDs.push_back(state->getSSANameState().getBlockID(pred)); + llvm::sort(predIDs.begin(), predIDs.end()); os << " // " << predIDs.size() << " preds: "; - interleaveComma(predIDs, [&](std::pair pred) { - printBlockName(pred.second); - }); + interleaveComma(predIDs, [&](StringRef pred) { os << "^" << pred; }); } os << newLine; } diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -298,7 +298,7 @@ // CHECK-NEXT: {{.*}} = llvm.mlir.constant(18 : index) : i64 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(37 : index) : i64 // CHECK-NEXT: llvm.br ^bb9({{.*}} : i64) -// CHECK-NEXT:^bb9({{.*}}: i64): // 2 preds: ^bb8, ^bb10 +// CHECK-NEXT:^bb9({{.*}}: i64): // 2 preds: ^bb10, ^bb8 // CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : i64 // CHECK-NEXT: llvm.cond_br {{.*}}, ^bb10, ^bb11 // CHECK-NEXT:^bb10: // pred: ^bb9 diff --git a/mlir/test/IR/pretty_printed_region_op.mlir b/mlir/test/IR/pretty_printed_region_op.mlir --- a/mlir/test/IR/pretty_printed_region_op.mlir +++ b/mlir/test/IR/pretty_printed_region_op.mlir @@ -33,3 +33,28 @@ return %0 : f32 } +// ----- + +// This tests the behavior of custom block names: +// operations like `test.block_names` can define custom names for blocks in +// nested regions. +// CHECK-CUSTOM-LABEL: func @block_names +func @block_names(%bool : i1) { + // CHECK: test.block_names + test.block_names { + // CHECK-CUSTOM: br ^foo1 + // CHECK-GENERIC: std.br{{.*}}^bb1 + br ^foo1 + // CHECK-CUSTOM: ^foo1: + // CHECK-GENERIC: ^bb1: + ^foo1: + // CHECK-CUSTOM: br ^foo2 + // CHECK-GENERIC: std.br{{.*}}^bb2 + br ^foo2 + // CHECK-CUSTOM: ^foo2: + // CHECK-GENERIC: ^bb2: + ^foo2: + "test.return"() : () -> () + } + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -627,6 +627,19 @@ let assemblyFormat = "regions attr-dict-with-keyword"; } +// This is used to test the OpAsmOpInterface::getAsmBlockName() feature: +// blocks nested in a region under this op will have a name defined by the +// interface. +def AsmBlockNameOp : TEST_Op<"block_names", [OpAsmOpInterface]> { + let regions = (region AnyRegion:$body); + let extraClassDeclaration = [{ + std::string getAsmBlockName(mlir::Block *block) { + return "foo"; + } + }]; + let assemblyFormat = "regions attr-dict-with-keyword"; +} + // This operation requires its return type to have the trait 'TestTypeTrait'. def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> { let results = (outs AnyType);