diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -75,11 +75,20 @@ template void create(OpBuilder &builder, SmallVectorImpl &results, Location location, Args &&... args) { - Operation *op = builder.create(location, std::forward(args)...); - if (failed(tryToFold(op, results))) + // The op needs to be inserted only if the fold (below) fails, or the number + // of results of the op is zero (which is treated as an in-place + // fold). Using create methods of the builder will insert the op, so not + // using it here. + OperationState state(location, OpTy::getOperationName()); + OpTy::build(&builder, state, std::forward(args)...); + Operation *op = Operation::create(state); + + if (failed(tryToFold(op, results)) || op->getNumResults() == 0) { + builder.insert(op); results.assign(op->result_begin(), op->result_end()); - else if (op->getNumResults() != 0) - op->erase(); + return; + } + op->destroy(); } /// Overload to create or fold a single result operation.