diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -25,7 +25,6 @@ class AffineApplyOp; class AffineBound; class AffineValueMap; -class RewriterBase; /// TODO: These should be renamed if they are on the mlir namespace. /// Ideally, they should go in a mlir::affine:: namespace. @@ -384,21 +383,20 @@ /// Constructs an AffineApplyOp that applies `map` to `operands` after composing /// the map with the maps of any other AffineApplyOp supplying the operands, /// then immediately attempts to fold it. If folding results in a constant -/// value, erases all created ops. The `map` must be a single-result affine map. -OpFoldResult makeComposedFoldedAffineApply(RewriterBase &b, Location loc, +/// value, no ops are actually created. The `map` must be a single-result affine +/// map. +OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands); /// Variant of `makeComposedFoldedAffineApply` that applies to an expression. -OpFoldResult makeComposedFoldedAffineApply(RewriterBase &b, Location loc, +OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineExpr expr, ArrayRef operands); /// Variant of `makeComposedFoldedAffineApply` suitable for multi-result maps. /// Note that this may create as many affine.apply operations as the map has /// results given that affine.apply must be single-result. -SmallVector -makeComposedFoldedMultiResultAffineApply(RewriterBase &b, Location loc, - AffineMap map, - ArrayRef operands); +SmallVector makeComposedFoldedMultiResultAffineApply( + OpBuilder &b, Location loc, AffineMap map, ArrayRef operands); /// Returns an AffineMinOp obtained by composing `map` and `operands` with /// AffineApplyOps supplying those operands. @@ -407,15 +405,15 @@ /// Constructs an AffineMinOp that computes a minimum across the results of /// applying `map` to `operands`, then immediately attempts to fold it. If -/// folding results in a constant value, erases all created ops. -OpFoldResult makeComposedFoldedAffineMin(RewriterBase &b, Location loc, +/// folding results in a constant value, no ops are actually created. +OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands); /// Constructs an AffineMinOp that computes a maximum across the results of /// applying `map` to `operands`, then immediately attempts to fold it. If -/// folding results in a constant value, erases all created ops. -OpFoldResult makeComposedFoldedAffineMax(RewriterBase &b, Location loc, +/// folding results in a constant value, no ops are actually created. +OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -709,11 +710,19 @@ /// Given a list of `OpFoldResult`, build the necessary operations to populate /// `actualValues` with values produced by operations. In particular, for any /// attribute-typed element in `values`, call the constant materializer -/// associated with the Affine dialect to produce an operation. +/// associated with the Affine dialect to produce an operation. Do NOT notify +/// the builder listener about the constant ops being created as they are +/// intended to be removed after being folded into affine constructs; this is +/// not suitable for use beyond the Affine dialect. static void materializeConstants(OpBuilder &b, Location loc, ArrayRef values, SmallVectorImpl &constants, SmallVectorImpl &actualValues) { + OpBuilder::Listener *listener = b.getListener(); + b.setListener(nullptr); + auto listenerResetter = + llvm::make_scope_exit([listener, &b] { b.setListener(listener); }); + actualValues.reserve(values.size()); auto *dialect = b.getContext()->getLoadedDialect(); for (OpFoldResult ofr : values) { @@ -742,7 +751,7 @@ template static std::enable_if_t(), OpFoldResult> -createOrFold(RewriterBase &b, Location loc, ValueRange operands, +createOrFold(OpBuilder &b, Location loc, ValueRange operands, Args &&...leadingArguments) { // Identify the constant operands and extract their values as attributes. // Note that we cannot use the original values directly because the list of @@ -759,17 +768,30 @@ // Create the operation and immediately attempt to fold it. On success, // delete the operation and prepare the (unmaterialized) value for being - // returned. On failure, return the operation result value. + // returned. On failure, return the operation result value. Temporarily remove + // the listener to avoid notifying it when the op is created as it may be + // removed immediately and there is no way of notifying the caller about that + // without resorting to RewriterBase. + // // TODO: arguably, the main folder (createOrFold) API should support this use // case instead of indiscriminately materializing constants. + OpBuilder::Listener *listener = b.getListener(); + b.setListener(nullptr); + auto listenerResetter = + llvm::make_scope_exit([listener, &b] { b.setListener(listener); }); OpTy op = b.create(loc, std::forward(leadingArguments)..., operands); SmallVector foldResults; if (succeeded(op->fold(constantOperands, foldResults)) && !foldResults.empty()) { - b.eraseOp(op); + op->erase(); return foldResults.front(); } + + // Notify the listener now that we definitely know that the operation will + // persist. Use the original listener stored in the variable. + if (listener) + listener->notifyOperationInserted(op); return op->getResult(0); } @@ -821,8 +843,7 @@ } OpFoldResult -mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc, - AffineMap map, +mlir::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands) { assert(map.getNumResults() == 1 && "building affine.apply with !=1 result"); @@ -835,13 +856,12 @@ // Constants are always folded into affine min/max because they can be // represented as constant expressions, so delete them. for (Operation *op : constants) - b.eraseOp(op); + op->erase(); return result; } OpFoldResult -mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc, - AffineExpr expr, +mlir::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineExpr expr, ArrayRef operands) { return makeComposedFoldedAffineApply( b, loc, AffineMap::inferFromExprList(ArrayRef{expr}).front(), @@ -849,7 +869,7 @@ } SmallVector mlir::makeComposedFoldedMultiResultAffineApply( - RewriterBase &b, Location loc, AffineMap map, + OpBuilder &b, Location loc, AffineMap map, ArrayRef operands) { return llvm::to_vector(llvm::map_range( llvm::seq(0, map.getNumResults()), [&](unsigned i) { @@ -866,7 +886,7 @@ } template -static OpFoldResult makeComposedFoldedMinMax(RewriterBase &b, Location loc, +static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands) { SmallVector constants; @@ -879,18 +899,18 @@ // Constants are always folded into affine min/max because they can be // represented as constant expressions, so delete them. for (Operation *op : constants) - b.eraseOp(op); + op->erase(); return result; } OpFoldResult -mlir::makeComposedFoldedAffineMin(RewriterBase &b, Location loc, AffineMap map, +mlir::makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands) { return makeComposedFoldedMinMax(b, loc, map, operands); } OpFoldResult -mlir::makeComposedFoldedAffineMax(RewriterBase &b, Location loc, AffineMap map, +mlir::makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands) { return makeComposedFoldedMinMax(b, loc, map, operands); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -132,12 +132,11 @@ SmallVector allShapes = op.createFlatListOfOperandDims(b, b.getLoc()); AffineMap shapesToLoops = op.getShapesToLoopsMap(); - IRRewriter rewriter(b); SmallVector loopRanges = - makeComposedFoldedMultiResultAffineApply(rewriter, op.getLoc(), - shapesToLoops, allShapes); + makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops, + allShapes); Value tripCount = - materializeOpFoldResult(rewriter, op.getLoc(), loopRanges[dimension]); + materializeOpFoldResult(b, op.getLoc(), loopRanges[dimension]); // Compute the tile sizes and the respective numbers of tiles. AffineExpr s0 = b.getAffineSymbolExpr(0); @@ -206,19 +205,17 @@ /// Build an `affine_max` of all the `vals`. static OpFoldResult buildMax(OpBuilder &b, Location loc, ArrayRef vals) { - IRRewriter rewriter(b); return makeComposedFoldedAffineMax( - rewriter, loc, - AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), vals); + b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), + vals); } /// Build an `affine_min` of all the `vals`. static OpFoldResult buildMin(OpBuilder &b, Location loc, ArrayRef vals) { - IRRewriter rewriter(b); return makeComposedFoldedAffineMin( - rewriter, loc, - AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), vals); + b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), + vals); } /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The @@ -386,7 +383,7 @@ // Insert a tile `source` into the destination tensor `dest`. The position at // which the tile is inserted (as well as size of tile) is taken from a given // ExtractSliceOp `sliceOp`. -static Value insertSliceIntoTensor(RewriterBase &b, Location loc, +static Value insertSliceIntoTensor(OpBuilder &b, Location loc, tensor::ExtractSliceOp sliceOp, Value source, Value dest) { return b.create( @@ -478,10 +475,9 @@ static_cast(op.getNumInputsAndOutputs()) && "expect the number of operands and inputs and outputs to match"); SmallVector valuesToTile = operandValuesToUse; - IRRewriter rewriter(b); SmallVector sizeBounds = - makeComposedFoldedMultiResultAffineApply( - rewriter, loc, shapeSizesToLoopsMap, allShapeSizes); + makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap, + allShapeSizes); SmallVector tiledOperands = makeTiledShapes( b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes, sizeBounds, @@ -616,10 +612,8 @@ auto sliceOp = tiledOutput.getDefiningOp(); assert(sliceOp && "expected ExtractSliceOp"); // Insert the tile into the output tensor. - // TODO: Propagate RewriterBase everywhere. - IRRewriter rewriter(b); Value yieldValue = - insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]); + insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]); return scf::ValueVector({yieldValue}); }); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -108,11 +108,10 @@ linalgOp.createFlatListOfOperandDims(b, loc); AffineMap map = linalgOp.getShapesToLoopsMap(); - IRRewriter rewriter(b); return llvm::to_vector( llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { - OpFoldResult ofr = makeComposedFoldedAffineApply( - rewriter, loc, loopExpr, allShapesSizes); + OpFoldResult ofr = + makeComposedFoldedAffineApply(b, loc, loopExpr, allShapesSizes); return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)}; })); } @@ -156,10 +155,9 @@ AffineExpr d0; bindDims(b.getContext(), d0); - IRRewriter rewriter(b); SmallVector subShapeSizes = llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { - return makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, ofr); + return makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr); })); OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);