diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -9,12 +9,12 @@ #ifndef MLIR_IR_BUILDERS_H #define MLIR_IR_BUILDERS_H +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/OpDefinition.h" namespace mlir { class AffineExpr; -class BlockAndValueMapping; class ModuleOp; class UnknownLoc; class FileLineColLoc; @@ -459,9 +459,25 @@ /// cloned sub-operations to the corresponding operation that is copied, /// and adds those mappings to the map. Operation *clone(Operation &op, BlockAndValueMapping &mapper) { - return insert(op.clone(mapper)); + bool alreadyHandledFirst = false; + auto notifyAllButFirst = [&](Operation *clonedOp) { + // The first operation that this callback gets called with is the op + // returned by op.clone(...) itself. The notification for that op will + // happen inside `insert`. But for any recursively nested ops, we need to + // notify. + if (!alreadyHandledFirst) { + alreadyHandledFirst = true; + return; + } + if (listener) + listener->notifyOperationInserted(clonedOp); + }; + return insert(op.clone(mapper, notifyAllButFirst)); + } + Operation *clone(Operation &op) { + BlockAndValueMapping mapper; + return clone(op, mapper); } - Operation *clone(Operation &op) { return insert(op.clone()); } /// Creates a deep copy of this operation but keep the operation regions /// empty. Operands are remapped using `mapper` (if present), and `mapper` is diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -74,7 +74,15 @@ /// them alone if no entry is present). Replaces references to cloned /// sub-operations to the corresponding operation that is copied, and adds /// those mappings to the map. - Operation *clone(BlockAndValueMapping &mapper); + /// + /// Additionally, if `opCreationListener` is provided, it will be called for + /// every new operation that is created. In the case of cloning ops with + /// regions, the listener is called on the parent op before its regions have + /// been populated, and then recursively on all operations inside the region + /// as the regions are being populated. + Operation * + clone(BlockAndValueMapping &mapper, + llvm::function_ref opCreationListener = nullptr); Operation *clone(); /// Create a partial copy of this operation without traversing into attached diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -212,8 +212,13 @@ /// in the respective cloned block. void cloneInto(Region *dest, BlockAndValueMapping &mapper); /// Clone this region into 'dest' before the given position in 'dest'. - void cloneInto(Region *dest, Region::iterator destPos, - BlockAndValueMapping &mapper); + /// + /// The provided opCreationListener is passed to any Operation::clone calls + /// made by this function. + void + cloneInto(Region *dest, Region::iterator destPos, + BlockAndValueMapping &mapper, + llvm::function_ref opCreationListener = nullptr); /// Takes body of another region (that region will have no body after this /// operation completes). The current body of this region is cleared. diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -624,12 +624,23 @@ /// them alone if no entry is present). Replaces references to cloned /// sub-operations to the corresponding operation that is copied, and adds /// those mappings to the map. -Operation *Operation::clone(BlockAndValueMapping &mapper) { +/// +/// Additionally, if `opCreationListener` is provided, it will be called for +/// every new operation that is created. In the case of cloning ops with +/// regions, the listener is called on the parent op before its regions have +/// been populated, and then recursively on all operations inside the region +/// as the regions are being populated. +Operation * +Operation::clone(BlockAndValueMapping &mapper, + llvm::function_ref opCreationListener) { auto *newOp = cloneWithoutRegions(mapper); + if (opCreationListener) + opCreationListener(newOp); // Clone the regions. for (unsigned i = 0; i != numRegions; ++i) - getRegion(i).cloneInto(&newOp->getRegion(i), mapper); + getRegion(i).cloneInto(&newOp->getRegion(i), newOp->getRegion(i).end(), + mapper, opCreationListener); return newOp; } diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -76,8 +76,12 @@ } /// Clone this region into 'dest' before the given position in 'dest'. -void Region::cloneInto(Region *dest, Region::iterator destPos, - BlockAndValueMapping &mapper) { +/// +/// The provided opCreationListener is passed to any Operation::clone calls +/// made by this function. +void Region::cloneInto( + Region *dest, Region::iterator destPos, BlockAndValueMapping &mapper, + llvm::function_ref opCreationListener) { assert(dest && "expected valid region to clone into"); assert(this != dest && "cannot clone region into itself"); @@ -98,7 +102,7 @@ // Clone and remap the operations within this block. for (auto &op : block) - newBlock->push_back(op.clone(mapper)); + newBlock->push_back(op.clone(mapper, opCreationListener)); dest->getBlocks().insert(destPos, newBlock); } diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -86,3 +86,21 @@ %0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex> return %0 : tensor<2xindex> } + +// The dynamic_tensor_from_elements op clones each op in its body. +// Make sure that regions nested within such ops are recursively converted. +// CHECK-LABEL: func @recursively_convert_cloned_regions +func @recursively_convert_cloned_regions(%arg0: tensor, %arg1: index, %arg2: i1) -> tensor { + %tensor = dynamic_tensor_from_elements %arg1 { + ^bb0(%iv: index): + %48 = scf.if %arg2 -> (index) { + scf.yield %iv : index + } else { + // CHECK-NOT: extract_element + %50 = extract_element %arg0[%iv] : tensor + scf.yield %50 : index + } + yield %48 : index + } : tensor + return %tensor : tensor +}