diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -75,11 +75,20 @@ template void create(OpBuilder &builder, SmallVectorImpl &results, Location location, Args &&... args) { - Operation *op = builder.create(location, std::forward(args)...); - if (failed(tryToFold(op, results))) + // The op needs to be inserted only if the fold (below) fails, or the number + // of results of the op is zero (which is treated as an in-place + // fold). Using create methods of the builder will insert the op, so not + // using it here. + OperationState state(location, OpTy::getOperationName()); + OpTy::build(&builder, state, std::forward(args)...); + Operation *op = Operation::create(state); + + if (failed(tryToFold(builder, op, results)) || op->getNumResults() == 0) { + builder.insert(op); results.assign(op->result_begin(), op->result_end()); - else if (op->getNumResults() != 0) - op->erase(); + return; + } + op->destroy(); } /// Overload to create or fold a single result operation. @@ -120,7 +129,7 @@ /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult tryToFold( - Operation *op, SmallVectorImpl &results, + OpBuilder &builder, Operation *op, SmallVectorImpl &results, function_ref processGeneratedConstants = nullptr); /// Try to get or create a new constant entry. On success this returns the diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -24,8 +24,8 @@ /// inserted into. static Region *getInsertionRegion( DialectInterfaceCollection &interfaces, - Operation *op) { - while (Region *region = op->getParentRegion()) { + Block *insertionBlock) { + while (Region *region = insertionBlock->getParent()) { // Insert in this region for any of the following scenarios: // * The parent is unregistered, or is known to be isolated from above. // * The parent is a top-level operation. @@ -40,7 +40,7 @@ return region; // Traverse up the parent looking for an insertion region. - op = parentOp; + insertionBlock = parentOp->getBlock(); } llvm_unreachable("expected valid insertion region"); } @@ -82,7 +82,8 @@ // Try to fold the operation. SmallVector results; - if (failed(tryToFold(op, results, processGeneratedConstants))) + OpBuilder builder(op); + if (failed(tryToFold(builder, op, results, processGeneratedConstants))) return failure(); // Check to see if the operation was just updated in place. @@ -117,7 +118,8 @@ assert(constValue); // Get the constant map that this operation was uniqued in. - auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)]; + auto &uniquedConstants = + foldScopes[getInsertionRegion(interfaces, op->getBlock())]; // Erase all of the references to this operation. auto type = op->getResult(0).getType(); @@ -135,7 +137,7 @@ /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult OperationFolder::tryToFold( - Operation *op, SmallVectorImpl &results, + OpBuilder &builder, Operation *op, SmallVectorImpl &results, function_ref processGeneratedConstants) { SmallVector operandConstants; SmallVector foldResults; @@ -164,9 +166,11 @@ // Create a builder to insert new operations into the entry block of the // insertion region. - auto *insertRegion = getInsertionRegion(interfaces, op); + auto *insertRegion = + getInsertionRegion(interfaces, builder.getInsertionBlock()); auto &entry = insertRegion->front(); - OpBuilder builder(&entry, entry.begin()); + OpBuilder::InsertionGuard foldGuard(builder); + builder.setInsertionPoint(&entry, entry.begin()); // Get the constant map for the insertion region of this operation. auto &uniquedConstants = foldScopes[insertRegion];