diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3065,6 +3065,24 @@ return getViewSource(); } + // Fold subview(subview(x)), where both subviews have the same size and the + // second subview's offsets are all zero. (I.e., the second subview is a + // no-op.) + if (auto srcSubview = getViewSource().getDefiningOp()) { + auto srcSizes = srcSubview.getMixedSizes(); + auto sizes = getMixedSizes(); + auto offsets = getMixedOffsets(); + bool allOffsetsZero = llvm::all_of( + offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); + auto strides = getMixedStrides(); + bool allStridesOne = llvm::all_of( + strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }); + bool allSizesSame = llvm::equal(sizes, srcSizes); + if (allOffsetsZero && allStridesOne && allSizesSame && + resultShapedType == sourceShapedType) + return getViewSource(); + } + return {}; } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -875,3 +875,22 @@ : memref<1x?xf32, 3> into memref return %1 : memref } + +// ----- + +// CHECK-LABEL: func @fold_trivial_subviews( +// CHECK-SAME: %[[m:.*]]: memref> +// CHECK: %[[subview:.*]] = memref.subview %[[m]][5] +// CHECK: return %[[subview]] +func.func @fold_trivial_subviews(%m: memref>, + %sz: index) + -> memref> +{ + %0 = memref.subview %m[5] [%sz] [1] + : memref> + to memref> + %1 = memref.subview %0[0] [%sz] [1] + : memref> + to memref> + return %1 : memref> +}