diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -26,50 +26,9 @@ /// Matches a ConstantIndexOp. detail::op_matcher matchConstantIndex(); -/// Returns `success` when any of the elements in `ofrs` was produced by -/// arith::ConstantIndexOp. In that case the constant attribute replaces the -/// Value. Returns `failure` when no folding happened. -LogicalResult foldDynamicIndexList(Builder &b, - SmallVectorImpl &ofrs); - llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef shape); -/// Pattern to rewrite a subview op with constant arguments. -template -class OpWithOffsetSizesAndStridesConstantArgumentFolder final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpType op, - PatternRewriter &rewriter) const override { - SmallVector mixedOffsets(op.getMixedOffsets()); - SmallVector mixedSizes(op.getMixedSizes()); - SmallVector mixedStrides(op.getMixedStrides()); - - // No constant operands were folded, just return; - if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) && - failed(foldDynamicIndexList(rewriter, mixedSizes)) && - failed(foldDynamicIndexList(rewriter, mixedStrides))) - return failure(); - - // Create the new op in canonical form. - ResultTypeFunc resultTypeFunc; - auto resultType = - resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides); - if (!resultType) - return failure(); - auto newOp = - rewriter.create(op.getLoc(), resultType, op.getSource(), - mixedOffsets, mixedSizes, mixedStrides); - CastOpFunc func; - func(rewriter, op, newOp); - - return success(); - } -}; - /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to /// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute. /// Other attribute types are not supported. diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -135,6 +135,11 @@ getValuesSortedByKey(ArrayRef keys, ArrayRef values, llvm::function_ref compare); +/// Returns "success" when any of the elements in `ofrs` is a constant value. In +/// that case the value is replaced by an attribute. Returns "failure" when no +/// folding happened. +LogicalResult foldDynamicIndexList(SmallVectorImpl &ofrs); + /// Return the number of iterations for a loop with a lower bound `lb`, upper /// bound `ub` and step `step`. std::optional constantTripCount(OpFoldResult lb, OpFoldResult ub, diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" namespace mlir { @@ -39,6 +40,47 @@ namespace mlir { +/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as +/// constant arguments. This pattern assumes that the op has a suitable builder +/// that takes a result type, a "source" operand and mixed offsets, sizes and +/// strides. +/// +/// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn` +/// returns the new result type of the op, based on the new offsets, sizes and +/// strides. `CastOpFunc` is used to generate a cast op if the result type of +/// the op has changed. +template +class OpWithOffsetSizesAndStridesConstantArgumentFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + SmallVector mixedOffsets(op.getMixedOffsets()); + SmallVector mixedSizes(op.getMixedSizes()); + SmallVector mixedStrides(op.getMixedStrides()); + + // No constant operands were folded, just return; + if (failed(foldDynamicIndexList(mixedOffsets)) && + failed(foldDynamicIndexList(mixedSizes)) && + failed(foldDynamicIndexList(mixedStrides))) + return failure(); + + // Create the new op in canonical form. + auto resultType = + ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides); + if (!resultType) + return failure(); + auto newOp = + rewriter.create(op.getLoc(), resultType, op.getSource(), + mixedOffsets, mixedSizes, mixedStrides); + CastOpFunc()(rewriter, op, newOp); + + return success(); + } +}; + /// Printer hook for custom directive in assemblyFormat. /// /// custom($values, $integers) diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -25,25 +25,6 @@ return detail::op_matcher(); } -// Returns `success` when any of the elements in `ofrs` was produced by -// arith::ConstantIndexOp. In that case the constant attribute replaces the -// Value. Returns `failure` when no folding happened. -LogicalResult mlir::foldDynamicIndexList(Builder &b, - SmallVectorImpl &ofrs) { - bool valuesChanged = false; - for (OpFoldResult &ofr : ofrs) { - if (ofr.is()) - continue; - // Newly static, move from Value to constant. - if (auto cstOp = llvm::dyn_cast_if_present(ofr) - .getDefiningOp()) { - ofr = b.getIndexAttr(cstOp.value()); - valuesChanged = true; - } - } - return success(valuesChanged); -} - llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, ArrayRef shape) { llvm::SmallBitVector dimsToProject(shape.size()); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1513,9 +1513,9 @@ SmallVector mixedLowerBound(op.getMixedLowerBound()); SmallVector mixedUpperBound(op.getMixedUpperBound()); SmallVector mixedStep(op.getMixedStep()); - if (failed(foldDynamicIndexList(rewriter, mixedLowerBound)) && - failed(foldDynamicIndexList(rewriter, mixedUpperBound)) && - failed(foldDynamicIndexList(rewriter, mixedStep))) + if (failed(foldDynamicIndexList(mixedLowerBound)) && + failed(foldDynamicIndexList(mixedUpperBound)) && + failed(foldDynamicIndexList(mixedStep))) return failure(); rewriter.updateRootInPlace(op, [&]() { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2317,9 +2317,9 @@ SmallVector mixedStrides(insertSliceOp.getMixedStrides()); // No constant operands were folded, just return; - if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) && - failed(foldDynamicIndexList(rewriter, mixedSizes)) && - failed(foldDynamicIndexList(rewriter, mixedStrides))) + if (failed(foldDynamicIndexList(mixedOffsets)) && + failed(foldDynamicIndexList(mixedSizes)) && + failed(foldDynamicIndexList(mixedStrides))) return failure(); // Create the new op in canonical form. diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -242,4 +242,18 @@ return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant); } +LogicalResult foldDynamicIndexList(SmallVectorImpl &ofrs) { + bool valuesChanged = false; + for (OpFoldResult &ofr : ofrs) { + if (ofr.is()) + continue; + Attribute attr; + if (matchPattern(ofr.get(), m_Constant(&attr))) { + ofr = attr; + valuesChanged = true; + } + } + return success(valuesChanged); +} + } // namespace mlir