diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -624,7 +624,9 @@ protected: /// Initialize the builder. - explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx) {} + explicit RewriterBase(MLIRContext *ctx, + OpBuilder::Listener *listener = nullptr) + : OpBuilder(ctx, listener) {} explicit RewriterBase(const OpBuilder &otherBuilder) : OpBuilder(otherBuilder) {} virtual ~RewriterBase(); @@ -648,7 +650,8 @@ /// such as a `PatternRewriter`, is not available. class IRRewriter : public RewriterBase { public: - explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} + explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr) + : RewriterBase(ctx, listener) {} explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} }; diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -32,8 +32,8 @@ /// generated along the way. class OperationFolder { public: - OperationFolder(MLIRContext *ctx, RewriterBase::Listener *listener = nullptr) - : interfaces(ctx), listener(listener) {} + OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr) + : interfaces(ctx), rewriter(ctx, listener) {} /// Tries to perform folding on the given `op`, including unifying /// deduplicated constants. If successful, replaces `op`'s uses with @@ -61,10 +61,11 @@ /// Clear out any constants cached inside of the folder. void clear(); - /// Get or create a constant using the given builder. On success this returns - /// the constant operation, nullptr otherwise. - Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, - Attribute value, Type type, Location loc); + /// Get or create a constant for use in the specified block. The constant may + /// be created in a parent block. On success this returns the constant + /// operation, nullptr otherwise. + Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, + Type type, Location loc); private: /// This map keeps track of uniqued constants by dialect, attribute, and type. @@ -74,29 +75,25 @@ using ConstantMap = DenseMap, Operation *>; - /// Erase the given operation and notify the listener. - void eraseOp(Operation *op); - /// Returns true if the given operation is an already folded constant that is /// owned by this folder. bool isFolderOwnedConstant(Operation *op) const; /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. - LogicalResult tryToFold(OpBuilder &builder, Operation *op, - SmallVectorImpl &results); + LogicalResult tryToFold(Operation *op, SmallVectorImpl &results); - /// Try to process a set of fold results, generating constants as necessary. - /// Populates `results` on success, otherwise leaves it unchanged. - LogicalResult processFoldResults(OpBuilder &builder, Operation *op, + /// Try to process a set of fold results. Populates `results` on success, + /// otherwise leaves it unchanged. + LogicalResult processFoldResults(Operation *op, SmallVectorImpl &results, ArrayRef foldResults); /// Try to get or create a new constant entry. On success this returns the /// constant operation, nullptr otherwise. Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants, - Dialect *dialect, OpBuilder &builder, - Attribute value, Type type, Location loc); + Dialect *dialect, Attribute value, + Type type, Location loc); /// A mapping between an insertion region and the constants that have been /// created within it. @@ -109,8 +106,8 @@ /// A collection of dialect folder interfaces. DialectInterfaceCollection interfaces; - /// An optional listener that is notified of all IR changes. - RewriterBase::Listener *listener = nullptr; + /// A rewriter that performs all IR modifications. + IRRewriter rewriter; }; } // namespace mlir diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -51,9 +51,9 @@ // Attempt to materialize a constant for the given value. Dialect *dialect = latticeValue.getConstantDialect(); - Value constant = folder.getOrCreateConstant(builder, dialect, - latticeValue.getConstantValue(), - value.getType(), value.getLoc()); + Value constant = folder.getOrCreateConstant( + builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(), + value.getType(), value.getLoc()); if (!constant) return failure(); diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -84,26 +84,25 @@ // Try to fold the operation. SmallVector results; - OpBuilder builder(op, listener); - if (failed(tryToFold(builder, op, results))) + if (failed(tryToFold(op, results))) return failure(); // Check to see if the operation was just updated in place. if (results.empty()) { if (inPlaceUpdate) *inPlaceUpdate = true; - if (listener) - listener->notifyOperationModified(op); + if (auto *rewriteListener = dyn_cast_if_present( + rewriter.getListener())) { + // Folding API does not notify listeners, so we have to notify manually. + rewriteListener->notifyOperationModified(op); + } return success(); } // Constant folding succeeded. Replace all of the result values and erase the // operation. - if (listener) - listener->notifyOperationReplaced(op, results); - for (unsigned i = 0, e = results.size(); i != e; ++i) - op->getResult(i).replaceAllUsesWith(results[i]); - eraseOp(op); + notifyRemoval(op); + rewriter.replaceOp(op, results); return success(); } @@ -141,10 +140,8 @@ // If there is an existing constant, replace `op`. if (folderConstOp) { - if (listener) - listener->notifyOperationReplaced(op, folderConstOp->getResults()); - op->replaceAllUsesWith(folderConstOp); - eraseOp(op); + notifyRemoval(op); + rewriter.replaceOp(op, folderConstOp->getResults()); return false; } @@ -162,13 +159,6 @@ return true; } -void OperationFolder::eraseOp(Operation *op) { - notifyRemoval(op); - if (listener) - listener->notifyOperationRemoved(op); - op->erase(); -} - /// Notifies that the given constant `op` should be remove from this /// OperationFolder's internal bookkeeping. void OperationFolder::notifyRemoval(Operation *op) { @@ -202,22 +192,18 @@ /// Get or create a constant using the given builder. On success this returns /// the constant operation, nullptr otherwise. -Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect, +Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, Type type, Location loc) { - OpBuilder::InsertionGuard foldGuard(builder); - - // Use the builder insertion block to find an insertion point for the - // constant. - auto *insertRegion = - getInsertionRegion(interfaces, builder.getInsertionBlock()); + // Find an insertion point for the constant. + auto *insertRegion = getInsertionRegion(interfaces, block); auto &entry = insertRegion->front(); - builder.setInsertionPoint(&entry, entry.begin()); + rewriter.setInsertionPoint(&entry, entry.begin()); // Get the constant map for the insertion region of this operation. auto &uniquedConstants = foldScopes[insertRegion]; - Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, - builder, value, type, loc); + Operation *constOp = + tryGetOrCreateConstant(uniquedConstants, dialect, value, type, loc); return constOp ? constOp->getResult(0) : Value(); } @@ -227,7 +213,7 @@ /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. -LogicalResult OperationFolder::tryToFold(OpBuilder &builder, Operation *op, +LogicalResult OperationFolder::tryToFold(Operation *op, SmallVectorImpl &results) { SmallVector operandConstants; @@ -257,13 +243,13 @@ // fold. SmallVector foldResults; if (failed(op->fold(operandConstants, foldResults)) || - failed(processFoldResults(builder, op, results, foldResults))) + failed(processFoldResults(op, results, foldResults))) return success(updatedOpOperands); return success(); } LogicalResult -OperationFolder::processFoldResults(OpBuilder &builder, Operation *op, +OperationFolder::processFoldResults(Operation *op, SmallVectorImpl &results, ArrayRef foldResults) { // Check to see if the operation was just updated in place. @@ -273,11 +259,9 @@ // Create a builder to insert new operations into the entry block of the // insertion region. - auto *insertRegion = - getInsertionRegion(interfaces, builder.getInsertionBlock()); + auto *insertRegion = getInsertionRegion(interfaces, op->getBlock()); auto &entry = insertRegion->front(); - OpBuilder::InsertionGuard foldGuard(builder); - builder.setInsertionPoint(&entry, entry.begin()); + rewriter.setInsertionPoint(&entry, entry.begin()); // Get the constant map for the insertion region of this operation. auto &uniquedConstants = foldScopes[insertRegion]; @@ -300,9 +284,8 @@ // Check to see if there is a canonicalized version of this constant. auto res = op->getResult(i); Attribute attrRepl = foldResults[i].get(); - if (auto *constOp = - tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, - res.getType(), op->getLoc())) { + if (auto *constOp = tryGetOrCreateConstant( + uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) { // Ensure that this constant dominates the operation we are replacing it // with. This may not automatically happen if the operation being folded // was inserted before the constant within the insertion block. @@ -316,8 +299,10 @@ // If materialization fails, cleanup any operations generated for the // previous results and return failure. for (Operation &op : llvm::make_early_inc_range( - llvm::make_range(entry.begin(), builder.getInsertionPoint()))) - eraseOp(&op); + llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) { + notifyRemoval(&op); + rewriter.eraseOp(&op); + } results.clear(); return failure(); @@ -328,9 +313,10 @@ /// Try to get or create a new constant entry. On success this returns the /// constant operation value, nullptr otherwise. -Operation *OperationFolder::tryGetOrCreateConstant( - ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder, - Attribute value, Type type, Location loc) { +Operation * +OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants, + Dialect *dialect, Attribute value, + Type type, Location loc) { // Check if an existing mapping already exists. auto constKey = std::make_tuple(dialect, value, type); Operation *&constOp = uniquedConstants[constKey]; @@ -338,7 +324,7 @@ return constOp; // If one doesn't exist, try to materialize one. - if (!(constOp = materializeConstant(dialect, builder, value, type, loc))) + if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc))) return nullptr; // Check to see if the generated constant is in the expected dialect. @@ -355,7 +341,8 @@ // If an existing operation in the new dialect already exists, delete the // materialized operation in favor of the existing one. if (auto *existingOp = uniquedConstants.lookup(newKey)) { - eraseOp(constOp); + notifyRemoval(constOp); + rewriter.eraseOp(constOp); referencedDialects[existingOp].push_back(dialect); return constOp = existingOp; } diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp --- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp +++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp @@ -39,8 +39,9 @@ maybeDefiningOp ? maybeDefiningOp->getDialect() : value.getParentRegion()->getParentOp()->getDialect(); Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); - Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, - value.getType(), value.getLoc()); + Value constant = + folder.getOrCreateConstant(b.getInsertionBlock(), valueDialect, constAttr, + value.getType(), value.getLoc()); if (!constant) return failure();