diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -294,6 +294,58 @@ PatternRewriter &rewriter) const override; }; +/// Folds subview(subview(x)) to a single subview(x). +class SubViewOfSubViewFolder : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::SubViewOp subView, + PatternRewriter &rewriter) const override { + Location loc = subView.getLoc(); + auto srcSubView = subView.getSource().getDefiningOp(); + if (!srcSubView) + return failure(); + int64_t srcRank = srcSubView.getSourceType().getRank(); + + // TODO: Only stride 1 is supported. + for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()}) + if (!llvm::all_of( + s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) + return failure(); + + // Get original offsets and sizes. + SmallVector offsets = subView.getMixedOffsets(); + SmallVector srcOffsets = srcSubView.getMixedOffsets(); + SmallVector sizes = subView.getMixedSizes(); + SmallVector srcSizes = srcSubView.getMixedSizes(); + + // Compute new offsets and sizes. + llvm::SmallBitVector srcReducedDims = srcSubView.getDroppedDims(); + SmallVector newOffsets, newSizes; + int64_t dim = 0; + for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) { + if (srcReducedDims[srcDim]) { + // Dim is reduced in srcSubView. + assert(isConstantIntValue(srcSizes[srcDim], 1) && "expected size 1"); + newOffsets.push_back(srcOffsets[srcDim]); + newSizes.push_back(srcSizes[srcDim]); + continue; + } + AffineExpr sym0, sym1; + bindSymbols(subView.getContext(), sym0, sym1); + newOffsets.push_back(makeComposedFoldedAffineApply( + rewriter, loc, sym0 + sym1, {srcOffsets[srcDim], offsets[dim]})); + newSizes.push_back(sizes[dim]); + ++dim; + } + + // Replace original op. + rewriter.replaceOpWithNewOp( + subView, subView.getType(), srcSubView.getSource(), newOffsets, + newSizes, srcSubView.getMixedStrides()); + return success(); + } +}; } // namespace static SmallVector @@ -533,8 +585,8 @@ LoadOpOfCollapseShapeOpFolder, LoadOpOfCollapseShapeOpFolder, StoreOpOfCollapseShapeOpFolder, - StoreOpOfCollapseShapeOpFolder>( - patterns.getContext()); + StoreOpOfCollapseShapeOpFolder, + SubViewOfSubViewFolder>(patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s -o - | FileCheck %s +// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s | FileCheck %s func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> @@ -465,3 +465,40 @@ // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index // CHECK-NEXT: affine.for %{{.*}} = 0 to 3 { // CHECK-NEXT: affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32> + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)> +// CHECK-LABEL: func @subview_of_subview( +// CHECK-SAME: %[[m:.*]]: memref<1x1024xf32, 3>, %[[pos:.*]]: index +// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%arg1] +// CHECK: memref.subview %arg0[4, %[[add]]] [1, 1] [1, 1] : memref<1x1024xf32, 3> to memref, 3> +func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index) + -> memref, 3> +{ + %0 = memref.subview %m[3, %pos] [1, 2] [1, 1] + : memref<1x1024xf32, 3> + to memref<1x2xf32, strided<[1024, 2], offset: ?>, 3> + %1 = memref.subview %0[1, 2] [1, 1] [1, 1] + : memref<1x2xf32, strided<[1024, 2], offset: ?>, 3> + to memref, 3> + return %1 : memref, 3> +} + +// ----- + +// CHECK-LABEL: func @subview_of_subview_rank_reducing( +// CHECK-SAME: %[[m:.*]]: memref +// CHECK: memref.subview %arg0[3, 7, 8] [1, 1, 1] [1, 1, 1] : memref to memref> +func.func @subview_of_subview_rank_reducing(%m: memref, + %sz: index, %pos: index) + -> memref> +{ + %0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1] + : memref + to memref> + %1 = memref.subview %0[6] [1] [1] + : memref> + to memref> + return %1 : memref> +}