diff --git a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h --- a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h +++ b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H #define MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -22,7 +23,8 @@ /// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets /// - Combined sizes = consumer_sizes /// - Combined strides = producer_strides * consumer_strides -// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate. +// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or +// deprecate. LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef producerOffsets, @@ -38,7 +40,8 @@ /// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use /// when combining a `producer` slice op **into** a `consumer` slice op. -// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate. +// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or +// deprecate. LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer, @@ -48,8 +51,8 @@ SmallVector &combinedSizes, SmallVector &combinedStrides); -/// Given the 'indicesVals' of a load/store operation operating on an op with -/// offsets and strides, return the combined indices. +/// Given the 'consumerIndices' of a load/store operation operating on an op +/// with offsets and strides, return the combined indices. /// /// For example, using `memref.load` and `memref.subview` as an illustration: /// @@ -64,13 +67,26 @@ /// /// ``` /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : -/// memref<12x42xf32> +/// memref<12x42xf32>å /// ``` -void resolveSourceIndicesOffsetsAndStrides( - RewriterBase &rewriter, Location loc, ArrayRef mixedOffsets, - ArrayRef mixedStrides, - const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals, - SmallVectorImpl &sourceIndices); +void resolveIndicesIntoOpWithOffsetsAndStrides( + RewriterBase &rewriter, Location loc, + ArrayRef mixedSourceOffsets, + ArrayRef mixedSourceStrides, + const llvm::SmallBitVector &rankReducedDims, + ArrayRef consumerIndices, + SmallVectorImpl &resolvedIndices); + +inline void resolveIndicesIntoOpWithOffsetsAndStrides( + RewriterBase &rewriter, Location loc, + ArrayRef mixedSourceOffsets, + ArrayRef mixedSourceStrides, + const llvm::SmallBitVector &rankReducedDims, ValueRange consumerIndices, + SmallVectorImpl &resolvedIndices) { + return resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, loc, mixedSourceOffsets, mixedSourceStrides, rankReducedDims, + getAsOpFoldResult(consumerIndices), resolvedIndices); +} } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp @@ -77,32 +77,34 @@ combinedOffsets, combinedSizes, combinedStrides); } -void mlir::resolveSourceIndicesOffsetsAndStrides( - RewriterBase &rewriter, Location loc, ArrayRef mixedOffsets, - ArrayRef mixedStrides, - const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals, - SmallVectorImpl &sourceIndices) { +void mlir::resolveIndicesIntoOpWithOffsetsAndStrides( + RewriterBase &rewriter, Location loc, + ArrayRef mixedSourceOffsets, + ArrayRef mixedSourceStrides, + const llvm::SmallBitVector &rankReducedDims, + ArrayRef consumerIndices, + SmallVectorImpl &resolvedIndices) { OpFoldResult zero = rewriter.getIndexAttr(0); // For each dimension that is rank-reduced, add a zero to the indices. int64_t indicesDim = 0; SmallVector indices; - for (auto dim : llvm::seq(0, mixedOffsets.size())) { + for (auto dim : llvm::seq(0, mixedSourceOffsets.size())) { OpFoldResult ofr = - (rankReducedDims.test(dim)) ? zero : indicesVals[indicesDim++]; + (rankReducedDims.test(dim)) ? zero : consumerIndices[indicesDim++]; indices.push_back(ofr); } - sourceIndices.resize(indices.size()); - sourceIndices.clear(); + resolvedIndices.resize(indices.size()); + resolvedIndices.clear(); for (auto [offset, index, stride] : - llvm::zip_equal(mixedOffsets, indices, mixedStrides)) { + llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) { AffineExpr off, idx, str; bindSymbols(rewriter.getContext(), off, idx, str); OpFoldResult ofr = makeComposedFoldedAffineApply( rewriter, loc, AffineMap::get(0, 3, off + idx * str), {offset, index, stride}); - sourceIndices.push_back( + resolvedIndices.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -254,11 +254,21 @@ return failure(); int64_t srcRank = srcSubView.getSourceType().getRank(); - // TODO: Only stride 1 is supported. - for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()}) - if (!llvm::all_of( - s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) - return failure(); + // // TODO: Only stride 1 is supported. + // for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()}) + // if (!llvm::all_of( + // s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) + // return failure(); + + // TODO: relax unit stride assumption. + if (!subView.hasUnitStride()) { + return rewriter.notifyMatchFailure(subView, + "requires unit strides"); + } + if (!srcSubView.hasUnitStride()) { + return rewriter.notifyMatchFailure(srcSubView, + "requires unit strides"); + } // Get original offsets and sizes. SmallVector offsets = subView.getMixedOffsets(); @@ -372,7 +382,7 @@ indices.assign(expandedIndices.begin(), expandedIndices.end()); } SmallVector sourceIndices; - resolveSourceIndicesOffsetsAndStrides( + resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(), subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, sourceIndices); @@ -492,7 +502,7 @@ indices.assign(expandedIndices.begin(), expandedIndices.end()); } SmallVector sourceIndices; - resolveSourceIndicesOffsetsAndStrides( + resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(), subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, sourceIndices); diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" @@ -21,6 +22,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" +#include namespace mlir { namespace tensor { @@ -98,7 +100,7 @@ SmallVector indices(readOp.getIndices().begin(), readOp.getIndices().end()); SmallVector sourceIndices; - resolveSourceIndicesOffsetsAndStrides( + resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), indices, sourceIndices); @@ -130,7 +132,7 @@ SmallVector indices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector sourceIndices; - resolveSourceIndicesOffsetsAndStrides( + resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices, sourceIndices); @@ -145,9 +147,65 @@ return success(); } +template +struct InsertSliceOfInsertSliceFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy insertSliceOp, + PatternRewriter &rewriter) const override { + auto sourceInsertSliceOp = + insertSliceOp.getSource() + .template getDefiningOp(); + if (!sourceInsertSliceOp) + return failure(); + + // TODO: relax unit stride assumption where possible. + if (!insertSliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure(insertSliceOp, + "requires unit strides"); + } + if (!sourceInsertSliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure(sourceInsertSliceOp, + "requires unit strides"); + } + if (insertSliceOp.getMixedSizes() != sourceInsertSliceOp.getMixedSizes()) { + return rewriter.notifyMatchFailure( + sourceInsertSliceOp, + "requires matching sizes to fold, otherwise a copy is needed"); + } + + // If we are inside an InParallel region, + if (std::is_same_v) { + rewriter.setInsertionPoint( + insertSliceOp->template getParentOfType()); + } + + // Resolve offsets according to source offsets and strides. + SmallVector resolvedOffsets; + resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, insertSliceOp.getLoc(), sourceInsertSliceOp.getMixedOffsets(), + sourceInsertSliceOp.getMixedStrides(), + sourceInsertSliceOp.getDroppedDims(), insertSliceOp.getMixedOffsets(), + resolvedOffsets); + + // Reset the insertion point. + rewriter.setInsertionPoint(insertSliceOp); + // Replace original op. + rewriter.replaceOpWithNewOp( + insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(), + getAsOpFoldResult(resolvedOffsets), sourceInsertSliceOp.getMixedSizes(), + sourceInsertSliceOp.getMixedStrides()); + + return success(); + } +}; + void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + InsertSliceOfTransferWriteOpFolder, + InsertSliceOfInsertSliceFolder, + InsertSliceOfInsertSliceFolder>( + patterns.getContext()); } //===----------------------------------------------------------------------===// // Pass registration