diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -942,7 +942,7 @@ // emitting the regions first (e.g. if the regions are huge, backpatching the // op encoding mask is more annoying). if (numRegions) { - bool isIsolatedFromAbove = op->hasTrait(); + bool isIsolatedFromAbove = numberingState.isIsolatedFromAbove(op); emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove); // If the region is not isolated from above, or we are emitting bytecode diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -126,6 +126,22 @@ llvm::MapVector resourceMap; }; +//===----------------------------------------------------------------------===// +// Operation Numbering +//===----------------------------------------------------------------------===// + +/// This class represents the numbering entry of an operation. +struct OperationNumbering { + OperationNumbering(unsigned number) : number(number) {} + + /// The number assigned to this operation. + unsigned number; + + /// A flag indicating if this operation's regions are isolated. If unset, the + /// operation isn't yet known to be isolated. + std::optional isIsolatedFromAbove; +}; + //===----------------------------------------------------------------------===// // IRNumberingState //===----------------------------------------------------------------------===// @@ -154,8 +170,8 @@ return blockIDs[block]; } unsigned getNumber(Operation *op) { - assert(operationIDs.count(op) && "operation not numbered"); - return operationIDs[op]; + assert(operations.count(op) && "operation not numbered"); + return operations[op]->number; } unsigned getNumber(OperationName opName) { assert(opNames.count(opName) && "opName not numbered"); @@ -186,14 +202,23 @@ return blockOperationCounts[block]; } + /// Return if the given operation is isolated from above. + bool isIsolatedFromAbove(Operation *op) { + assert(operations.count(op) && "operation not numbered"); + return operations[op]->isIsolatedFromAbove.value_or(false); + } + /// Get the set desired bytecode version to emit. int64_t getDesiredBytecodeVersion() const; - + private: /// This class is used to provide a fake dialect writer for numbering nested /// attributes and types. struct NumberingDialectWriter; + /// Compute the global numbering state for the given root operation. + void computeGlobalNumberingState(Operation *rootOp); + /// Number the given IR unit for bytecode emission. void number(Attribute attr); void number(Block &block); @@ -212,6 +237,7 @@ /// Mapping from IR to the respective numbering entries. DenseMap attrs; + DenseMap operations; DenseMap opNames; DenseMap types; DenseMap registeredDialects; @@ -228,12 +254,12 @@ /// Allocators used for the various numbering entries. llvm::SpecificBumpPtrAllocator attrAllocator; llvm::SpecificBumpPtrAllocator dialectAllocator; + llvm::SpecificBumpPtrAllocator opAllocator; llvm::SpecificBumpPtrAllocator opNameAllocator; llvm::SpecificBumpPtrAllocator resourceAllocator; llvm::SpecificBumpPtrAllocator typeAllocator; - /// The value ID for each Operation, Block and Value. - DenseMap operationIDs; + /// The value ID for each Block and Value. DenseMap blockIDs; DenseMap valueIDs; diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -115,19 +115,29 @@ IRNumberingState::IRNumberingState(Operation *op, const BytecodeWriterConfig &config) : config(config) { - // Compute a global operation ID numbering according to the pre-order walk of - // the IR. This is used as reference to construct use-list orders. - unsigned operationID = 0; - op->walk( - [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); + computeGlobalNumberingState(op); // Number the root operation. number(*op); - // Push all of the regions of the root operation onto the worklist. + // A worklist of region contexts to number and the next value id before that + // region. SmallVector, 8> numberContext; - for (Region ®ion : op->getRegions()) - numberContext.emplace_back(®ion, nextValueID); + + // Functor to push the regions of the given operation onto the numbering + // context. + auto addOpRegionsToNumber = [&](Operation *op) { + MutableArrayRef regions = op->getRegions(); + if (regions.empty()) + return; + + // Isolated regions don't share value numbers with their parent, so we can + // start numbering these regions at zero. + unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID; + for (Region ®ion : regions) + numberContext.emplace_back(®ion, opFirstValueID); + }; + addOpRegionsToNumber(op); // Iteratively process each of the nested regions. while (!numberContext.empty()) { @@ -136,14 +146,8 @@ number(*region); // Traverse into nested regions. - for (Operation &op : region->getOps()) { - // Isolated regions don't share value numbers with their parent, so we can - // start numbering these regions at zero. - unsigned opFirstValueID = - op.hasTrait() ? 0 : nextValueID; - for (Region ®ion : op.getRegions()) - numberContext.emplace_back(®ion, opFirstValueID); - } + for (Operation &op : region->getOps()) + addOpRegionsToNumber(&op); } // Number each of the dialects. For now this is just in the order they were @@ -178,6 +182,116 @@ finalizeDialectResourceNumberings(op); } +void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) { + // A simple state struct tracking data used when walking operations. + struct StackState { + /// The operation currently being walked. + Operation *op; + + /// The numbering of the operation. + OperationNumbering *numbering; + + /// A flag indicating if the current state or one of its parents has + /// unresolved isolation status. This is tracked separately from the + /// isIsolatedFromAbove bit on `numbering` because we need to be able to + /// handle the given case: + /// top.op { + /// %value = ... + /// middle.op { + /// %value2 = ... + /// inner.op { + /// // Here we mark `inner.op` as not isolated. Note `middle.op` + /// // isn't known not isolated yet. + /// use.op %value2 + /// + /// // Here inner.op is already known to be non-isolated, but + /// // `middle.op` is now also discovered to be non-isolated. + /// use.op %value + /// } + /// } + /// } + bool hasUnresolvedIsolation; + }; + + // Compute a global operation ID numbering according to the pre-order walk of + // the IR. This is used as reference to construct use-list orders. + unsigned operationID = 0; + + // Walk each of the operations within the IR, tracking a stack of operations + // as we recurse into nested regions. This walk method hooks in at two stages + // during the walk: + // + // BeforeAllRegions: + // Here we generate a numbering for the operation and push it onto the + // stack if it has regions. We also compute the isolation status of parent + // regions at this stage. This is done by checking the parent regions of + // operands used by the operation, and marking each region between the + // the operand region and the current as not isolated. See + // StackState::hasUnresolvedIsolation above for an example. + // + // AfterAllRegions: + // Here we pop the operation from the stack, and if it hasn't been marked + // as non-isolated, we mark it as so. A non-isolated use would have been + // found while walking the regions, so it is safe to mark the operation at + // this point. + // + SmallVector opStack; + rootOp->walk([&](Operation *op, const WalkStage &stage) { + // After visiting all nested regions, we pop the operation from the stack. + if (stage.isAfterAllRegions()) { + // If no non-isolated uses were found, we can safely mark this operation + // as isolated from above. + OperationNumbering *numbering = opStack.pop_back_val().numbering; + if (!numbering->isIsolatedFromAbove.has_value()) + numbering->isIsolatedFromAbove = true; + return; + } + + // When visiting before nested regions, we process "IsolatedFromAbove" + // checks and compute the number for this operation. + if (!stage.isBeforeAllRegions()) + return; + // Update the isolation status of parent regions if any have yet to be + // resolved. + if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) { + Region *parentRegion = op->getParentRegion(); + for (Value operand : op->getOperands()) { + Region *operandRegion = operand.getParentRegion(); + if (operandRegion == parentRegion) + continue; + // We've found a use of an operand outside of the current region, + // walk the operation stack searching for the parent operation, + // marking every region on the way as not isolated. + Operation *operandContainerOp = operandRegion->getParentOp(); + auto it = std::find_if( + opStack.rbegin(), opStack.rend(), [=](const StackState &it) { + // We only need to mark up to the container region, or the first + // that has an unresolved status. + return !it.hasUnresolvedIsolation || it.op == operandContainerOp; + }); + assert(it != opStack.rend() && "expected to find the container"); + for (auto &state : llvm::make_range(opStack.rbegin(), it)) { + // If we stopped at a region that knows its isolation status, we can + // stop updating the isolation status for the parent regions. + state.hasUnresolvedIsolation = it->hasUnresolvedIsolation; + state.numbering->isIsolatedFromAbove = false; + } + } + } + + // Compute the number for this op and push it onto the stack. + auto *numbering = + new (opAllocator.Allocate()) OperationNumbering(operationID++); + if (op->hasTrait()) + numbering->isIsolatedFromAbove = true; + operations.try_emplace(op, numbering); + if (op->getNumRegions()) { + opStack.emplace_back(StackState{ + op, numbering, !numbering->isIsolatedFromAbove.has_value()}); + } + }); +} + void IRNumberingState::number(Attribute attr) { auto it = attrs.insert({attr, nullptr}); if (!it.second) { diff --git a/mlir/test/Bytecode/bytecode-lazy-loading.mlir b/mlir/test/Bytecode/bytecode-lazy-loading.mlir --- a/mlir/test/Bytecode/bytecode-lazy-loading.mlir +++ b/mlir/test/Bytecode/bytecode-lazy-loading.mlir @@ -23,6 +23,21 @@ }, { "test.unknown_op"() : () -> () } + + // Ensure operations that aren't tagged as IsolatedFromAbove can + // still be lazy loaded if they don't have references to values + // defined above. + "test.one_region_op"() ({ + "test.unknown_op"() : () -> () + }) : () -> () + + // Similar test as above, but check that if one region has a reference + // to a value defined above, we don't lazy load the operation. + "test.two_region_op"() ({ + "test.unknown_op"() : () -> () + }, { + "test.consumer"(%0) : (index) -> () + }) : () -> () return } @@ -53,7 +68,12 @@ // CHECK: test.consumer // CHECK: isolated_region // CHECK-NOT: test.consumer -// CHECK: Has 3 ops to materialize +// CHECK: test.one_region_op +// CHECK-NOT: test.op +// CHECK: test.two_region_op +// CHECK: test.unknown_op +// CHECK: test.consumer +// CHECK: Has 4 ops to materialize // CHECK: Before Materializing... // CHECK: test.isolated_region @@ -62,7 +82,7 @@ // CHECK: test.isolated_region // CHECK: ^bb0(%arg0: index): // CHECK: test.consumer -// CHECK: Has 2 ops to materialize +// CHECK: Has 3 ops to materialize // CHECK: Before Materializing... // CHECK: test.isolated_region @@ -70,7 +90,7 @@ // CHECK: Materializing... // CHECK: test.isolated_region // CHECK: test.consumer -// CHECK: Has 1 ops to materialize +// CHECK: Has 2 ops to materialize // CHECK: Before Materializing... // CHECK: test.isolated_regions @@ -79,4 +99,12 @@ // CHECK: test.isolated_regions // CHECK: test.unknown_op // CHECK: test.unknown_op +// CHECK: Has 1 ops to materialize + +// CHECK: Before Materializing... +// CHECK: test.one_region_op +// CHECK-NOT: test.unknown_op +// CHECK: Materializing... +// CHECK: test.one_region_op +// CHECK: test.unknown_op // CHECK: Has 0 ops to materialize