diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -156,54 +156,23 @@ /// Recomputes the ordering of child operations within the block. void recomputeOpOrder(); -private: - /// A utility iterator that filters out operations that are not 'OpT'. - template - class op_filter_iterator - : public llvm::filter_iterator { - static bool filter(Operation &op) { return llvm::isa(op); } - - public: - op_filter_iterator(Block::iterator it, Block::iterator end) - : llvm::filter_iterator( - it, end, &filter) {} - - /// Allow implicit conversion to the underlying block iterator. - operator Block::iterator() const { return this->wrapped(); } - }; - -public: /// This class provides iteration over the held operations of a block for a /// specific operation type. template - class op_iterator : public llvm::mapped_iterator, - OpT (*)(Operation &)> { - static OpT unwrap(Operation &op) { return cast(op); } - - public: - using reference = OpT; - - /// Initializes the iterator to the specified filter iterator. - op_iterator(op_filter_iterator it) - : llvm::mapped_iterator, OpT (*)(Operation &)>( - it, &unwrap) {} - - /// Allow implicit conversion to the underlying block iterator. - operator Block::iterator() const { return this->wrapped(); } - }; + using op_iterator = detail::op_iterator; /// Return an iterator range over the operations within this block that are of /// 'OpT'. template iterator_range> getOps() { auto endIt = end(); - return {op_filter_iterator(begin(), endIt), - op_filter_iterator(endIt, endIt)}; + return {detail::op_filter_iterator(begin(), endIt), + detail::op_filter_iterator(endIt, endIt)}; } template op_iterator op_begin() { - return op_filter_iterator(begin(), end()); + return detail::op_filter_iterator(begin(), end()); } template op_iterator op_end() { - return op_filter_iterator(end(), end()); + return detail::op_filter_iterator(end(), end()); } /// Return an iterator range over the operation within this block excluding diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -75,6 +75,46 @@ friend RangeBaseT; }; +//===----------------------------------------------------------------------===// +// Operation Iterators +//===----------------------------------------------------------------------===// + +namespace detail { +/// A utility iterator that filters out operations that are not 'OpT'. +template +class op_filter_iterator + : public llvm::filter_iterator { + static bool filter(Operation &op) { return llvm::isa(op); } + +public: + op_filter_iterator(IteratorT it, IteratorT end) + : llvm::filter_iterator(it, end, + &filter) {} + + /// Allow implicit conversion to the underlying iterator. + operator IteratorT() const { return this->wrapped(); } +}; + +/// This class provides iteration over the held operations of a block for a +/// specific operation type. +template +class op_iterator + : public llvm::mapped_iterator, + OpT (*)(Operation &)> { + static OpT unwrap(Operation &op) { return cast(op); } + +public: + using reference = OpT; + + /// Initializes the iterator to the specified filter iterator. + op_iterator(op_filter_iterator it) + : llvm::mapped_iterator, + OpT (*)(Operation &)>(it, &unwrap) {} + + /// Allow implicit conversion to the underlying block iterator. + operator IteratorT() const { return this->wrapped(); } +}; +} // end namespace detail } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -30,10 +30,11 @@ /// implicitly capture global values, and all external references must use /// Function arguments or attributes that establish a symbolic connection(e.g. /// symbols referenced by name via a string attribute). -class FuncOp : public Op { +class FuncOp + : public Op { public: using Op::Op; using Op::print; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -583,6 +583,13 @@ public: Region &getRegion() { return this->getOperation()->getRegion(0); } + /// Returns a range of operations within the region of this operation. + auto getOps() { return getRegion().getOps(); } + template + auto getOps() { + return getRegion().template getOps(); + } + static LogicalResult verifyTrait(Operation *op) { return impl::verifyOneRegion(op); } diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -34,6 +34,10 @@ /// parent container. The region must have a valid parent container. Location getLoc(); + //===--------------------------------------------------------------------===// + // Block list management + //===--------------------------------------------------------------------===// + using BlockListType = llvm::iplist; BlockListType &getBlocks() { return blocks; } @@ -58,6 +62,72 @@ return &Region::blocks; } + //===--------------------------------------------------------------------===// + // Operation list utilities + //===--------------------------------------------------------------------===// + + /// This class provides iteration over the held operations of blocks directly + /// within a region. + class OpIterator final + : public llvm::iterator_facade_base { + public: + /// Initialize OpIterator for a region, specify `end` to return the iterator + /// to last operation. + explicit OpIterator(Region *region, bool end = false); + + using llvm::iterator_facade_base::operator++; + OpIterator &operator++(); + Operation *operator->() const { return &*operation; } + Operation &operator*() const { return *operation; } + + /// Compare this iterator with another. + bool operator==(const OpIterator &rhs) const { + return operation == rhs.operation; + } + bool operator!=(const OpIterator &rhs) const { return !(*this == rhs); } + + private: + void skipOverBlocksWithNoOps(); + + /// The region whose operations are being iterated over. + Region *region; + /// The block of 'region' whose operations are being iterated over. + Region::iterator block; + /// The current operation within 'block'. + Block::iterator operation; + }; + + /// This class provides iteration over the held operations of a region for a + /// specific operation type. + template + using op_iterator = detail::op_iterator; + + /// Return iterators that walk the operations nested directly within this + /// region. + OpIterator op_begin() { return OpIterator(this); } + OpIterator op_end() { return OpIterator(this, /*end=*/true); } + iterator_range getOps() { return {op_begin(), op_end()}; } + + /// Return iterators that walk operations of type 'T' nested directly within + /// this region. + template op_iterator op_begin() { + return detail::op_filter_iterator(op_begin(), op_end()); + } + template op_iterator op_end() { + return detail::op_filter_iterator(op_end(), op_end()); + } + template iterator_range> getOps() { + auto endIt = op_end(); + return {detail::op_filter_iterator(op_begin(), endIt), + detail::op_filter_iterator(endIt, endIt)}; + } + + //===--------------------------------------------------------------------===// + // Misc. utilities + //===--------------------------------------------------------------------===// + /// Return the region containing this region or nullptr if the region is /// attached to a top-level operation. Region *getParentRegion(); @@ -120,6 +190,10 @@ /// they are to be deleted. void dropAllReferences(); + //===--------------------------------------------------------------------===// + // Operation Walkers + //===--------------------------------------------------------------------===// + /// Walk the operations in this region in postorder, calling the callback for /// each operation. This method is invoked for void-returning callbacks. /// See Operation::walk for more details. @@ -142,6 +216,10 @@ return WalkResult::advance(); } + //===--------------------------------------------------------------------===// + // CFG view utilities + //===--------------------------------------------------------------------===// + /// Displays the CFG in a window. This is for use from the debugger and /// depends on Graphviz to generate the graph. /// This function is defined in ViewRegionGraph and only works with that diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -87,9 +87,8 @@ } for (Region ®ion : op->getRegions()) - for (Block &block : region) - for (Operation &nested : block) - computeCallGraph(&nested, cg, parentNode, resolveCalls); + for (Operation &nested : region.getOps()) + computeCallGraph(&nested, cg, parentNode, resolveCalls); } CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp @@ -36,18 +36,17 @@ } void runOnFunction() override { - for (Block &block : getFunction()) - for (Operation &op : llvm::make_early_inc_range(block)) { - if (auto forOp = dyn_cast(&op)) { - if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims, - numThreadDims))) - signalPassFailure(); - } else if (auto forOp = dyn_cast(&op)) { - if (failed(convertLoopNestToGPULaunch(forOp, numBlockDims, - numThreadDims))) - signalPassFailure(); - } + for (Operation &op : llvm::make_early_inc_range(getFunction().getOps())) { + if (auto forOp = dyn_cast(&op)) { + if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims, + numThreadDims))) + signalPassFailure(); + } else if (auto forOp = dyn_cast(&op)) { + if (failed( + convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims))) + signalPassFailure(); } + } } }; @@ -81,14 +80,10 @@ funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val)); workGroupSizeVal.push_back(constOp); } - for (Block &block : getFunction()) { - for (Operation &op : llvm::make_early_inc_range(block)) { - if (auto forOp = dyn_cast(&op)) { - if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal, - workGroupSizeVal))) { - return signalPassFailure(); - } - } + for (ForOp forOp : llvm::make_early_inc_range(funcOp.getOps())) { + if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal, + workGroupSizeVal))) { + return signalPassFailure(); } } } diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -146,34 +146,32 @@ // Traverse all operations in the region. while (!pendingRegions.empty()) { - for (Block &block : *pendingRegions.pop_back_val()) { - for (Operation &op : block) { - for (Value operand : op.getOperands()) { - // operand should be non-null here if the IR is well-formed. But - // we don't assert here as this function is called from the verifier - // and so could be called on invalid IR. - if (!operand) { - if (noteLoc) - op.emitOpError("block's operand not defined").attachNote(noteLoc); - return false; - } + for (Operation &op : pendingRegions.pop_back_val()->getOps()) { + for (Value operand : op.getOperands()) { + // operand should be non-null here if the IR is well-formed. But + // we don't assert here as this function is called from the verifier + // and so could be called on invalid IR. + if (!operand) { + if (noteLoc) + op.emitOpError("block's operand not defined").attachNote(noteLoc); + return false; + } - // Check that any value that is used by an operation is defined in the - // same region as either an operation result or a block argument. - if (operand.getParentRegion()->isProperAncestor(&limit)) { - if (noteLoc) { - op.emitOpError("using value defined outside the region") - .attachNote(noteLoc) - << "required by region isolation constraints"; - } - return false; + // Check that any value that is used by an operation is defined in the + // same region as either an operation result or a block argument. + if (operand.getParentRegion()->isProperAncestor(&limit)) { + if (noteLoc) { + op.emitOpError("using value defined outside the region") + .attachNote(noteLoc) + << "required by region isolation constraints"; } + return false; } - // Schedule any regions the operations contain for further checking. - pendingRegions.reserve(pendingRegions.size() + op.getNumRegions()); - for (Region &subRegion : op.getRegions()) - pendingRegions.push_back(&subRegion); } + // Schedule any regions the operations contain for further checking. + pendingRegions.reserve(pendingRegions.size() + op.getNumRegions()); + for (Region &subRegion : op.getRegions()) + pendingRegions.push_back(&subRegion); } } return true; @@ -219,6 +217,40 @@ first->parentValidOpOrderPair.setPointer(curParent); } +//===----------------------------------------------------------------------===// +// Region::OpIterator +//===----------------------------------------------------------------------===// + +Region::OpIterator::OpIterator(Region *region, bool end) + : region(region), block(end ? region->end() : region->begin()) { + if (!region->empty()) + skipOverBlocksWithNoOps(); +} + +Region::OpIterator &Region::OpIterator::operator++() { + // We increment over operations, if we reach the last use then move to next + // block. + if (operation != block->end()) + ++operation; + if (operation == block->end()) { + ++block; + skipOverBlocksWithNoOps(); + } + return *this; +} + +void Region::OpIterator::skipOverBlocksWithNoOps() { + while (block != region->end() && block->empty()) + ++block; + + // If we are at the last block, then set the operation to first operation of + // next block (sentinel value used for end). + if (block == region->end()) + operation = {}; + else + operation = block->begin(); +} + //===----------------------------------------------------------------------===// // RegionRange //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -245,11 +245,9 @@ assert(symbolTableOp->hasTrait()); // Look for a symbol with the given name. - for (auto &block : symbolTableOp->getRegion(0)) { - for (auto &op : block) - if (getNameIfSymbol(&op) == symbol) - return &op; - } + for (auto &op : symbolTableOp->getRegion(0).front().without_terminator()) + if (getNameIfSymbol(&op) == symbol) + return &op; return nullptr; } Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, @@ -444,21 +442,19 @@ function_ref)> callback) { SmallVector worklist(llvm::make_pointer_range(regions)); while (!worklist.empty()) { - for (Block &block : *worklist.pop_back_val()) { - for (Operation &op : block) { - if (walkSymbolRefs(&op, callback).wasInterrupted()) - return WalkResult::interrupt(); - - // Check that this isn't a potentially unknown symbol table. - if (isPotentiallyUnknownSymbolTable(&op)) - return llvm::None; - - // If this op defines a new symbol table scope, we can't traverse. Any - // symbol references nested within 'op' are different semantically. - if (!op.hasTrait()) { - for (Region ®ion : op.getRegions()) - worklist.push_back(®ion); - } + for (Operation &op : worklist.pop_back_val()->getOps()) { + if (walkSymbolRefs(&op, callback).wasInterrupted()) + return WalkResult::interrupt(); + + // Check that this isn't a potentially unknown symbol table. + if (isPotentiallyUnknownSymbolTable(&op)) + return llvm::None; + + // If this op defines a new symbol table scope, we can't traverse. Any + // symbol references nested within 'op' are different semantically. + if (!op.hasTrait()) { + for (Region ®ion : op.getRegions()) + worklist.push_back(®ion); } } } diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -122,23 +122,21 @@ // Walk each of the symbol tables looking for discardable callgraph nodes. auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { - for (Block &block : symbolTableOp->getRegion(0)) { - for (Operation &op : block) { - // If this is a callgraph operation, check to see if it is discardable. - if (auto callable = dyn_cast(&op)) { - if (auto *node = cg.lookupNode(callable.getCallableRegion())) { - SymbolOpInterface symbol = dyn_cast(&op); - if (symbol && (allUsesVisible || symbol.isPrivate()) && - symbol.canDiscardOnUseEmpty()) { - discardableSymNodeUses.try_emplace(node, 0); - } - continue; + for (Operation &op : symbolTableOp->getRegion(0).getOps()) { + // If this is a callgraph operation, check to see if it is discardable. + if (auto callable = dyn_cast(&op)) { + if (auto *node = cg.lookupNode(callable.getCallableRegion())) { + SymbolOpInterface symbol = dyn_cast(&op); + if (symbol && (allUsesVisible || symbol.isPrivate()) && + symbol.canDiscardOnUseEmpty()) { + discardableSymNodeUses.try_emplace(node, 0); } + continue; } - // Otherwise, check for any referenced nodes. These will be always-live. - walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes, - [](CallGraphNode *, Operation *) {}); } + // Otherwise, check for any referenced nodes. These will be always-live. + walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes, + [](CallGraphNode *, Operation *) {}); } }; SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),