diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -75,11 +75,20 @@ template void create(OpBuilder &builder, SmallVectorImpl &results, Location location, Args &&... args) { - Operation *op = builder.create(location, std::forward(args)...); - if (failed(tryToFold(op, results))) + // The op needs to be inserted only if the fold (below) fails, or the number + // of results of the op is zero (which is treated as an in-place + // fold). Using create methods of the builder will insert the op, so not + // using it here. + OperationState state(location, OpTy::getOperationName()); + OpTy::build(&builder, state, std::forward(args)...); + Operation *op = Operation::create(state); + + if (failed(tryToFold(builder, op, results)) || op->getNumResults() == 0) { + builder.insert(op); results.assign(op->result_begin(), op->result_end()); - else if (op->getNumResults() != 0) - op->erase(); + return; + } + op->destroy(); } /// Overload to create or fold a single result operation. @@ -120,7 +129,7 @@ /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult tryToFold( - Operation *op, SmallVectorImpl &results, + OpBuilder &builder, Operation *op, SmallVectorImpl &results, function_ref processGeneratedConstants = nullptr); /// Try to get or create a new constant entry. On success this returns the diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -82,7 +82,8 @@ // Try to fold the operation. SmallVector results; - if (failed(tryToFold(op, results, processGeneratedConstants))) + OpBuilder builder(op); + if (failed(tryToFold(builder, op, results, processGeneratedConstants))) return failure(); // Constant folding succeeded. We will start replacing this op's uses and @@ -135,7 +136,7 @@ /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult OperationFolder::tryToFold( - Operation *op, SmallVectorImpl &results, + OpBuilder &builder, Operation *op, SmallVectorImpl &results, function_ref processGeneratedConstants) { SmallVector operandConstants; SmallVector foldResults; @@ -164,9 +165,11 @@ // Create a builder to insert new operations into the entry block of the // insertion region. - auto *insertRegion = getInsertionRegion(interfaces, op); + Operation &insertionPoint = *builder.getInsertionPoint(); + auto *insertRegion = getInsertionRegion(interfaces, &insertionPoint); auto &entry = insertRegion->front(); - OpBuilder builder(&entry, entry.begin()); + OpBuilder::InsertionGuard foldGuard(builder); + builder.setInsertionPoint(&entry, entry.begin()); // Get the constant map for the insertion region of this operation. auto &uniquedConstants = foldScopes[insertRegion];